diff --git a/client/pkg/transport/keepalive_listener.go b/client/pkg/transport/keepalive_listener.go index 4ff8e7f0010..2241d8823ed 100644 --- a/client/pkg/transport/keepalive_listener.go +++ b/client/pkg/transport/keepalive_listener.go @@ -21,26 +21,29 @@ import ( "time" ) -type keepAliveConn interface { - SetKeepAlive(bool) error - SetKeepAlivePeriod(d time.Duration) error -} - // NewKeepAliveListener returns a listener that listens on the given address. // Be careful when wrap around KeepAliveListener with another Listener if TLSInfo is not nil. // Some pkgs (like go/http) might expect Listener to return TLSConn type to start TLS handshake. // http://tldp.org/HOWTO/TCP-Keepalive-HOWTO/overview.html +// +// Note(ahrtr): +// only `net.TCPConn` supports `SetKeepAlive` and `SetKeepAlivePeriod` +// by default, so if you want to wrap multiple layers of net.Listener, +// the `keepaliveListener` should be the one which is closest to the +// original `net.Listener` implementation, namely `TCPListener`. func NewKeepAliveListener(l net.Listener, scheme string, tlscfg *tls.Config) (net.Listener, error) { + kal := &keepaliveListener{ + Listener: l, + } + if scheme == "https" { if tlscfg == nil { return nil, fmt.Errorf("cannot listen on TLS for given listener: KeyFile and CertFile are not presented") } - return newTLSKeepaliveListener(l, tlscfg), nil + return newTLSKeepaliveListener(kal, tlscfg), nil } - return &keepaliveListener{ - Listener: l, - }, nil + return kal, nil } type keepaliveListener struct{ net.Listener } @@ -50,13 +53,43 @@ func (kln *keepaliveListener) Accept() (net.Conn, error) { if err != nil { return nil, err } - kac := c.(keepAliveConn) + + kac, err := createKeepaliveConn(c) + if err != nil { + return nil, fmt.Errorf("create keepalive connection failed, %w", err) + } // detection time: tcp_keepalive_time + tcp_keepalive_probes + tcp_keepalive_intvl // default on linux: 30 + 8 * 30 // default on osx: 30 + 8 * 75 - kac.SetKeepAlive(true) - kac.SetKeepAlivePeriod(30 * time.Second) - return c, nil + if err := kac.SetKeepAlive(true); err != nil { + return nil, fmt.Errorf("SetKeepAlive failed, %w", err) + } + if err := kac.SetKeepAlivePeriod(30 * time.Second); err != nil { + return nil, fmt.Errorf("SetKeepAlivePeriod failed, %w", err) + } + return kac, nil +} + +func createKeepaliveConn(c net.Conn) (*keepAliveConn, error) { + tcpc, ok := c.(*net.TCPConn) + if !ok { + return nil, ErrNotTCP + } + return &keepAliveConn{tcpc}, nil +} + +type keepAliveConn struct { + *net.TCPConn +} + +// SetKeepAlive sets keepalive +func (l *keepAliveConn) SetKeepAlive(doKeepAlive bool) error { + return l.TCPConn.SetKeepAlive(doKeepAlive) +} + +// SetKeepAlivePeriod sets keepalive period +func (l *keepAliveConn) SetKeepAlivePeriod(d time.Duration) error { + return l.TCPConn.SetKeepAlivePeriod(d) } // A tlsKeepaliveListener implements a network listener (net.Listener) for TLS connections. @@ -72,12 +105,6 @@ func (l *tlsKeepaliveListener) Accept() (c net.Conn, err error) { if err != nil { return } - kac := c.(keepAliveConn) - // detection time: tcp_keepalive_time + tcp_keepalive_probes + tcp_keepalive_intvl - // default on linux: 30 + 8 * 30 - // default on osx: 30 + 8 * 75 - kac.SetKeepAlive(true) - kac.SetKeepAlivePeriod(30 * time.Second) c = tls.Server(c, l.config) return c, nil } diff --git a/client/pkg/transport/keepalive_listener_test.go b/client/pkg/transport/keepalive_listener_test.go index 425f53368b5..041e48953b1 100644 --- a/client/pkg/transport/keepalive_listener_test.go +++ b/client/pkg/transport/keepalive_listener_test.go @@ -40,6 +40,9 @@ func TestNewKeepAliveListener(t *testing.T) { if err != nil { t.Fatalf("unexpected Accept error: %v", err) } + if _, ok := conn.(*keepAliveConn); !ok { + t.Fatalf("Unexpected conn type: %T, wanted *keepAliveConn", conn) + } conn.Close() ln.Close() diff --git a/client/pkg/transport/limit_listen.go b/client/pkg/transport/limit_listen.go index 930c542066f..404722ba76e 100644 --- a/client/pkg/transport/limit_listen.go +++ b/client/pkg/transport/limit_listen.go @@ -63,6 +63,9 @@ func (l *limitListenerConn) Close() error { return err } +// SetKeepAlive sets keepalive +// +// Deprecated: use (*keepAliveConn) SetKeepAlive instead. func (l *limitListenerConn) SetKeepAlive(doKeepAlive bool) error { tcpc, ok := l.Conn.(*net.TCPConn) if !ok { @@ -71,6 +74,9 @@ func (l *limitListenerConn) SetKeepAlive(doKeepAlive bool) error { return tcpc.SetKeepAlive(doKeepAlive) } +// SetKeepAlivePeriod sets keepalive period +// +// Deprecated: use (*keepAliveConn) SetKeepAlivePeriod instead. func (l *limitListenerConn) SetKeepAlivePeriod(d time.Duration) error { tcpc, ok := l.Conn.(*net.TCPConn) if !ok { diff --git a/client/pkg/transport/listener.go b/client/pkg/transport/listener.go index 992c773eaac..e8f475eb824 100644 --- a/client/pkg/transport/listener.go +++ b/client/pkg/transport/listener.go @@ -68,7 +68,7 @@ func newListener(addr, scheme string, opts ...ListenerOption) (net.Listener, err fallthrough case lnOpts.IsTimeout(), lnOpts.IsSocketOpts(): // timeout listener with socket options. - ln, err := lnOpts.ListenConfig.Listen(context.TODO(), "tcp", addr) + ln, err := newKeepAliveListener(&lnOpts.ListenConfig, addr) if err != nil { return nil, err } @@ -78,7 +78,7 @@ func newListener(addr, scheme string, opts ...ListenerOption) (net.Listener, err writeTimeout: lnOpts.writeTimeout, } case lnOpts.IsTimeout(): - ln, err := net.Listen("tcp", addr) + ln, err := newKeepAliveListener(nil, addr) if err != nil { return nil, err } @@ -88,7 +88,7 @@ func newListener(addr, scheme string, opts ...ListenerOption) (net.Listener, err writeTimeout: lnOpts.writeTimeout, } default: - ln, err := net.Listen("tcp", addr) + ln, err := newKeepAliveListener(nil, addr) if err != nil { return nil, err } @@ -102,6 +102,19 @@ func newListener(addr, scheme string, opts ...ListenerOption) (net.Listener, err return wrapTLS(scheme, lnOpts.tlsInfo, lnOpts.Listener) } +func newKeepAliveListener(cfg *net.ListenConfig, addr string) (ln net.Listener, err error) { + if cfg != nil { + ln, err = cfg.Listen(context.TODO(), "tcp", addr) + } else { + ln, err = net.Listen("tcp", addr) + } + if err != nil { + return + } + + return NewKeepAliveListener(ln, "tcp", nil) +} + func wrapTLS(scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, error) { if scheme != "https" && scheme != "unixs" { return l, nil diff --git a/client/pkg/transport/listener_test.go b/client/pkg/transport/listener_test.go index 00657648ece..96d0563da0e 100644 --- a/client/pkg/transport/listener_test.go +++ b/client/pkg/transport/listener_test.go @@ -213,6 +213,15 @@ func TestNewListenerWithSocketOpts(t *testing.T) { if !test.expectedErr && err != nil { t.Fatalf("unexpected error: %v", err) } + + if test.scheme == "http" { + lnOpts := newListenOpts(test.opts...) + if !lnOpts.IsSocketOpts() && !lnOpts.IsTimeout() { + if _, ok := ln.(*keepaliveListener); !ok { + t.Fatalf("ln: unexpected listener type: %T, wanted *keepaliveListener", ln) + } + } + } }) } } diff --git a/server/embed/etcd.go b/server/embed/etcd.go index 5321fa476e8..28b0ff92cf6 100644 --- a/server/embed/etcd.go +++ b/server/embed/etcd.go @@ -650,12 +650,6 @@ func configureClientListeners(cfg *Config) (sctxs map[string]*serveCtx, err erro sctx.l = transport.LimitListener(sctx.l, int(fdLimit-reservedInternalFDNum)) } - if network == "tcp" { - if sctx.l, err = transport.NewKeepAliveListener(sctx.l, network, nil); err != nil { - return nil, err - } - } - defer func(u url.URL) { if err == nil { return