-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
engine.go
235 lines (218 loc) · 8.13 KB
/
engine.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
/*
Copyright 2021 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mongodb
import (
"context"
"net"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/common/role"
"github.com/gravitational/teleport/lib/srv/db/mongodb/protocol"
"github.com/gravitational/teleport/lib/utils"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
)
// Engine implements the MongoDB database service that accepts client
// connections coming over reverse tunnel from the proxy and proxies
// them between the proxy and the MongoDB database instance.
//
// Implements common.Engine.
type Engine struct {
// Auth handles database access authentication.
Auth common.Auth
// Audit emits database access audit events.
Audit common.Audit
// Context is the database server close context.
Context context.Context
// Clock is the clock interface.
Clock clockwork.Clock
// Log is used for logging.
Log logrus.FieldLogger
// clientConn is an incoming client connection.
clientConn net.Conn
}
// InitializeConnection initializes the client connection.
func (e *Engine) InitializeConnection(clientConn net.Conn, _ *common.Session) error {
e.clientConn = clientConn
return nil
}
// SendError sends an error to the connected client in MongoDB understandable format.
func (e *Engine) SendError(err error) {
if err != nil && !utils.IsOKNetworkError(err) {
e.replyError(e.clientConn, nil, err)
}
}
// HandleConnection processes the connection from MongoDB proxy coming
// over reverse tunnel.
//
// It handles all necessary startup actions, authorization and acts as a
// middleman between the proxy and the database intercepting and interpreting
// all messages i.e. doing protocol parsing.
func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Session) error {
// Check that the user has access to the database.
err := e.authorizeConnection(ctx, sessionCtx)
if err != nil {
return trace.Wrap(err, "error authorizing database access")
}
// Establish connection to the MongoDB server.
serverConn, err := e.connect(ctx, sessionCtx)
if err != nil {
return trace.Wrap(err, "error connecting to the database")
}
defer func() {
err := serverConn.Close()
if err != nil {
e.Log.WithError(err).Error("Failed to close server connection.")
}
}()
e.Audit.OnSessionStart(e.Context, sessionCtx, nil)
defer e.Audit.OnSessionEnd(e.Context, sessionCtx)
// Start reading client messages and sending them to server.
for {
clientMessage, err := protocol.ReadMessage(e.clientConn)
if err != nil {
return trace.Wrap(err)
}
err = e.handleClientMessage(ctx, sessionCtx, clientMessage, e.clientConn, serverConn)
if err != nil {
return trace.Wrap(err)
}
}
}
// handleClientMessage implements the client message's roundtrip which can go
// down a few different ways:
// 1. If the client's command is not allowed by user's role, we do not pass it
// to the server and return an error to the client.
// 2. In the most common case, we send client message to the server, read its
// reply and send it back to the client.
// 3. Some client commands do not receive a reply in which case we just return
// after sending message to the server and wait for next client message.
// 4. Server can also send multiple messages in a row in which case we exhaust
// them before returning to listen for next client message.
func (e *Engine) handleClientMessage(ctx context.Context, sessionCtx *common.Session, clientMessage protocol.Message, clientConn net.Conn, serverConn driver.Connection) error {
e.Log.Debugf("===> %v", clientMessage)
// First check the client command against user's role and log in the audit.
err := e.authorizeClientMessage(sessionCtx, clientMessage)
if err != nil {
return protocol.ReplyError(clientConn, clientMessage, err)
}
// If RBAC is ok, pass the message to the server.
err = serverConn.WriteWireMessage(ctx, clientMessage.GetBytes())
if err != nil {
return trace.Wrap(err)
}
// Some client messages will not receive a reply.
if clientMessage.MoreToCome(nil) {
return nil
}
// Otherwise read the server's reply...
serverMessage, err := protocol.ReadServerMessage(ctx, serverConn)
if err != nil {
return trace.Wrap(err)
}
e.Log.Debugf("<=== %v", serverMessage)
// ... and pass it back to the client.
_, err = clientConn.Write(serverMessage.GetBytes())
if err != nil {
return trace.Wrap(err)
}
// Keep reading if server indicated it has more to send.
for serverMessage.MoreToCome(clientMessage) {
serverMessage, err = protocol.ReadServerMessage(ctx, serverConn)
if err != nil {
return trace.Wrap(err)
}
e.Log.Debugf("<=== %v", serverMessage)
_, err = clientConn.Write(serverMessage.GetBytes())
if err != nil {
return trace.Wrap(err)
}
}
return nil
}
// authorizeConnection does authorization check for MongoDB connection about
// to be established.
func (e *Engine) authorizeConnection(ctx context.Context, sessionCtx *common.Session) error {
ap, err := e.Auth.GetAuthPreference(ctx)
if err != nil {
return trace.Wrap(err)
}
mfaParams := services.AccessMFAParams{
Verified: sessionCtx.Identity.MFAVerified != "",
AlwaysRequired: ap.GetRequireSessionMFA(),
}
// Only the username is checked upon initial connection. MongoDB sends
// database name with each protocol message (for query, update, etc.)
// so it is checked when we receive a message from client.
err = sessionCtx.Checker.CheckAccess(
sessionCtx.Database,
mfaParams,
&services.DatabaseUserMatcher{User: sessionCtx.DatabaseUser},
)
if err != nil {
e.Audit.OnSessionStart(e.Context, sessionCtx, err)
return trace.Wrap(err)
}
return nil
}
// authorizeClientMessage checks if the user can run the provided MongoDB command.
//
// Each MongoDB command contains information about the database it's run in
// so we check it against allowed databases in the user's role.
func (e *Engine) authorizeClientMessage(sessionCtx *common.Session, message protocol.Message) error {
// Each client message should have database information in it.
database, err := message.GetDatabase()
if err != nil {
return trace.Wrap(err)
}
err = e.checkClientMessage(sessionCtx, message, database)
defer e.Audit.OnQuery(e.Context, sessionCtx, common.Query{
Database: database,
Query: message.String(),
Error: err,
})
return trace.Wrap(err)
}
func (e *Engine) checkClientMessage(sessionCtx *common.Session, message protocol.Message, database string) error {
// Legacy OP_KILL_CURSORS command doesn't contain database information.
if _, ok := message.(*protocol.MessageOpKillCursors); ok {
return sessionCtx.Checker.CheckAccess(sessionCtx.Database,
services.AccessMFAParams{Verified: true},
&services.DatabaseUserMatcher{User: sessionCtx.DatabaseUser})
}
// Do not allow certain commands that deal with authentication.
command, err := message.GetCommand()
if err != nil {
return trace.Wrap(err)
}
switch command {
case "authenticate", "saslStart", "saslContinue", "logout":
return trace.AccessDenied("access denied")
}
// Otherwise authorize the command against allowed databases.
return sessionCtx.Checker.CheckAccess(sessionCtx.Database,
services.AccessMFAParams{Verified: true},
role.DatabaseRoleMatchers(
defaults.ProtocolMongoDB,
sessionCtx.DatabaseUser,
database)...)
}
func (e *Engine) replyError(clientConn net.Conn, replyTo protocol.Message, err error) {
errSend := protocol.ReplyError(clientConn, replyTo, err)
if errSend != nil {
e.Log.WithError(errSend).Errorf("Failed to send error message to MongoDB client: %v.", err)
}
}