diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 227608c7f21..0956b500c18 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -36,6 +36,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" "google.golang.org/grpc/internal/grpcutil" + "google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -231,6 +232,11 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, if kp.Timeout == 0 { kp.Timeout = defaultServerKeepaliveTimeout } + if kp.Time != infinity { + if err = syscall.SetTCPUserTimeout(conn, kp.Timeout); err != nil { + return nil, connectionErrorf(false, err, "transport: failed to set TCP_USER_TIMEOUT: %v", err) + } + } kep := config.KeepalivePolicy if kep.MinTime == 0 { kep.MinTime = defaultKeepalivePolicyMinTime diff --git a/internal/transport/keepalive_test.go b/internal/transport/keepalive_test.go index 4f5a2bed622..41395216fe4 100644 --- a/internal/transport/keepalive_test.go +++ b/internal/transport/keepalive_test.go @@ -645,19 +645,28 @@ func (s) TestKeepaliveServerEnforcementWithDormantKeepaliveOnClient(t *testing.T // the keepalive timeout, as detailed in proposal A18. func (s) TestTCPUserTimeout(t *testing.T) { tests := []struct { - time time.Duration - timeout time.Duration - wantTimeout time.Duration + time time.Duration + timeout time.Duration + clientWantTimeout time.Duration + serverWantTimeout time.Duration }{ { 10 * time.Second, 10 * time.Second, 10 * 1000 * time.Millisecond, + 10 * 1000 * time.Millisecond, }, { 0, 0, 0, + 20 * 1000 * time.Millisecond, + }, + { + infinity, + infinity, + 0, + 0, }, } for _, tt := range tests { @@ -666,7 +675,7 @@ func (s) TestTCPUserTimeout(t *testing.T) { 0, &ServerConfig{ KeepaliveParams: keepalive.ServerParameters{ - Time: tt.timeout, + Time: tt.time, Timeout: tt.timeout, }, }, @@ -684,6 +693,26 @@ func (s) TestTCPUserTimeout(t *testing.T) { cancel() }() + var sc *http2Server + // Wait until the server transport is setup. + for { + server.mu.Lock() + if len(server.conns) == 0 { + server.mu.Unlock() + time.Sleep(time.Millisecond) + continue + } + for k := range server.conns { + var ok bool + sc, ok = k.(*http2Server) + if !ok { + t.Fatalf("Failed to convert %v to *http2Server", k) + } + } + server.mu.Unlock() + break + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() stream, err := client.NewStream(ctx, &CallHdr{}) @@ -692,15 +721,23 @@ func (s) TestTCPUserTimeout(t *testing.T) { } client.CloseStream(stream, io.EOF) - opt, err := syscall.GetTCPUserTimeout(client.conn) + cltOpt, err := syscall.GetTCPUserTimeout(client.conn) if err != nil { t.Fatalf("syscall.GetTCPUserTimeout() failed: %v", err) } - if opt < 0 { + if cltOpt < 0 { t.Skipf("skipping test on unsupported environment") } - if gotTimeout := time.Duration(opt) * time.Millisecond; gotTimeout != tt.wantTimeout { - t.Fatalf("syscall.GetTCPUserTimeout() = %d, want %d", gotTimeout, tt.wantTimeout) + if gotTimeout := time.Duration(cltOpt) * time.Millisecond; gotTimeout != tt.clientWantTimeout { + t.Fatalf("syscall.GetTCPUserTimeout() = %d, want %d", gotTimeout, tt.clientWantTimeout) + } + + srvOpt, err := syscall.GetTCPUserTimeout(sc.conn) + if err != nil { + t.Fatalf("syscall.GetTCPUserTimeout() failed: %v", err) + } + if gotTimeout := time.Duration(srvOpt) * time.Millisecond; gotTimeout != tt.serverWantTimeout { + t.Fatalf("syscall.GetTCPUserTimeout() = %d, want %d", gotTimeout, tt.serverWantTimeout) } } }