-
Notifications
You must be signed in to change notification settings - Fork 38
/
dbconn.go
413 lines (361 loc) · 13.3 KB
/
dbconn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
package cli
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/XSAM/otelsql"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/rds"
pop "github.com/gobuffalo/pop/v6"
"github.com/luna-duclos/instrumentedsql"
"github.com/pkg/errors"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"go.uber.org/zap"
iampg "github.com/transcom/mymove/pkg/iampostgres"
)
const (
// DbDebugFlag is the DB Debug flag
DbDebugFlag string = "db-debug"
// DbEnvFlag is the DB environment flag
DbEnvFlag string = "db-env"
// DbNameFlag is the DB name flag
DbNameFlag string = "db-name"
// DbHostFlag is the DB host flag
DbHostFlag string = "db-host"
// DbPortFlag is the DB port flag
DbPortFlag string = "db-port"
// DbUserFlag is the DB user flag
DbUserFlag string = "db-user"
// DbPasswordFlag is the DB password flag
DbPasswordFlag string = "db-password"
// DbPoolFlag is the DB pool flag
DbPoolFlag string = "db-pool"
// DbIdlePoolFlag is the DB idle pool flag
DbIdlePoolFlag string = "db-idle-pool"
// DbSSLModeFlag is the DB SSL Mode flag
DbSSLModeFlag string = "db-ssl-mode"
// DbSSLRootCertFlag is the DB SSL Root Cert flag
DbSSLRootCertFlag string = "db-ssl-root-cert"
// DbIamFlag is the DB IAM flag
DbIamFlag string = "db-iam"
// DbIamRoleFlag is the DB IAM Role flag
DbIamRoleFlag string = "db-iam-role"
// DbRegionFlag is the DB Region flag
DbRegionFlag string = "db-region"
// DbUseInstrumentedDriverFlag indicates if additional db
// instrumentation should be done
DbInstrumentedFlag = "db-instrumented"
// DbEnvContainer is the Container DB Env name
DbEnvContainer string = "container"
// DbEnvTest is the Test DB Env name
DbEnvTest string = "test"
// DbEnvDevelopment is the Development DB Env name
DbEnvDevelopment string = "development"
// DbNameTest The name of the test database
DbNameTest string = "test_db"
// SSLModeDisable is the disable SSL Mode
SSLModeDisable string = "disable"
// SSLModeAllow is the allow SSL Mode
SSLModeAllow string = "allow"
// SSLModePrefer is the prefer SSL Mode
SSLModePrefer string = "prefer"
// SSLModeRequire is the require SSL Mode
SSLModeRequire string = "require"
// SSLModeVerifyCA is the verify-ca SSL Mode
SSLModeVerifyCA string = "verify-ca"
// SSLModeVerifyFull is the verify-full SSL Mode
SSLModeVerifyFull string = "verify-full"
// awsRdsT3SmallMaxConnections is the max connections to an RDS T3
// Small instance
//
// The T3 small instance has 2 GB
// https://aws.amazon.com/rds/instance-types/
//
// These docs say we can calculate the max connections
// https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/CHAP_Limits.html
//
// If correct it is
//
// LEAST({DBInstanceClassMemory/9531392}, 5000)
//
// DBInstanceClassMemory = 2147483648
// so 2147483648 / 9531392 = 225.3 which is less than 5000
//
// we deploy two containers for the AWS service, so divide that in
// half
// 225 / 2 =~ 110
awsRdsT3SmallMaxConnections = 110
// DbPoolDefault is the default db pool connections
DbPoolDefault = awsRdsT3SmallMaxConnections
// DbIdlePoolDefault is the default db idle pool connections
DbIdlePoolDefault = 2
// DbPoolMax is the upper limit the db pool can use for connections which constrains the user input
DbPoolMax int = awsRdsT3SmallMaxConnections
)
// The dependency https://github.com/lib/pq only supports a limited subset of SSL Modes and returns the error:
// pq: unsupported sslmode \"prefer\"; only \"require\" (default), \"verify-full\", \"verify-ca\", and \"disable\" supported
// - https://www.postgresql.org/docs/10/libpq-ssl.html
var allSSLModes = []string{
SSLModeDisable,
// SSLModeAllow,
// SSLModePrefer,
SSLModeRequire,
SSLModeVerifyCA,
SSLModeVerifyFull,
}
var containerSSLModes = []string{
SSLModeRequire,
SSLModeVerifyCA,
SSLModeVerifyFull,
}
var allDbEnvs = []string{
DbEnvContainer,
DbEnvTest,
DbEnvDevelopment,
}
type errInvalidDbPool struct {
DbPool int
}
func (e *errInvalidDbPool) Error() string {
return fmt.Sprintf("invalid db pool of %d. Pool must be greater than 0 and less than or equal to %d", e.DbPool, DbPoolMax)
}
type errInvalidDbIdlePool struct {
DbPool int
DbIdlePool int
}
func (e *errInvalidDbIdlePool) Error() string {
return fmt.Sprintf("invalid db idle pool of %d. Pool must be greater than 0 and less than or equal to %d", e.DbIdlePool, e.DbPool)
}
type errInvalidDbEnv struct {
Value string
DbEnvs []string
}
func (e *errInvalidDbEnv) Error() string {
return fmt.Sprintf("invalid db env %s, must be one of: ", e.Value) + strings.Join(e.DbEnvs, ", ")
}
type errInvalidSSLMode struct {
Mode string
Modes []string
}
func (e *errInvalidSSLMode) Error() string {
return fmt.Sprintf("invalid ssl mode %s, must be one of: "+strings.Join(e.Modes, ", "), e.Mode)
}
// InitDatabaseFlags initializes DB command line flags
func InitDatabaseFlags(flag *pflag.FlagSet) {
flag.String(DbEnvFlag, DbEnvDevelopment, "database environment: "+strings.Join(allDbEnvs, ", "))
flag.String(DbNameFlag, "dev_db", "Database Name")
flag.String(DbHostFlag, "localhost", "Database Hostname")
flag.Int(DbPortFlag, 5432, "Database Port")
flag.String(DbUserFlag, "crud", "Database Username")
flag.String(DbPasswordFlag, "", "Database Password")
flag.Int(DbPoolFlag, DbPoolDefault, "Database Pool or max DB connections")
flag.Int(DbIdlePoolFlag, DbIdlePoolDefault, "Database Idle Pool or max DB idle connections")
flag.String(DbSSLModeFlag, SSLModeDisable, "Database SSL Mode: "+strings.Join(allSSLModes, ", "))
flag.String(DbSSLRootCertFlag, "", "Path to the database root certificate file used for database connections")
flag.Bool(DbDebugFlag, false, "Set Pop to debug mode")
flag.Bool(DbIamFlag, false, "Use AWS IAM authentication")
flag.String(DbIamRoleFlag, "", "The arn of the AWS IAM role to assume when connecting to the database.")
// Required by https://docs.aws.amazon.com/sdk-for-go/api/service/rds/rdsutils/#BuildAuthToken
flag.String(DbRegionFlag, "", "AWS Region of the database")
flag.Bool(DbInstrumentedFlag, false, "Use instrumented db driver")
}
// CheckDatabase validates DB command line flags
func CheckDatabase(v *viper.Viper, logger *zap.Logger) error {
if err := ValidateHost(v, DbHostFlag); err != nil {
return err
}
if err := ValidatePort(v, DbPortFlag); err != nil {
return err
}
dbPool := v.GetInt(DbPoolFlag)
dbIdlePool := v.GetInt(DbIdlePoolFlag)
if dbPool < 1 || dbPool > DbPoolMax {
return &errInvalidDbPool{DbPool: dbPool}
}
if dbIdlePool > dbPool {
return &errInvalidDbIdlePool{DbPool: dbPool, DbIdlePool: dbIdlePool}
}
dbEnv := v.GetString(DbEnvFlag)
if !stringSliceContains(allDbEnvs, dbEnv) {
return &errInvalidDbEnv{Value: dbEnv, DbEnvs: allDbEnvs}
}
sslMode := v.GetString(DbSSLModeFlag)
if len(sslMode) == 0 || !stringSliceContains(allSSLModes, sslMode) {
return &errInvalidSSLMode{Mode: sslMode, Modes: allSSLModes}
}
if dbEnv == DbEnvContainer && !stringSliceContains(containerSSLModes, sslMode) {
return errors.Wrap(&errInvalidSSLMode{Mode: sslMode, Modes: containerSSLModes}, "container db env requires SSL connection to the database")
} else if dbEnv != DbEnvContainer && !stringSliceContains(allSSLModes, sslMode) {
return &errInvalidSSLMode{Mode: sslMode, Modes: allSSLModes}
}
if filename := v.GetString(DbSSLRootCertFlag); len(filename) > 0 {
b, err := os.ReadFile(filepath.Clean(filename))
if err != nil {
return errors.Wrap(err, fmt.Sprintf("error reading %s at %q", DbSSLRootCertFlag, filename))
}
tlsCerts := ParseCertificates(string(b))
logger.Debug(fmt.Sprintf("certificate chain from %s parsed", DbSSLRootCertFlag), zap.Any("count", len(tlsCerts)))
}
// Check IAM Authentication
if v.GetBool(DbIamFlag) {
// DbRegionFlag must be set if IAM authentication is enabled.
dbRegion := v.GetString(DbRegionFlag)
if err := CheckAWSRegionForService(dbRegion, rds.ServiceName); err != nil {
return errors.Wrap(err, fmt.Sprintf("'%q' is invalid for service %s", DbRegionFlag, rds.ServiceName))
}
dbIamRole := v.GetString(DbIamRoleFlag)
if len(dbIamRole) == 0 {
return errors.New("database IAM role not provided")
}
}
return nil
}
// InitDatabase initializes a Pop connection from command line flags.
// v is the viper Configuration.
// creds must relate to an assumed role and can't point to a user or task role directly.
// logger is the application logger.
func InitDatabase(v *viper.Viper, creds *credentials.Credentials, logger *zap.Logger) (*pop.Connection, error) {
dbEnv := v.GetString(DbEnvFlag)
dbName := v.GetString(DbNameFlag)
dbHost := v.GetString(DbHostFlag)
dbPort := strconv.Itoa(v.GetInt(DbPortFlag))
dbUser := v.GetString(DbUserFlag)
dbPassword := v.GetString(DbPasswordFlag)
dbPool := v.GetInt(DbPoolFlag)
dbIdlePool := v.GetInt(DbIdlePoolFlag)
dbUseInstrumentedDriver := v.GetBool(DbInstrumentedFlag)
// Modify DB options by environment
dbOptions := map[string]string{
"sslmode": v.GetString(DbSSLModeFlag),
}
if dbEnv == DbEnvTest {
// Leave the test database name hardcoded, since we run tests in the same
// environment as development, and it's extra confusing to have to swap environment
// variables before running tests.
dbName = DbNameTest
}
if str := v.GetString(DbSSLRootCertFlag); len(str) > 0 {
dbOptions["sslrootcert"] = str
}
// Construct a safe URL and log it
s := "postgres://%s:%s@%s:%s/%s?sslmode=%s"
dbURL := fmt.Sprintf(s, dbUser, "*****", dbHost, dbPort, dbName, dbOptions["sslmode"])
logger.Info("Connecting to the database", zap.String("url", dbURL), zap.String(DbSSLRootCertFlag, v.GetString(DbSSLRootCertFlag)))
// Configure DB connection details
dbConnectionDetails := pop.ConnectionDetails{
Dialect: "postgres",
Driver: iampg.CustomPostgres,
Database: dbName,
Host: dbHost,
Port: dbPort,
User: dbUser,
Password: dbPassword,
Options: dbOptions,
Pool: dbPool,
IdlePool: dbIdlePool,
}
if v.GetBool(DbIamFlag) {
// Set a bogus password holder. It will be replaced with an RDS auth token as the password.
passHolder := "*****"
iampg.EnableIAM(dbConnectionDetails.Host,
dbConnectionDetails.Port,
v.GetString(DbRegionFlag),
dbConnectionDetails.User,
passHolder,
creds,
iampg.RDSU{},
time.NewTicker(10*time.Minute), // Refresh every 10 minutes
logger,
make(chan bool))
dbConnectionDetails.Password = passHolder
// pop now use url.QueryEscape on the password, but that
// doesn't work with iampg. If we manually construct the URL,
// we can override that behavior
s := "postgres://%s:%s@%s:%s/%s?%s"
dbConnectionDetails.URL = fmt.Sprintf(s,
dbConnectionDetails.User, dbConnectionDetails.Password,
dbConnectionDetails.Host, dbConnectionDetails.Port,
dbConnectionDetails.Database, dbConnectionDetails.OptionsString(""))
}
if dbUseInstrumentedDriver {
// to fake pop out, we need to register the otelsql instrumented
// driver under the driverName that pop would use. To do that,
// we need to get the otelsql driver.Driver, which is easiest
// to get from sql.DB.Driver()
db, err := sql.Open(dbConnectionDetails.Driver, "")
if err != nil {
logger.Error("Failed opening uninstrumented connection", zap.Error(err))
return nil, err
}
currentDriver := db.Driver()
err = db.Close()
if err != nil {
logger.Error("Failed closing uninstrumented connection", zap.Error(err))
return nil, err
}
// This is the name from pop's instrumented connection code
// https://github.com/gobuffalo/pop/blob/master/connection_instrumented.go#L44
popInstrumentedDriverName := "instrumented-sql-driver-postgres"
// and we're going to fake out pop with the Driver so that the
// driver name matches what pop is looking for, but it will
// wind up using the desired driver under a wrapped otelsql connection
dbConnectionDetails.Driver = "postgres"
spanOptions := otelsql.SpanOptions{
Ping: true,
RowsNext: v.GetBool(DbDebugFlag),
}
sql.Register(popInstrumentedDriverName,
otelsql.WrapDriver(currentDriver,
otelsql.WithSpanOptions(spanOptions)))
// now we can update the connection details to indicate we
// want an instrumented connection
dbConnectionDetails.UseInstrumentedDriver = true
// pop expects at least one option when using instrumented
// sql, but the options will be ignored since we are faking
// things out
dbConnectionDetails.InstrumentedDriverOptions = []instrumentedsql.Opt{
instrumentedsql.WithOmitArgs(),
}
logger.Info("Using otelsql instrumented sql driver")
}
err := dbConnectionDetails.Finalize()
if err != nil {
logger.Error("Failed to finalize DB connection details", zap.Error(err))
return nil, err
}
// Set up the connection
connection, err := pop.NewConnection(&dbConnectionDetails)
if err != nil {
logger.Error("Failed create DB connection", zap.Error(err))
return nil, err
}
// Open the connection - required
err = connection.Open()
if err != nil {
logger.Error("Failed to open DB connection", zap.Error(err))
return nil, err
}
dbWithPinger, ok := connection.Store.(pinger)
if !ok {
logger.Error("Failed to convert to pinger interface")
return nil, errors.New("Failed to convert to pinger interface")
}
// Make the db ping
logger.Info("Starting database ping....")
err = dbWithPinger.Ping()
if err != nil {
logger.Warn("Failed to ping DB connection", zap.Error(err))
return nil, err
}
logger.Info("...DB ping successful!")
// Return the open connection
return connection, nil
}
type pinger interface {
Ping() error
}