diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index d518b07e16f..847ac9a521d 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -59,11 +59,15 @@ var clientConnectionCounter uint64 // http2Client implements the ClientTransport interface with HTTP2. type http2Client struct { - lastRead int64 // Keep this field 64-bit aligned. Accessed atomically. - ctx context.Context - cancel context.CancelFunc - ctxDone <-chan struct{} // Cache the ctx.Done() chan. - userAgent string + lastRead int64 // Keep this field 64-bit aligned. Accessed atomically. + ctx context.Context + cancel context.CancelFunc + ctxDone <-chan struct{} // Cache the ctx.Done() chan. + userAgent string + // address contains the resolver returned address for this transport. + // If the `ServerName` field is set, it takes precedence over `CallHdr.Host` + // passed to `NewStream`, when determining the :authority header. + address resolver.Address md metadata.MD conn net.Conn // underlying communication channel loopy *loopyWriter @@ -314,6 +318,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts cancel: cancel, userAgent: opts.UserAgent, registeredCompressors: grpcutil.RegisteredCompressors(), + address: addr, conn: conn, remoteAddr: conn.RemoteAddr(), localAddr: conn.LocalAddr(), @@ -702,7 +707,18 @@ func (e NewStreamError) Error() string { // streams. All non-nil errors returned will be *NewStreamError. func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) { ctx = peer.NewContext(ctx, t.getPeer()) - headerFields, err := t.createHeaderFields(ctx, callHdr) + + dupCallHdr := *callHdr + // ServerName field of the resolver returned address takes precedence over + // Host field of CallHdr to determine the :authority header. This is because, + // the ServerName field takes precedence for server authentication during + // TLS handshake, and the :authority header should match the value used + // for server authentication. + if t.address.ServerName != "" { + dupCallHdr.Host = t.address.ServerName + } + + headerFields, err := t.createHeaderFields(ctx, &dupCallHdr) if err != nil { return nil, &NewStreamError{Err: err, AllowTransparentRetry: false} } diff --git a/test/authority_test.go b/test/authority_test.go index c841c64736f..ad21c169ae6 100644 --- a/test/authority_test.go +++ b/test/authority_test.go @@ -36,6 +36,8 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" "google.golang.org/grpc/status" testpb "google.golang.org/grpc/test/grpc_testing" ) @@ -205,3 +207,35 @@ func (s) TestColonPortAuthority(t *testing.T) { t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err) } } + +// TestAuthorityReplacedWithResolverAddress This test makes sure that the http2 client +// replaces the authority to the resolver address server name when it is set. +func (s) TestAuthorityReplacedWithResolverAddress(t *testing.T) { + const expectedAuthority = "test.server.name" + + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + return authorityChecker(ctx, expectedAuthority) + }, + Network: "tcp", + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + r := manual.NewBuilderWithScheme("whatever") + r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: ss.Address, ServerName: expectedAuthority}}}) + cc, err := grpc.Dial(r.Scheme()+":///whatever", grpc.WithInsecure(), grpc.WithResolvers(r)) + if err != nil { + t.Fatalf("grpc.Dial(%q) = %v", ss.Address, err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + _, err = testpb.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}) + if err != nil { + t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err) + } +} diff --git a/xds/internal/balancer/clusterresolver/configbuilder.go b/xds/internal/balancer/clusterresolver/configbuilder.go index b76a40355cc..0f9efba4fb3 100644 --- a/xds/internal/balancer/clusterresolver/configbuilder.go +++ b/xds/internal/balancer/clusterresolver/configbuilder.go @@ -198,7 +198,7 @@ func buildClusterImplConfigForDNS(g *nameGenerator, addrStrs []string, mechanism retAddrs := make([]resolver.Address, 0, len(addrStrs)) pName := fmt.Sprintf("priority-%v", g.prefix) for _, addrStr := range addrStrs { - retAddrs = append(retAddrs, hierarchy.Set(resolver.Address{Addr: addrStr}, []string{pName})) + retAddrs = append(retAddrs, hierarchy.Set(resolver.Address{Addr: addrStr, ServerName: mechanism.DNSHostname}, []string{pName})) } return pName, &clusterimpl.LBConfig{ Cluster: mechanism.Cluster,