Skip to content

Commit

Permalink
transport: new stream with actual server name (#5748)
Browse files Browse the repository at this point in the history
  • Loading branch information
holdno committed Nov 18, 2022
1 parent 817c1e8 commit 0238b6e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
27 changes: 22 additions & 5 deletions internal/transport/http2_client.go
Expand Up @@ -59,11 +59,15 @@ var clientConnectionCounter uint64

// http2Client implements the ClientTransport interface with HTTP2.
type http2Client struct {
lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
ctx context.Context
cancel context.CancelFunc
ctxDone <-chan struct{} // Cache the ctx.Done() chan.
userAgent string
lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
ctx context.Context
cancel context.CancelFunc
ctxDone <-chan struct{} // Cache the ctx.Done() chan.
userAgent string
// address contains the resolver returned address for this transport.
// If the `ServerName` field is set, it takes precedence over `CallHdr.Host`
// passed to `NewStream`, when determining the :authority header.
address resolver.Address
md metadata.MD
conn net.Conn // underlying communication channel
loopy *loopyWriter
Expand Down Expand Up @@ -314,6 +318,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
cancel: cancel,
userAgent: opts.UserAgent,
registeredCompressors: grpcutil.RegisteredCompressors(),
address: addr,
conn: conn,
remoteAddr: conn.RemoteAddr(),
localAddr: conn.LocalAddr(),
Expand Down Expand Up @@ -702,6 +707,18 @@ func (e NewStreamError) Error() string {
// streams. All non-nil errors returned will be *NewStreamError.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) {
ctx = peer.NewContext(ctx, t.getPeer())

// ServerName field of the resolver returned address takes precedence over
// Host field of CallHdr to determine the :authority header. This is because,
// the ServerName field takes precedence for server authentication during
// TLS handshake, and the :authority header should match the value used
// for server authentication.
if t.address.ServerName != "" {
newCallHdr := *callHdr
newCallHdr.Host = t.address.ServerName
callHdr = &newCallHdr
}

headerFields, err := t.createHeaderFields(ctx, callHdr)
if err != nil {
return nil, &NewStreamError{Err: err, AllowTransparentRetry: false}
Expand Down
34 changes: 34 additions & 0 deletions test/authority_test.go
Expand Up @@ -36,6 +36,8 @@ import (
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/status"
testpb "google.golang.org/grpc/test/grpc_testing"
)
Expand Down Expand Up @@ -205,3 +207,35 @@ func (s) TestColonPortAuthority(t *testing.T) {
t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err)
}
}

// TestAuthorityReplacedWithResolverAddress tests the scenario where the resolver
// returned address contains a ServerName override. The test verifies that the the
// :authority header value sent to the server as part of the http/2 HEADERS frame
// is set to the value specified in the resolver returned address.
func (s) TestAuthorityReplacedWithResolverAddress(t *testing.T) {
const expectedAuthority = "test.server.name"

ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
return authorityChecker(ctx, expectedAuthority)
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

r := manual.NewBuilderWithScheme("whatever")
r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: ss.Address, ServerName: expectedAuthority}}})
cc, err := grpc.Dial(r.Scheme()+":///whatever", grpc.WithInsecure(), grpc.WithResolvers(r))
if err != nil {
t.Fatalf("grpc.Dial(%q) = %v", ss.Address, err)
}
defer cc.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err = testpb.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("EmptyCall() rpc failed: %v", err)
}
}

0 comments on commit 0238b6e

Please sign in to comment.