Skip to content

Commit

Permalink
aws/endpoints: Fix SDK resolving endpoint without region (#420)
Browse files Browse the repository at this point in the history
Fixes the SDK's endpoint resolve incorrectly resolving endpoints for a
service when the region is empty. Also fixes the SDK attempting to
resolve a service when the service value is empty.

Related to: aws/aws-sdk-go#2911
  • Loading branch information
jasdel committed Dec 12, 2019
1 parent dadd7ec commit 7ede46e
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 8 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG_PENDING.md
Expand Up @@ -9,4 +9,6 @@ SDK Enhancements

SDK Bugs
---

* `aws/endpoints`: aws/endpoints: Fix SDK resolving endpoint without region ([#420](https://github.com/aws/aws-sdk-go-v2/pull/420))
* Fixes the SDK's endpoint resolve incorrectly resolving endpoints for a service when the region is empty. Also fixes the SDK attempting to resolve a service when the service value is empty.
* Related to [aws/aws-sdk-go#2909](https://github.com/aws/aws-sdk-go/issues/2909)
17 changes: 15 additions & 2 deletions aws/endpoints/v3model.go
Expand Up @@ -74,16 +74,29 @@ func (p partition) canResolveEndpoint(service, region string, strictMatch bool)
return p.RegionRegex.MatchString(region)
}

var allowEmptyRegion = map[string]struct{}{
"ec2metadata": {},
}

func serviceRequiresRegion(service string) bool {
_, allowed := allowEmptyRegion[service]
return !allowed
}

func (p partition) EndpointFor(service, region string, opts ResolveOptions) (resolved aws.Endpoint, err error) {
s, hasService := p.Services[service]
if !hasService && opts.StrictMatching {
if len(service) == 0 || (!hasService && opts.StrictMatching) {
// Only return error if the resolver will not fallback to creating
// endpoint based on service endpoint ID passed in.
return resolved, NewUnknownServiceError(p.ID, service, serviceList(p.Services))
}

if len(region) == 0 && !serviceRequiresRegion(service) && len(s.PartitionEndpoint) != 0 {
region = s.PartitionEndpoint
}

e, hasEndpoint := s.endpointForRegion(region)
if !hasEndpoint && opts.StrictMatching {
if len(region) == 0 || (!hasEndpoint && opts.StrictMatching) {
return resolved, NewUnknownEndpointError(p.ID, service, region, endpointList(s.Endpoints))
}

Expand Down
44 changes: 44 additions & 0 deletions aws/endpoints/v3model_test.go
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"reflect"
"regexp"
"strings"
"testing"
)

Expand Down Expand Up @@ -522,3 +523,46 @@ func TestResolveEndpoint_AwsGlobal(t *testing.T) {
t.Errorf("expect the signing name to be derived")
}
}

func TestEndpointFor_EmptyRegion(t *testing.T) {
cases := map[string]struct {
Service string
Region string
RealRegion string
ExpectErr string
}{
// Legacy services that previous accepted empty region
"ec2metadata": {Service: "ec2metadata", RealRegion: "aws-global"},

// Other services
"s3": {Service: "s3", Region: "us-east-1", RealRegion: "us-east-1"},
"s3 no region": {Service: "s3", ExpectErr: "could not resolve endpoint"},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
actual, err := NewDefaultResolver().ResolveEndpoint(c.Service, c.Region)
if len(c.ExpectErr) != 0 {
if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %q error in %q", e, a)
}
return
}
if err != nil {
t.Fatalf("expect no error got, %v", err)
}

expect, err := NewDefaultResolver().ResolveEndpoint(c.Service, c.RealRegion)
if err != nil {
t.Fatalf("failed to get endpoint for default resolver")
}
if e, a := expect.URL, actual.URL; e != a {
t.Errorf("expect %v URL, got %v", e, a)
}
if e, a := expect.SigningRegion, actual.SigningRegion; e != a {
t.Errorf("expect %v signing region, got %v", e, a)
}

})
}
}
14 changes: 11 additions & 3 deletions example/service/s3/mockPaginator/mockPaginator.go
Expand Up @@ -5,6 +5,7 @@ package main
import (
"context"
"fmt"
"log"
"os"

"github.com/aws/aws-sdk-go-v2/aws/external"
Expand All @@ -24,12 +25,15 @@ func main() {

bucket := os.Args[1]
svc := s3.New(cfg)
keys := getKeys(svc, bucket)
keys, err := getKeys(svc, bucket)
if err != nil {
log.Fatalf("failed to get keys, %v", err)
}

fmt.Printf("keys for bucket %q,\n%v\n", bucket, keys)
}

func getKeys(svc s3iface.ClientAPI, bucket string) []string {
func getKeys(svc s3iface.ClientAPI, bucket string) ([]string, error) {
req := svc.ListObjectsRequest(&s3.ListObjectsInput{
Bucket: &bucket,
})
Expand All @@ -41,5 +45,9 @@ func getKeys(svc s3iface.ClientAPI, bucket string) []string {
keys = append(keys, *obj.Key)
}
}
return keys
if err := p.Err(); err != nil {
return nil, err
}

return keys, nil
}
10 changes: 8 additions & 2 deletions example/service/s3/mockPaginator/mockPaginator_test.go
Expand Up @@ -72,10 +72,16 @@ func TestListObjectsPagination(t *testing.T) {
},
}

svc.Client = s3.New(defaults.Config())
cfg := defaults.Config()
cfg.Region = "us-west-2"

svc.Client = s3.New(cfg)
svc.objects = objects

keys := getKeys(svc, "foo")
keys, err := getKeys(svc, "foo")
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
expected := []string{"1", "2", "3"}

if e, a := 3, len(keys); e != a {
Expand Down

0 comments on commit 7ede46e

Please sign in to comment.