diff --git a/balancer/balancer.go b/balancer/balancer.go index 178de0898aa..bcc6f5451c9 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -174,25 +174,32 @@ type ClientConn interface { // BuildOptions contains additional information for Build. type BuildOptions struct { - // DialCreds is the transport credential the Balancer implementation can - // use to dial to a remote load balancer server. The Balancer implementations - // can ignore this if it does not need to talk to another party securely. + // DialCreds is the transport credentials to use when communicating with a + // remote load balancer server. Balancer implementations which do not + // communicate with a remote load balancer server can ignore this field. DialCreds credentials.TransportCredentials - // CredsBundle is the credentials bundle that the Balancer can use. + // CredsBundle is the credentials bundle to use when communicating with a + // remote load balancer server. Balancer implementations which do not + // communicate with a remote load balancer server can ignore this field. CredsBundle credentials.Bundle - // Dialer is the custom dialer the Balancer implementation can use to dial - // to a remote load balancer server. The Balancer implementations - // can ignore this if it doesn't need to talk to remote balancer. + // Dialer is the custom dialer to use when communicating with a remote load + // balancer server. Balancer implementations which do not communicate with a + // remote load balancer server can ignore this field. Dialer func(context.Context, string) (net.Conn, error) - // ChannelzParentID is the entity parent's channelz unique identification number. + // Authority is the server name to use as part of the authentication + // handshake when communicating with a remote load balancer server. Balancer + // implementations which do not communicate with a remote load balancer + // server can ignore this field. + Authority string + // ChannelzParentID is the parent ClientConn's channelz ID. ChannelzParentID int64 // CustomUserAgent is the custom user agent set on the parent ClientConn. // The balancer should set the same custom user agent if it creates a // ClientConn. CustomUserAgent string - // Target contains the parsed address info of the dial target. It is the same resolver.Target as - // passed to the resolver. - // See the documentation for the resolver.Target type for details about what it contains. + // Target contains the parsed address info of the dial target. It is the + // same resolver.Target as passed to the resolver. See the documentation for + // the resolver.Target type for details about what it contains. Target resolver.Target } diff --git a/clientconn.go b/clientconn.go index 972ff1a65ba..97b793e05e2 100644 --- a/clientconn.go +++ b/clientconn.go @@ -285,6 +285,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * DialCreds: credsClone, CredsBundle: cc.dopts.copts.CredsBundle, Dialer: cc.dopts.copts.Dialer, + Authority: cc.authority, CustomUserAgent: cc.dopts.copts.UserAgent, ChannelzParentID: cc.channelzID, Target: cc.parsedTarget, diff --git a/test/balancer_test.go b/test/balancer_test.go index 47332db7975..5d5c85896d3 100644 --- a/test/balancer_test.go +++ b/test/balancer_test.go @@ -36,6 +36,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/internal/balancerload" "google.golang.org/grpc/internal/grpcutil" @@ -821,3 +822,127 @@ func (s) TestWaitForReady(t *testing.T) { t.Fatal(err.Error()) } } + +// authorityOverrideTransportCreds returns the configured authority value in its +// Info() method. +type authorityOverrideTransportCreds struct { + credentials.TransportCredentials + authorityOverride string +} + +func (ao *authorityOverrideTransportCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} +func (ao *authorityOverrideTransportCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{ServerName: ao.authorityOverride} +} +func (ao *authorityOverrideTransportCreds) Clone() credentials.TransportCredentials { + return &authorityOverrideTransportCreds{authorityOverride: ao.authorityOverride} +} + +// TestAuthorityInBuildOptions tests that the Authority field in +// balancer.BuildOptions is setup correctly from gRPC. +func (s) TestAuthorityInBuildOptions(t *testing.T) { + const dialTarget = "test.server" + + tests := []struct { + name string + dopts []grpc.DialOption + wantAuthority string + }{ + { + name: "authority from dial target", + dopts: []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, + wantAuthority: dialTarget, + }, + { + name: "authority from dial option", + dopts: []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithAuthority("authority-override"), + }, + wantAuthority: "authority-override", + }, + { + name: "authority from transport creds", + dopts: []grpc.DialOption{grpc.WithTransportCredentials(&authorityOverrideTransportCreds{authorityOverride: "authority-override-from-transport-creds"})}, + wantAuthority: "authority-override-from-transport-creds", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + authorityCh := make(chan string, 1) + bf := stub.BalancerFuncs{ + UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { + select { + case authorityCh <- bd.BuildOptions.Authority: + default: + } + + addrs := ccs.ResolverState.Addresses + if len(addrs) == 0 { + return nil + } + + // Only use the first address. + sc, err := bd.ClientConn.NewSubConn([]resolver.Address{addrs[0]}, balancer.NewSubConnOptions{}) + if err != nil { + return err + } + sc.Connect() + return nil + }, + UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) { + bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}}) + }, + } + balancerName := "stub-balancer-" + test.name + stub.Register(balancerName, bf) + t.Logf("Registered balancer %s...", balancerName) + + lis, err := testutils.LocalTCPListener() + if err != nil { + t.Fatal(err) + } + + s := grpc.NewServer() + testpb.RegisterTestServiceServer(s, &testServer{}) + go s.Serve(lis) + defer s.Stop() + t.Logf("Started gRPC server at %s...", lis.Addr().String()) + + r := manual.NewBuilderWithScheme("whatever") + t.Logf("Registered manual resolver with scheme %s...", r.Scheme()) + r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}) + + dopts := append([]grpc.DialOption{ + grpc.WithResolvers(r), + grpc.WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, balancerName)), + }, test.dopts...) + cc, err := grpc.Dial(r.Scheme()+":///"+dialTarget, dopts...) + if err != nil { + t.Fatal(err) + } + defer cc.Close() + tc := testpb.NewTestServiceClient(cc) + t.Log("Created a ClientConn...") + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("EmptyCall() = _, %v, want _, ", err) + } + t.Log("Made an RPC which succeeded...") + + select { + case <-ctx.Done(): + t.Fatal("timeout when waiting for Authority in balancer.BuildOptions") + case gotAuthority := <-authorityCh: + if gotAuthority != test.wantAuthority { + t.Fatalf("Authority in balancer.BuildOptions is %s, want %s", gotAuthority, test.wantAuthority) + } + } + }) + } +}