From c7c94111cd0a844f0f31c422748b89896caad943 Mon Sep 17 00:00:00 2001 From: Jakub Nyckowski Date: Fri, 7 Jan 2022 17:23:22 -0500 Subject: [PATCH] Add "limiter" support to database service (#9087) Add rate and connection limiter to database service. --- lib/auth/auth.go | 1 + lib/config/configuration.go | 1 + lib/limiter/connlimiter.go | 1 - lib/limiter/limiter.go | 25 +++- lib/service/cfg.go | 3 + lib/service/db.go | 7 + lib/service/service.go | 5 + lib/srv/db/access_test.go | 12 +- lib/srv/db/common/interfaces.go | 20 ++- lib/srv/db/mongodb/engine.go | 28 ++-- lib/srv/db/mysql/engine.go | 35 +++-- lib/srv/db/mysql/proxy.go | 24 +++- lib/srv/db/postgres/engine.go | 51 ++++--- lib/srv/db/postgres/proxy.go | 21 ++- lib/srv/db/proxyserver.go | 49 +++++-- lib/srv/db/proxyserver_test.go | 230 ++++++++++++++++++++++++++++++++ lib/srv/db/server.go | 70 ++++++++-- lib/srv/db/server_test.go | 118 ++++++++++++++++ lib/utils/net.go | 35 +++++ 19 files changed, 667 insertions(+), 69 deletions(-) create mode 100644 lib/srv/db/proxyserver_test.go create mode 100644 lib/utils/net.go diff --git a/lib/auth/auth.go b/lib/auth/auth.go index e95c97d4d0241..71965579142de 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -325,6 +325,7 @@ type Server struct { // if not set, cache uses itself cache Cache + // limiter limits the number of active connections per client IP. limiter *limiter.ConnectionsLimiter // Emitter is events emitter, used to submit discrete events diff --git a/lib/config/configuration.go b/lib/config/configuration.go index a9c47ac674a71..981f75551fbc6 100644 --- a/lib/config/configuration.go +++ b/lib/config/configuration.go @@ -349,6 +349,7 @@ func ApplyFileConfig(fc *FileConfig, cfg *service.Config) error { &cfg.SSH.Limiter, &cfg.Auth.Limiter, &cfg.Proxy.Limiter, + &cfg.Databases.Limiter, &cfg.Kube.Limiter, &cfg.WindowsDesktop.ConnLimiter, } diff --git a/lib/limiter/connlimiter.go b/lib/limiter/connlimiter.go index ee3d562c6b01e..4d5c7949c4309 100644 --- a/lib/limiter/connlimiter.go +++ b/lib/limiter/connlimiter.go @@ -87,7 +87,6 @@ func (l *ConnectionsLimiter) AcquireConnection(token string) error { // ReleaseConnection decrements the counter func (l *ConnectionsLimiter) ReleaseConnection(token string) { - l.Lock() defer l.Unlock() diff --git a/lib/limiter/limiter.go b/lib/limiter/limiter.go index 5cb431d68b0c4..7d365fec0eb7c 100644 --- a/lib/limiter/limiter.go +++ b/lib/limiter/limiter.go @@ -80,8 +80,31 @@ func (l *Limiter) RegisterRequestWithCustomRate(token string, customRate *rateli return l.rateLimiter.RegisterRequest(token, customRate) } -// Add limiter to the handle +// WrapHandle adds limiter to the handle func (l *Limiter) WrapHandle(h http.Handler) { l.rateLimiter.Wrap(h) l.ConnLimiter.Wrap(l.rateLimiter) } + +// RegisterRequestAndConnection register a rate and connection limiter for a given token. Close function is returned, +// and it must be called to release the token. When a limit is hit an error is returned. +// Example usage: +// +// release, err := limiter.RegisterRequestAndConnection(clientIP) +// if err != nil { +// return trace.Wrap(err) +// } +// defer release() +func (l *Limiter) RegisterRequestAndConnection(token string) (func(), error) { + // Apply rate limiting. + if err := l.RegisterRequest(token); err != nil { + return func() {}, trace.LimitExceeded("rate limit exceeded for %q", token) + } + + // Apply connection limiting. + if err := l.AcquireConnection(token); err != nil { + return func() {}, trace.LimitExceeded("exceeded connection limit for %q", token) + } + + return func() { l.ReleaseConnection(token) }, nil +} diff --git a/lib/service/cfg.go b/lib/service/cfg.go index 80821897448b5..bda4efe5b5205 100644 --- a/lib/service/cfg.go +++ b/lib/service/cfg.go @@ -612,6 +612,8 @@ type DatabasesConfig struct { ResourceMatchers []services.ResourceMatcher // AWSMatchers match AWS hosted databases. AWSMatchers []services.AWSMatcher + // Limiter limits the connection and request rates. + Limiter limiter.Config } // Database represents a single database that's being proxied. @@ -1103,6 +1105,7 @@ func ApplyDefaults(cfg *Config) { // Databases proxy service is disabled by default. cfg.Databases.Enabled = false + defaults.ConfigureLimiter(&cfg.Databases.Limiter) // Metrics service defaults. cfg.Metrics.Enabled = false diff --git a/lib/service/db.go b/lib/service/db.go index 87d80e69ae178..7054fb7622dca 100644 --- a/lib/service/db.go +++ b/lib/service/db.go @@ -21,6 +21,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/db" @@ -167,6 +168,11 @@ func (process *TeleportProcess) initDatabaseService() (retErr error) { return trace.Wrap(err) } + connLimiter, err := limiter.NewLimiter(process.Config.Databases.Limiter) + if err != nil { + return trace.Wrap(err) + } + // Create and start the database service. dbService, err := db.New(process.ExitContext(), db.Config{ Clock: process.Clock, @@ -179,6 +185,7 @@ func (process *TeleportProcess) initDatabaseService() (retErr error) { }, Authorizer: authorizer, TLSConfig: tlsConfig, + Limiter: connLimiter, GetRotation: process.getRotation, Hostname: process.Config.Hostname, HostID: process.Config.HostUUID, diff --git a/lib/service/service.go b/lib/service/service.go index 174d0ee412f47..caeca6be9400e 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -3120,6 +3120,10 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { if err != nil { return trace.Wrap(err) } + connLimiter, err := limiter.NewLimiter(process.Config.Databases.Limiter) + if err != nil { + return trace.Wrap(err) + } dbProxyServer, err := db.NewProxyServer(process.ExitContext(), db.ProxyServerConfig{ AuthClient: conn.Client, @@ -3127,6 +3131,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { Authorizer: authorizer, Tunnel: tsrv, TLSConfig: tlsConfig, + Limiter: connLimiter, Emitter: asyncEmitter, Clock: process.Clock, ServerID: cfg.HostUUID, diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index 6d185866b09b7..a03011483aca1 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -32,6 +32,7 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/fixtures" + "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/reversetunnel" @@ -520,7 +521,7 @@ type testModules struct { func (m *testModules) Features() modules.Features { return modules.Features{ - DB: false, // Explicily turn off database access. + DB: false, // Explicitly turn off database access. } } @@ -938,6 +939,9 @@ func setupTestContext(ctx context.Context, t *testing.T, withDatabases ...withDa testCtx.fakeRemoteSite, }, } + // Empty config means no limit. + connLimiter, err := limiter.NewLimiter(limiter.Config{}) + require.NoError(t, err) // Create test audit events emitter. testCtx.emitter = newTestEmitter() @@ -949,6 +953,7 @@ func setupTestContext(ctx context.Context, t *testing.T, withDatabases ...withDa Authorizer: proxyAuthorizer, Tunnel: tunnel, TLSConfig: tlsConfig, + Limiter: connLimiter, Emitter: testCtx.emitter, Clock: testCtx.clock, ServerID: "proxy-server", @@ -1021,6 +1026,10 @@ func (c *testContext) setupDatabaseServer(ctx context.Context, t *testing.T, p a }) require.NoError(t, err) + // Create default limiter. + connLimiter, err := limiter.NewLimiter(limiter.Config{}) + require.NoError(t, err) + // Create database server agent itself. server, err := New(ctx, Config{ Clock: clockwork.NewFakeClockAt(time.Now()), @@ -1032,6 +1041,7 @@ func (c *testContext) setupDatabaseServer(ctx context.Context, t *testing.T, p a Hostname: constants.APIDomain, HostID: p.HostID, TLSConfig: tlsConfig, + Limiter: connLimiter, Auth: testAuth, Databases: p.Databases, ResourceMatchers: p.ResourceMatchers, diff --git a/lib/srv/db/common/interfaces.go b/lib/srv/db/common/interfaces.go index f023b6c11d174..c480873614c17 100644 --- a/lib/srv/db/common/interfaces.go +++ b/lib/srv/db/common/interfaces.go @@ -30,10 +30,20 @@ type Proxy interface { HandleConnection(context.Context, net.Conn) error } +// ConnectParams keeps parameters used when connecting to Service. +type ConnectParams struct { + // User is a database username. + User string + // Database is a database name/schema. + Database string + // ClientIP is a client real IP. Currently, used for rate limiting. + ClientIP string +} + // Service defines an interface for connecting to a remote database service. type Service interface { // Connect is used to connect to remote database server over reverse tunnel. - Connect(ctx context.Context, user, database string) (net.Conn, *auth.Context, error) + Connect(ctx context.Context, params ConnectParams) (net.Conn, *auth.Context, error) // Proxy starts proxying between client and service connections. Proxy(ctx context.Context, authContext *auth.Context, clientConn, serviceConn net.Conn) error } @@ -41,7 +51,13 @@ type Service interface { // Engine defines an interface for specific database protocol engine such // as Postgres or MySQL. type Engine interface { + // InitializeConnection initializes the client connection. No DB connection is made at this point, but a message + // can be sent to a client in a database format. + InitializeConnection(clientConn net.Conn, sessionCtx *Session) error + // SendError sends an error to a client in database encoded format. + // NOTE: Client connection must be initialized before this function is called. + SendError(error) // HandleConnection proxies the connection received from the proxy to // the particular database instance. - HandleConnection(context.Context, *Session, net.Conn) error + HandleConnection(context.Context, *Session) error } diff --git a/lib/srv/db/mongodb/engine.go b/lib/srv/db/mongodb/engine.go index dde524a1b8449..1cafd2837bd74 100644 --- a/lib/srv/db/mongodb/engine.go +++ b/lib/srv/db/mongodb/engine.go @@ -50,6 +50,21 @@ type Engine struct { Clock clockwork.Clock // Log is used for logging. Log logrus.FieldLogger + // clientConn is an incoming client connection. + clientConn net.Conn +} + +// InitializeConnection initializes the client connection. +func (e *Engine) InitializeConnection(clientConn net.Conn, _ *common.Session) error { + e.clientConn = clientConn + return nil +} + +// SendError sends an error to the connected client in MongoDB understandable format. +func (e *Engine) SendError(err error) { + if err != nil && !utils.IsOKNetworkError(err) { + e.replyError(e.clientConn, nil, err) + } } // HandleConnection processes the connection from MongoDB proxy coming @@ -58,14 +73,9 @@ type Engine struct { // It handles all necessary startup actions, authorization and acts as a // middleman between the proxy and the database intercepting and interpreting // all messages i.e. doing protocol parsing. -func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Session, clientConn net.Conn) (err error) { - defer func() { - if err != nil && !utils.IsOKNetworkError(err) { - e.replyError(clientConn, nil, err) - } - }() +func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Session) error { // Check that the user has access to the database. - err = e.authorizeConnection(ctx, sessionCtx) + err := e.authorizeConnection(ctx, sessionCtx) if err != nil { return trace.Wrap(err, "error authorizing database access") } @@ -84,11 +94,11 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio defer e.Audit.OnSessionEnd(e.Context, sessionCtx) // Start reading client messages and sending them to server. for { - clientMessage, err := protocol.ReadMessage(clientConn) + clientMessage, err := protocol.ReadMessage(e.clientConn) if err != nil { return trace.Wrap(err) } - err = e.handleClientMessage(ctx, sessionCtx, clientMessage, clientConn, serverConn) + err = e.handleClientMessage(ctx, sessionCtx, clientMessage, e.clientConn, serverConn) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/db/mysql/engine.go b/lib/srv/db/mysql/engine.go index 48a5ba4e6f867..4ee74d6365559 100644 --- a/lib/srv/db/mysql/engine.go +++ b/lib/srv/db/mysql/engine.go @@ -58,6 +58,22 @@ type Engine struct { Clock clockwork.Clock // Log is used for logging. Log logrus.FieldLogger + // proxyConn is a client connection. + proxyConn server.Conn +} + +// InitializeConnection initializes the engine with client connection. +func (e *Engine) InitializeConnection(clientConn net.Conn, _ *common.Session) error { + // Make server conn to get access to protocol's WriteOK/WriteError methods. + e.proxyConn = server.Conn{Conn: packet.NewConn(clientConn)} + return nil +} + +// SendError sends an error to connected client in the MySQL understandable format. +func (e *Engine) SendError(err error) { + if writeErr := e.proxyConn.WriteError(err); writeErr != nil { + e.Log.WithError(writeErr).Debugf("Failed to send error %q to MySQL client.", err) + } } // HandleConnection processes the connection from MySQL proxy coming @@ -66,18 +82,9 @@ type Engine struct { // It handles all necessary startup actions, authorization and acts as a // middleman between the proxy and the database intercepting and interpreting // all messages i.e. doing protocol parsing. -func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Session, clientConn net.Conn) (err error) { - // Make server conn to get access to protocol's WriteOK/WriteError methods. - proxyConn := server.Conn{Conn: packet.NewConn(clientConn)} - defer func() { - if err != nil { - if writeErr := proxyConn.WriteError(err); writeErr != nil { - e.Log.WithError(writeErr).Debugf("Failed to send error %q to MySQL client.", err) - } - } - }() +func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Session) error { // Perform authorization checks. - err = e.checkAccess(ctx, sessionCtx) + err := e.checkAccess(ctx, sessionCtx) if err != nil { return trace.Wrap(err) } @@ -97,7 +104,7 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio }() // Send back OK packet to indicate auth/connect success. At this point // the original client should consider the connection phase completed. - err = proxyConn.WriteOK(nil) + err = e.proxyConn.WriteOK(nil) if err != nil { return trace.Wrap(err) } @@ -106,8 +113,8 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio // Copy between the connections. clientErrCh := make(chan error, 1) serverErrCh := make(chan error, 1) - go e.receiveFromClient(clientConn, serverConn, clientErrCh, sessionCtx) - go e.receiveFromServer(serverConn, clientConn, serverErrCh) + go e.receiveFromClient(e.proxyConn.Conn, serverConn, clientErrCh, sessionCtx) + go e.receiveFromServer(serverConn, e.proxyConn.Conn, serverErrCh) select { case err := <-clientErrCh: e.Log.WithError(err).Debug("Client done.") diff --git a/lib/srv/db/mysql/proxy.go b/lib/srv/db/mysql/proxy.go index 281e5e819ff3b..814530f9c8535 100644 --- a/lib/srv/db/mysql/proxy.go +++ b/lib/srv/db/mysql/proxy.go @@ -24,9 +24,11 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/srv/db/mysql/protocol" + "github.com/gravitational/teleport/lib/utils" "github.com/siddontang/go-mysql/mysql" "github.com/siddontang/go-mysql/server" @@ -48,6 +50,8 @@ type Proxy struct { Service common.Service // Log is used for logging. Log logrus.FieldLogger + // Limiter limits the number of active connections per client IP. + Limiter *limiter.Limiter } // HandleConnection accepts connection from a MySQL client, authenticates @@ -58,7 +62,7 @@ func (p *Proxy) HandleConnection(ctx context.Context, clientConn net.Conn) (err // proxy protocol which otherwise would interfere with MySQL protocol. conn := multiplexer.NewConn(clientConn) server := p.makeServer(conn) - // If any error happens, make sure to send it back to the client so it + // If any error happens, make sure to send it back to the client, so it // has a chance to close the connection from its side. defer func() { if r := recover(); r != nil { @@ -81,7 +85,23 @@ func (p *Proxy) HandleConnection(ctx context.Context, clientConn net.Conn) (err if err != nil { return trace.Wrap(err) } - serviceConn, authContext, err := p.Service.Connect(ctx, server.GetUser(), server.GetDatabase()) + + clientIP, err := utils.ClientIPFromConn(clientConn) + if err != nil { + return trace.Wrap(err) + } + // Apply connection and rate limiting. + releaseConn, err := p.Limiter.RegisterRequestAndConnection(clientIP) + if err != nil { + return trace.Wrap(err) + } + defer releaseConn() + + serviceConn, authContext, err := p.Service.Connect(ctx, common.ConnectParams{ + User: server.GetUser(), + Database: server.GetDatabase(), + ClientIP: clientIP, + }) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/db/postgres/engine.go b/lib/srv/db/postgres/engine.go index 2ab84eee8bbf9..f6695f17ca695 100644 --- a/lib/srv/db/postgres/engine.go +++ b/lib/srv/db/postgres/engine.go @@ -52,6 +52,30 @@ type Engine struct { Clock clockwork.Clock // Log is used for logging. Log logrus.FieldLogger + // client is a client connection. + client *pgproto3.Backend +} + +// InitializeConnection initializes the client connection. +func (e *Engine) InitializeConnection(clientConn net.Conn, sessionCtx *common.Session) error { + e.client = pgproto3.NewBackend(pgproto3.NewChunkReader(clientConn), clientConn) + + // The proxy is supposed to pass a startup message it received from + // the psql client over to us, so wait for it and extract database + // and username from it. + err := e.handleStartup(e.client, sessionCtx) + if err != nil { + return trace.Wrap(err) + } + + return nil +} + +// SendError sends an error to connected client in a Postgres understandable format. +func (e *Engine) SendError(err error) { + if err := e.client.Send(toErrorResponse(err)); err != nil && !utils.IsOKNetworkError(err) { + e.Log.WithError(err).Error("Failed to send error to client.") + } } // toErrorResponse converts the provided error to a Postgres wire protocol @@ -78,25 +102,10 @@ func toErrorResponse(err error) *pgproto3.ErrorResponse { // It handles all necessary startup actions, authorization and acts as a // middleman between the proxy and the database intercepting and interpreting // all messages i.e. doing protocol parsing. -func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Session, clientConn net.Conn) (err error) { - client := pgproto3.NewBackend(pgproto3.NewChunkReader(clientConn), clientConn) - defer func() { - if err != nil { - if err := client.Send(toErrorResponse(err)); err != nil && !utils.IsOKNetworkError(err) { - e.Log.WithError(err).Error("Failed to send error to client.") - } - } - }() - // The proxy is supposed to pass a startup message it received from - // the psql client over to us, so wait for it and extract database - // and username from it. - err = e.handleStartup(client, sessionCtx) - if err != nil { - return trace.Wrap(err) - } +func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Session) error { // Now we know which database/username the user is connecting to, so // perform an authorization check. - err = e.checkAccess(ctx, sessionCtx) + err := e.checkAccess(ctx, sessionCtx) if err != nil { return trace.Wrap(err) } @@ -106,8 +115,8 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio return trace.Wrap(err) } // Upon successful connect, indicate to the Postgres client that startup - // has been completed and it can start sending queries. - err = e.makeClientReady(client, hijackedConn) + // has been completed, and it can start sending queries. + err = e.makeClientReady(e.client, hijackedConn) if err != nil { return trace.Wrap(err) } @@ -131,8 +140,8 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio // the client (psql or other Postgres client) and the server (database). clientErrCh := make(chan error, 1) serverErrCh := make(chan error, 1) - go e.receiveFromClient(client, server, clientErrCh, sessionCtx) - go e.receiveFromServer(server, client, serverConn, serverErrCh, sessionCtx) + go e.receiveFromClient(e.client, server, clientErrCh, sessionCtx) + go e.receiveFromServer(server, e.client, serverConn, serverErrCh, sessionCtx) select { case err := <-clientErrCh: e.Log.WithError(err).Debug("Client done.") diff --git a/lib/srv/db/postgres/proxy.go b/lib/srv/db/postgres/proxy.go index 924094613eada..9bcf34d852ba6 100644 --- a/lib/srv/db/postgres/proxy.go +++ b/lib/srv/db/postgres/proxy.go @@ -22,7 +22,9 @@ import ( "net" "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/srv/db/common" + "github.com/gravitational/teleport/lib/utils" "github.com/jackc/pgproto3/v2" @@ -43,6 +45,8 @@ type Proxy struct { Service common.Service // Log is used for logging. Log logrus.FieldLogger + // Limiter limits the number of active connections per client IP. + Limiter *limiter.Limiter } // HandleConnection accepts connection from a Postgres client, authenticates @@ -63,7 +67,22 @@ func (p *Proxy) HandleConnection(ctx context.Context, clientConn net.Conn) (err if err != nil { return trace.Wrap(err) } - serviceConn, authContext, err := p.Service.Connect(ctx, "", "") + + clientIP, err := utils.ClientIPFromConn(clientConn) + if err != nil { + return trace.Wrap(err) + } + + // Apply connection and rate limiting. + releaseConn, err := p.Limiter.RegisterRequestAndConnection(clientIP) + if err != nil { + return trace.Wrap(err) + } + defer releaseConn() + + serviceConn, authContext, err := p.Service.Connect(ctx, common.ConnectParams{ + ClientIP: clientIP, + }) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/db/proxyserver.go b/lib/srv/db/proxyserver.go index a9cc444750d71..bc2d5e5dc6454 100644 --- a/lib/srv/db/proxyserver.go +++ b/lib/srv/db/proxyserver.go @@ -34,9 +34,11 @@ import ( apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/native" + "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv" + "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/srv/db/mysql" "github.com/gravitational/teleport/lib/srv/db/postgres" "github.com/gravitational/teleport/lib/tlsca" @@ -73,6 +75,8 @@ type ProxyServerConfig struct { Tunnel reversetunnel.Server // TLSConfig is the proxy server TLS configuration. TLSConfig *tls.Config + // Limiter is the connection/rate limiter. + Limiter *limiter.Limiter // Emitter is used to emit audit events. Emitter events.Emitter // Clock to override clock in tests. @@ -120,6 +124,15 @@ func (c *ProxyServerConfig) CheckAndSetDefaults() error { if c.LockWatcher == nil { return trace.BadParameter("missing LockWatcher") } + if c.Limiter == nil { + // Empty config means no connection limit. + connLimiter, err := limiter.NewLimiter(limiter.Config{}) + if err != nil { + return trace.Wrap(err) + } + + c.Limiter = connLimiter + } return nil } @@ -252,7 +265,22 @@ func (s *ProxyServer) handleConnection(conn net.Conn) error { if err != nil { return trace.Wrap(err) } - serviceConn, authContext, err := s.Connect(ctx, "", "") + + clientIP, err := utils.ClientIPFromConn(conn) + if err != nil { + return trace.Wrap(err) + } + + // Apply connection and rate limiting. + release, err := s.cfg.Limiter.RegisterRequestAndConnection(clientIP) + if err != nil { + return trace.Wrap(err) + } + defer release() + + serviceConn, authContext, err := s.Connect(ctx, common.ConnectParams{ + ClientIP: clientIP, + }) if err != nil { return trace.Wrap(err) } @@ -270,6 +298,7 @@ func (s *ProxyServer) PostgresProxy() *postgres.Proxy { TLSConfig: s.cfg.TLSConfig, Middleware: s.middleware, Service: s, + Limiter: s.cfg.Limiter, Log: s.log, } } @@ -280,6 +309,7 @@ func (s *ProxyServer) MySQLProxy() *mysql.Proxy { TLSConfig: s.cfg.TLSConfig, Middleware: s.middleware, Service: s, + Limiter: s.cfg.Limiter, Log: s.log, } } @@ -292,8 +322,8 @@ func (s *ProxyServer) MySQLProxy() *mysql.Proxy { // decoded from the client certificate by auth.Middleware. // // Implements common.Service. -func (s *ProxyServer) Connect(ctx context.Context, user, database string) (net.Conn, *auth.Context, error) { - proxyContext, err := s.authorize(ctx, user, database) +func (s *ProxyServer) Connect(ctx context.Context, params common.ConnectParams) (net.Conn, *auth.Context, error) { + proxyContext, err := s.authorize(ctx, params) if err != nil { return nil, nil, trace.Wrap(err) } @@ -465,17 +495,20 @@ type proxyContext struct { authContext *auth.Context } -func (s *ProxyServer) authorize(ctx context.Context, user, database string) (*proxyContext, error) { +func (s *ProxyServer) authorize(ctx context.Context, params common.ConnectParams) (*proxyContext, error) { authContext, err := s.cfg.Authorizer.Authorize(ctx) if err != nil { return nil, trace.Wrap(err) } identity := authContext.Identity.GetIdentity() - if user != "" { - identity.RouteToDatabase.Username = user + if params.User != "" { + identity.RouteToDatabase.Username = params.User + } + if params.Database != "" { + identity.RouteToDatabase.Database = params.Database } - if database != "" { - identity.RouteToDatabase.Database = database + if params.ClientIP != "" { + identity.ClientIP = params.ClientIP } cluster, servers, err := s.getDatabaseServers(ctx, identity) if err != nil { diff --git a/lib/srv/db/proxyserver_test.go b/lib/srv/db/proxyserver_test.go new file mode 100644 index 0000000000000..1b61b4fbd8723 --- /dev/null +++ b/lib/srv/db/proxyserver_test.go @@ -0,0 +1,230 @@ +/* +Copyright 2021 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package db + +import ( + "context" + "testing" + "time" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/limiter" + "github.com/stretchr/testify/require" +) + +func TestProxyConnectionLimiting(t *testing.T) { + const ( + user = "bob" + role = "admin" + postgresDbName = "postgres" + dbUser = user + connLimitNumber = 3 // Arbitrary number + ) + + ctx := context.Background() + testCtx := setupTestContext(ctx, t, + withSelfHostedPostgres("postgres"), + withSelfHostedMySQL("mysql")) + // TODO(jakule): Mongo seems to create some internal connections. I didn't find a way to predict + // how many connection will be created and decided to skip it for now. Otherwise, the whole test may be flaky. + + connLimit, err := limiter.NewLimiter(limiter.Config{MaxConnections: connLimitNumber}) + require.NoError(t, err) + + // Set proxy connection limiter. + testCtx.proxyServer.cfg.Limiter = connLimit + + go testCtx.startHandlingConnections() + + // Create user/role with the requested permissions. + testCtx.createUserAndRole(ctx, t, user, role, []string{types.Wildcard}, []string{types.Wildcard}) + + tests := []struct { + name string + connect func() (func(context.Context) error, error) + }{ + { + "postgres", + func() (func(context.Context) error, error) { + pgConn, err := testCtx.postgresClient(ctx, user, "postgres", dbUser, postgresDbName) + return pgConn.Close, err + }, + }, + { + "mysql", + func() (func(context.Context) error, error) { + mysqlClient, err := testCtx.mysqlClient(user, "mysql", dbUser) + return func(_ context.Context) error { + return mysqlClient.Close() + }, err + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + // Keep close functions to all connections. Call and release all active connection at the end of test. + connsClosers := make([]func(context.Context) error, 0) + t.Cleanup(func() { + for _, connClose := range connsClosers { + err := connClose(ctx) + require.NoError(t, err) + } + }) + + t.Run("limit can be hit", func(t *testing.T) { + for i := 0; i < connLimitNumber; i++ { + // Try to connect to the database. + pgConn, err := tt.connect() + require.NoError(t, err) + + connsClosers = append(connsClosers, pgConn) + } + + // This connection should go over the limit. + _, err = tt.connect() + require.Error(t, err) + require.Contains(t, err.Error(), "exceeded connection limit") + }) + + // When a connection is released a new can be established + t.Run("reconnect one", func(t *testing.T) { + // Get one open connection. + oneConn := connsClosers[len(connsClosers)-1] + connsClosers = connsClosers[:len(connsClosers)-1] + + // Close it, this should decrease the connection limit. + err = oneConn(ctx) + require.NoError(t, err) + + // Create a new connection. We do not expect an error here as we have just closed one. + pgConn, err := tt.connect() + require.NoError(t, err) + connsClosers = append(connsClosers, pgConn) + + // Here the limit should be reached again. + _, err = tt.connect() + require.Error(t, err) + require.Contains(t, err.Error(), "exceeded connection limit") + }) + }) + } +} + +func TestProxyRateLimiting(t *testing.T) { + const ( + user = "bob" + role = "admin" + postgresDbName = "postgres" + dbUser = user + connLimitNumber = 20 // Should be enough to hit the connection limit. + ) + + ctx := context.Background() + testCtx := setupTestContext(ctx, t, + withSelfHostedPostgres("postgres"), + withSelfHostedMySQL("mysql"), + withSelfHostedMongo("mongodb"), + ) + + connLimit, err := limiter.NewLimiter(limiter.Config{ + // Set rates low, so we can easily hit them. + Rates: []limiter.Rate{ + { + Period: 10 * time.Second, + Average: 3, + Burst: 3, + }, + }}) + require.NoError(t, err) + + // Set proxy connection limiter. + testCtx.proxyServer.cfg.Limiter = connLimit + + go testCtx.startHandlingConnections() + + // Create user/role with the requested permissions. + testCtx.createUserAndRole(ctx, t, user, role, []string{types.Wildcard}, []string{types.Wildcard}) + + tests := []struct { + name string + connect func() (func(context.Context) error, error) + }{ + { + "postgres", + func() (func(context.Context) error, error) { + pgConn, err := testCtx.postgresClient(ctx, user, "postgres", dbUser, postgresDbName) + return pgConn.Close, err + }, + }, + { + "mysql", + func() (func(context.Context) error, error) { + mysqlClient, err := testCtx.mysqlClient(user, "mysql", dbUser) + return func(_ context.Context) error { + return mysqlClient.Close() + }, err + }, + }, + { + "mongodb", + func() (func(context.Context) error, error) { + mongoClient, err := testCtx.mongoClient(ctx, user, "mongodb", dbUser) + return mongoClient.Disconnect, err + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + // Keep close functions to all connections. Call and release all active connection at the end of test. + connsClosers := make([]func(context.Context) error, 0) + t.Cleanup(func() { + for _, connClose := range connsClosers { + err := connClose(ctx) + require.NoError(t, err) + } + }) + + for i := 0; i < connLimitNumber; i++ { + // Try to connect to the database. + pgConn, err := tt.connect() + if err == nil { + connsClosers = append(connsClosers, pgConn) + + continue + } + + require.Error(t, err) + + //TODO(jakule) currently mongodb proxy don't know how to propagate an error, + // so this check for mongo is disabled + if tt.name != "mongodb" { + require.Contains(t, err.Error(), "rate limit exceeded") + } + + return + } + + require.FailNow(t, "we should hit the limit by now") + }) + } +} diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index 02d37e47e9640..0f5fc1daf8e8c 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -30,6 +30,7 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/labels" + "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/srv/db/cloud" @@ -61,6 +62,8 @@ type Config struct { NewAudit NewAuditFn // TLSConfig is the *tls.Config for this server. TLSConfig *tls.Config + // Limiter limits the number of connections per client IP. + Limiter *limiter.Limiter // Authorizer is used to authorize requests coming from proxy. Authorizer auth.Authorizer // GetRotation returns the certificate rotation state. @@ -169,6 +172,13 @@ func (c *Config) CheckAndSetDefaults(ctx context.Context) (err error) { return trace.Wrap(err) } } + if c.Limiter == nil { + // Use default limiter if nothing is provided. Connection limiting will be disabled. + c.Limiter, err = limiter.NewLimiter(limiter.Config{}) + if err != nil { + return trace.Wrap(err) + } + } return nil } @@ -617,7 +627,7 @@ func (s *Server) HandleConnection(conn net.Conn) { // Make sure to close the upgraded connection, not "conn", otherwise // the other side may not detect that connection has closed. defer tlsConn.Close() - // Perform the hanshake explicitly, normally it should be performed + // Perform the handshake explicitly, normally it should be performed // on the first read/write but when the connection is passed over // reverse tunnel it doesn't happen for some reason. err := tlsConn.Handshake() @@ -641,11 +651,12 @@ func (s *Server) HandleConnection(conn net.Conn) { } } -func (s *Server) handleConnection(ctx context.Context, conn net.Conn) error { +func (s *Server) handleConnection(ctx context.Context, clientConn net.Conn) error { sessionCtx, err := s.authorize(ctx) if err != nil { return trace.Wrap(err) } + streamWriter, err := s.newStreamWriter(sessionCtx) if err != nil { return trace.Wrap(err) @@ -653,7 +664,7 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) error { defer func() { // Closing the stream writer is needed to flush all recorded data // and trigger upload. Do it in a goroutine since depending on - // session size it can take a while and we don't want to block + // session size it can take a while, and we don't want to block // the client. go func() { // Use the server closing context to make sure that upload @@ -664,15 +675,25 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) error { } }() }() - engine, err := s.dispatch(sessionCtx, streamWriter) + engine, err := s.dispatch(sessionCtx, streamWriter, clientConn) if err != nil { return trace.Wrap(err) } + defer func() { + if r := recover(); r != nil { + s.log.Warnf("Recovered while handling DB connection from %v: %v.", clientConn.RemoteAddr(), r) + err = trace.BadParameter("failed to handle client connection") + } + if err != nil { + engine.SendError(err) + } + }() + // Wrap a client connection into monitor that auto-terminates // idle connection and connection with expired cert. - conn, err = monitorConn(ctx, monitorConnConfig{ - conn: conn, + clientConn, err = monitorConn(ctx, monitorConnConfig{ + conn: clientConn, lockWatcher: s.cfg.LockWatcher, lockTargets: sessionCtx.LockTargets, identity: sessionCtx.Identity, @@ -689,21 +710,51 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) error { return trace.Wrap(err) } - err = engine.HandleConnection(ctx, sessionCtx, conn) + // TODO(jakule): ClientIP should be required starting from 10.0. + clientIP := sessionCtx.Identity.ClientIP + if clientIP != "" { + s.log.Debugf("Real client IP %s", clientIP) + + var release func() + release, err = s.cfg.Limiter.RegisterRequestAndConnection(clientIP) + if err != nil { + return trace.Wrap(err) + } + defer release() + } else { + s.log.Debug("ClientIP is not set (Proxy Service has to be updated). Rate limiting is disabled.") + } + + err = engine.HandleConnection(ctx, sessionCtx) if err != nil { return trace.Wrap(err) } return nil } -// dispatch returns an appropriate database engine for the session. -func (s *Server) dispatch(sessionCtx *common.Session, streamWriter events.StreamWriter) (common.Engine, error) { +// dispatch creates and initializes an appropriate database engine for the session. +func (s *Server) dispatch(sessionCtx *common.Session, streamWriter events.StreamWriter, clientConn net.Conn) (common.Engine, error) { audit, err := s.cfg.NewAudit(common.AuditConfig{ Emitter: streamWriter, }) if err != nil { return nil, trace.Wrap(err) } + engine, err := s.createEngine(sessionCtx, audit) + if err != nil { + return nil, trace.Wrap(err) + } + + if err := engine.InitializeConnection(clientConn, sessionCtx); err != nil { + return nil, trace.Wrap(err) + } + + return engine, nil +} + +// createEngine creates a new database engine base on the database protocol. An error is returned when +// a protocol is not supported. +func (s *Server) createEngine(sessionCtx *common.Session, audit common.Audit) (common.Engine, error) { switch sessionCtx.Database.GetProtocol() { case defaults.ProtocolPostgres, defaults.ProtocolCockroachDB: return &postgres.Engine{ @@ -731,6 +782,7 @@ func (s *Server) dispatch(sessionCtx *common.Session, streamWriter events.Stream Log: sessionCtx.Log, }, nil } + return nil, trace.BadParameter("unsupported database protocol %q", sessionCtx.Database.GetProtocol()) } diff --git a/lib/srv/db/server_test.go b/lib/srv/db/server_test.go index 0e3d2980834d8..8105809a46297 100644 --- a/lib/srv/db/server_test.go +++ b/lib/srv/db/server_test.go @@ -22,8 +22,13 @@ import ( apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/limiter" + "github.com/jackc/pgconn" + "github.com/siddontang/go-mysql/client" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/mongo" ) // TestDatabaseServerStart validates that started database server updates its @@ -58,3 +63,116 @@ func TestDatabaseServerStart(t *testing.T) { server.GetDatabase().GetAllLabels()) } } + +func TestDatabaseServerLimiting(t *testing.T) { + const ( + user = "bob" + role = "admin" + dbName = "postgres" + dbUser = user + connLimit = int64(5) // Arbitrary number + ) + + ctx := context.Background() + allowDbUsers := []string{types.Wildcard} + allowDbNames := []string{types.Wildcard} + + testCtx := setupTestContext(ctx, t, + withSelfHostedPostgres("postgres"), + withSelfHostedMySQL("mysql"), + withSelfHostedMongo("mongo"), + ) + + connLimiter, err := limiter.NewLimiter(limiter.Config{MaxConnections: connLimit}) + require.NoError(t, err) + + // Set connection limit + testCtx.server.cfg.Limiter = connLimiter + + go testCtx.startHandlingConnections() + t.Cleanup(func() { + err := testCtx.Close() + require.NoError(t, err) + }) + + // Create user/role with the requested permissions. + testCtx.createUserAndRole(ctx, t, user, role, allowDbUsers, allowDbNames) + + t.Run("postgres", func(t *testing.T) { + dbConns := make([]*pgconn.PgConn, 0) + t.Cleanup(func() { + // Disconnect all clients. + for _, pgConn := range dbConns { + err = pgConn.Close(ctx) + require.NoError(t, err) + } + }) + + // Connect the maximum allowed number of clients. + for i := int64(0); i < connLimit; i++ { + pgConn, err := testCtx.postgresClient(ctx, user, "postgres", dbUser, dbName) + require.NoError(t, err) + + // Save all connection, so we can close them later. + dbConns = append(dbConns, pgConn) + } + + // We keep the previous connections open, so this one should be rejected, because we exhausted the limit. + _, err = testCtx.postgresClient(ctx, user, "postgres", dbUser, dbName) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeded connection limit") + }) + + t.Run("mysql", func(t *testing.T) { + dbConns := make([]*client.Conn, 0) + t.Cleanup(func() { + // Disconnect all clients. + for _, dbConn := range dbConns { + err = dbConn.Close() + require.NoError(t, err) + } + }) + // Connect the maximum allowed number of clients. + for i := int64(0); i < connLimit; i++ { + mysqlConn, err := testCtx.mysqlClient(user, "mysql", dbUser) + require.NoError(t, err) + + // Save all connection, so we can close them later. + dbConns = append(dbConns, mysqlConn) + } + + // We keep the previous connections open, so this one should be rejected, because we exhausted the limit. + _, err = testCtx.mysqlClient(user, "mysql", dbUser) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeded connection limit") + }) + + t.Run("mongodb", func(t *testing.T) { + dbConns := make([]*mongo.Client, 0) + t.Cleanup(func() { + // Disconnect all clients. + for _, dbConn := range dbConns { + err = dbConn.Disconnect(ctx) + require.NoError(t, err) + } + }) + // Mongo driver behave different from MySQL and Postgres. In this case we just want to hit the limit + // by creating some DB connections. + for i := int64(0); i < 2*connLimit; i++ { + mongoConn, err := testCtx.mongoClient(ctx, user, "mongo", dbUser) + + if err == nil { + // Save all connection, so we can close them later. + dbConns = append(dbConns, mongoConn) + + continue + } + + assert.Contains(t, err.Error(), "exceeded connection limit") + // When we hit the expected error we can exit. + return + } + + require.FailNow(t, "we should exceed the connection limit by now") + }) +} diff --git a/lib/utils/net.go b/lib/utils/net.go new file mode 100644 index 0000000000000..a67507d68a8be --- /dev/null +++ b/lib/utils/net.go @@ -0,0 +1,35 @@ +/* +Copyright 2021 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "net" + + "github.com/gravitational/trace" +) + +// ClientIPFromConn extracts host from provided remote address. +func ClientIPFromConn(conn net.Conn) (string, error) { + clientRemoteAddr := conn.RemoteAddr() + + clientIP, _, err := net.SplitHostPort(clientRemoteAddr.String()) + if err != nil { + return "", trace.Wrap(err) + } + + return clientIP, nil +}