diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 28c77af70ab..53643fa9747 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -326,6 +326,8 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts keepaliveEnabled: keepaliveEnabled, bufferPool: newBufferPool(), } + // Add peer information to the http2client context. + t.ctx = peer.NewContext(t.ctx, t.getPeer()) if md, ok := addr.Metadata.(*metadata.MD); ok { t.md = *md @@ -469,7 +471,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { func (t *http2Client) getPeer() *peer.Peer { return &peer.Peer{ Addr: t.remoteAddr, - AuthInfo: t.authInfo, + AuthInfo: t.authInfo, // Can be nil } } diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 28bcba0a33c..3dd15647bc8 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -265,6 +265,9 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, czData: new(channelzData), bufferPool: newBufferPool(), } + // Add peer information to the http2server context. + t.ctx = peer.NewContext(t.ctx, t.getPeer()) + t.controlBuf = newControlBuffer(t.done) if dynamicWindow { t.bdpEst = &bdpEstimator{ @@ -485,14 +488,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( } else { s.ctx, s.cancel = context.WithCancel(t.ctx) } - pr := &peer.Peer{ - Addr: t.remoteAddr, - } - // Attach Auth info if there is any. - if t.authInfo != nil { - pr.AuthInfo = t.authInfo - } - s.ctx = peer.NewContext(s.ctx, pr) + // Attach the received metadata to the context. if len(mdata) > 0 { s.ctx = metadata.NewIncomingContext(s.ctx, mdata) @@ -1416,6 +1412,13 @@ func (t *http2Server) getOutFlowWindow() int64 { } } +func (t *http2Server) getPeer() *peer.Peer { + return &peer.Peer{ + Addr: t.remoteAddr, + AuthInfo: t.authInfo, // Can be nil + } +} + func getJitter(v time.Duration) time.Duration { if v == infinity { return 0 diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index c1f9664ada6..760e1b64f35 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -35,6 +35,8 @@ import ( "testing" "time" + "google.golang.org/grpc/peer" + "github.com/google/go-cmp/cmp" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" @@ -2450,3 +2452,52 @@ func TestConnectionError_Unwrap(t *testing.T) { t.Error("ConnectionError does not unwrap") } } + +func (s) TestPeerSetInServerContext(t *testing.T) { + // create client and server transports. + server, client, cancel := setUp(t, 0, math.MaxUint32, normal) + defer cancel() + defer server.stop() + defer client.Close(fmt.Errorf("closed manually by test")) + + // create a stream with client transport. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + stream, err := client.NewStream(ctx, &CallHdr{}) + if err != nil { + t.Fatalf("failed to create a stream: %v", err) + } + + waitWhileTrue(t, func() (bool, error) { + server.mu.Lock() + defer server.mu.Unlock() + + if len(server.conns) == 0 { + return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") + } + return false, nil + }) + + // verify peer is set in client transport context. + if _, ok := peer.FromContext(client.ctx); !ok { + t.Fatalf("Peer expected in client transport's context, but actually not found.") + } + + // verify peer is set in stream context. + if _, ok := peer.FromContext(stream.ctx); !ok { + t.Fatalf("Peer expected in stream context, but actually not found.") + } + + // verify peer is set in server transport context. + server.mu.Lock() + for k := range server.conns { + sc, ok := k.(*http2Server) + if !ok { + t.Fatalf("ServerTransport is of type %T, want %T", k, &http2Server{}) + } + if _, ok = peer.FromContext(sc.ctx); !ok { + t.Fatalf("Peer expected in server transport's context, but actually not found.") + } + } + server.mu.Unlock() +}