Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transport: new stream with actual server name #5748

Merged
merged 10 commits into from Nov 18, 2022
16 changes: 15 additions & 1 deletion internal/transport/http2_client.go
Expand Up @@ -64,6 +64,7 @@ type http2Client struct {
cancel context.CancelFunc
ctxDone <-chan struct{} // Cache the ctx.Done() chan.
userAgent string
address resolver.Address // Record the used resolver address of client, and replace :authority to resolver address if serverName is not empty.
holdno marked this conversation as resolved.
Show resolved Hide resolved
md metadata.MD
conn net.Conn // underlying communication channel
loopy *loopyWriter
Expand Down Expand Up @@ -314,6 +315,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
cancel: cancel,
userAgent: opts.UserAgent,
registeredCompressors: grpcutil.RegisteredCompressors(),
address: addr, // resolver address
holdno marked this conversation as resolved.
Show resolved Hide resolved
conn: conn,
remoteAddr: conn.RemoteAddr(),
localAddr: conn.LocalAddr(),
Expand Down Expand Up @@ -454,6 +456,11 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
return t, nil
}

// Address return the resolver address meta info
holdno marked this conversation as resolved.
Show resolved Hide resolved
func (t *http2Client) Address() resolver.Address {
return t.address
}

func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{
Expand Down Expand Up @@ -702,7 +709,14 @@ func (e NewStreamError) Error() string {
// streams. All non-nil errors returned will be *NewStreamError.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) {
ctx = peer.NewContext(ctx, t.getPeer())
headerFields, err := t.createHeaderFields(ctx, callHdr)

dupCallHdr := *callHdr
// replace host with the actual server name, if it is exist and unmatch
holdno marked this conversation as resolved.
Show resolved Hide resolved
if t.address.ServerName != "" && t.address.ServerName != dupCallHdr.Host {
dupCallHdr.Host = t.address.ServerName
}

headerFields, err := t.createHeaderFields(ctx, &dupCallHdr)
if err != nil {
return nil, &NewStreamError{Err: err, AllowTransparentRetry: false}
}
Expand Down
3 changes: 3 additions & 0 deletions internal/transport/transport.go
Expand Up @@ -657,6 +657,9 @@ type ClientTransport interface {
// with a human readable string with debug info.
GetGoAwayReason() (GoAwayReason, string)

// Address return the transport used resolver address meta info
Address() resolver.Address

// RemoteAddr returns the remote network address.
RemoteAddr() net.Addr

Expand Down
67 changes: 66 additions & 1 deletion internal/transport/transport_test.go
Expand Up @@ -35,6 +35,7 @@ import (
"testing"
"time"

"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"

holdno marked this conversation as resolved.
Show resolved Hide resolved
"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -91,6 +92,7 @@ const (
invalidHeaderField
delayRead
pingpong
returnHeaderAuthority
)

func (h *testStreamHandler) handleStreamAndNotify(s *Stream) {
Expand Down Expand Up @@ -203,6 +205,31 @@ func (h *testStreamHandler) handleStreamInvalidHeaderField(s *Stream) {
})
}

func (h *testStreamHandler) handleStreamReturnValueOfAuthority(t *testing.T, s *Stream) {
var (
md, exist = metadata.FromIncomingContext(s.ctx)
resp string
)

if !exist || len(md.Get(":authority")) == 0 {
h.handleStreamInvalidHeaderField(s)
return
}

resp = md.Get(":authority")[0]

req := expectedRequest
p := make([]byte, len(req))
_, err := s.Read(p)
if err != nil {
return
}
// send a response back to the client.
h.t.Write(s, nil, []byte(resp), &Options{})
// send the trailer to end the stream.
h.t.WriteStatus(s, status.New(codes.OK, ""))
}

// handleStreamDelayRead delays reads so that the other side has to halt on
// stream-level flow control.
// This handler assumes dynamic flow control is turned off and assumes window
Expand Down Expand Up @@ -379,6 +406,12 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case returnHeaderAuthority:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamReturnValueOfAuthority(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case delayRead:
h.notify = make(chan struct{})
h.getNotified = make(chan struct{})
Expand Down Expand Up @@ -448,7 +481,7 @@ func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2

func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) {
server := setUpServerOnly(t, port, sc, ht)
addr := resolver.Address{Addr: "localhost:" + server.port}
addr := resolver.Address{Addr: "localhost:" + server.port, ServerName: server.addr()}
copts.ChannelzParentID = channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil)

connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
Expand Down Expand Up @@ -1431,6 +1464,38 @@ func (s) TestEncodingRequiredStatus(t *testing.T) {
server.stop()
}

func (s) TestHeaderHostReplacedWithResolverAddress(t *testing.T) {
holdno marked this conversation as resolved.
Show resolved Hide resolved
server, ct, cancel := setUp(t, 0, math.MaxUint32, returnHeaderAuthority)
defer cancel()
callHdr := &CallHdr{
Host: "scheme://testSrv.com/testPath",
Method: "foo",
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
s, err := ct.NewStream(ctx, callHdr)
if err != nil {
return
}

opts := Options{Last: true}
if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != errStreamDone {
t.Fatalf("Failed to write the request: %v", err)
}
respOfAuthority := make([]byte, http2MaxFrameLen)
len, recvErr := s.Read(respOfAuthority)
if err, ok := status.FromError(recvErr); ok {
t.Fatalf("Read got error %v, headers are unexpected", err)
}

if string(respOfAuthority[:len]) != server.addr() {
t.Fatalf("Read got a unexpected :authority value %v, want %v", string(respOfAuthority), server.addr())
}

ct.Close(fmt.Errorf("closed manually by test"))
server.stop()
}

func (s) TestInvalidHeaderField(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
defer cancel()
Expand Down