diff --git a/aws/endpoints/v3model.go b/aws/endpoints/v3model.go index 610dc1e1d39..b6d1b13e0ff 100644 --- a/aws/endpoints/v3model.go +++ b/aws/endpoints/v3model.go @@ -74,16 +74,29 @@ func (p partition) canResolveEndpoint(service, region string, strictMatch bool) return p.RegionRegex.MatchString(region) } +func allowLegacyEmptyRegion(service string) bool { + legacy := map[string]struct{}{ + "ec2metadata": {}, + } + + _, allowed := legacy[service] + return allowed +} + func (p partition) EndpointFor(service, region string, opts ResolveOptions) (resolved aws.Endpoint, err error) { s, hasService := p.Services[service] - if !hasService && opts.StrictMatching { + if len(service) == 0 || (!hasService && opts.StrictMatching) { // Only return error if the resolver will not fallback to creating // endpoint based on service endpoint ID passed in. return resolved, NewUnknownServiceError(p.ID, service, serviceList(p.Services)) } + if len(region) == 0 && allowLegacyEmptyRegion(service) && len(s.PartitionEndpoint) != 0 { + region = s.PartitionEndpoint + } + e, hasEndpoint := s.endpointForRegion(region) - if !hasEndpoint && opts.StrictMatching { + if len(region) == 0 || (!hasEndpoint && opts.StrictMatching) { return resolved, NewUnknownEndpointError(p.ID, service, region, endpointList(s.Endpoints)) } diff --git a/aws/endpoints/v3model_test.go b/aws/endpoints/v3model_test.go index 9f6641498f5..5779c5596d5 100644 --- a/aws/endpoints/v3model_test.go +++ b/aws/endpoints/v3model_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "reflect" "regexp" + "strings" "testing" ) @@ -522,3 +523,46 @@ func TestResolveEndpoint_AwsGlobal(t *testing.T) { t.Errorf("expect the signing name to be derived") } } + +func TestEndpointFor_EmptyRegion(t *testing.T) { + cases := map[string]struct { + Service string + Region string + RealRegion string + ExpectErr string + }{ + // Legacy services that previous accepted empty region + "ec2metadata": {Service: "ec2metadata", RealRegion: "aws-global"}, + + // Other services + "s3": {Service: "s3", Region: "us-east-1", RealRegion: "us-east-1"}, + "s3 no region": {Service: "s3", ExpectErr: "could not resolve endpoint"}, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + actual, err := NewDefaultResolver().ResolveEndpoint(c.Service, c.Region) + if len(c.ExpectErr) != 0 { + if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) { + t.Errorf("expect %q error in %q", e, a) + } + return + } + if err != nil { + t.Fatalf("expect no error got, %v", err) + } + + expect, err := NewDefaultResolver().ResolveEndpoint(c.Service, c.RealRegion) + if err != nil { + t.Fatalf("failed to get endpoint for default resolver") + } + if e, a := expect.URL, actual.URL; e != a { + t.Errorf("expect %v URL, got %v", e, a) + } + if e, a := expect.SigningRegion, actual.SigningRegion; e != a { + t.Errorf("expect %v signing region, got %v", e, a) + } + + }) + } +}