Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manually construct pop URL when using iampg auth #9538

Merged
merged 7 commits into from Nov 8, 2022
74 changes: 21 additions & 53 deletions pkg/cli/dbconn.go
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/rds"
pop "github.com/gobuffalo/pop/v6"
"github.com/jmoiron/sqlx"
"github.com/luna-duclos/instrumentedsql"
"github.com/pkg/errors"
"github.com/spf13/pflag"
Expand Down Expand Up @@ -315,6 +314,14 @@ func InitDatabase(v *viper.Viper, creds *credentials.Credentials, logger *zap.Lo
make(chan bool))

dbConnectionDetails.Password = passHolder
// pop now use url.QueryEscape on the password, but that
// doesn't work with iampg. If we manually construct the URL,
// we can override that behavior
s := "postgres://%s:%s@%s:%s/%s?%s"
dbConnectionDetails.URL = fmt.Sprintf(s,
dbConnectionDetails.User, dbConnectionDetails.Password,
dbConnectionDetails.Host, dbConnectionDetails.Port,
dbConnectionDetails.Database, dbConnectionDetails.OptionsString(""))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we changed the passHolder to something that does not need query escaping, we wouldn't have to do this assignment to URL, but that seems like leaving a footgun around, even if we comment on it.

}

if dbUseInstrumentedDriver {
Expand Down Expand Up @@ -381,65 +388,26 @@ func InitDatabase(v *viper.Viper, creds *credentials.Credentials, logger *zap.Lo
return nil, err
}

err = testConnection(&dbConnectionDetails, v.GetBool(DbIamFlag), logger)
if err != nil {
logger.Error("Failed to ping database")
return connection, err
}

// Return the open connection
return connection, nil
}

// testConnection tests the connection to determine successful ping
func testConnection(dbConnDetails *pop.ConnectionDetails, useIam bool, logger *zap.Logger) error {
// Copy connection info as we don't want to alter connection info
dbConnectionDetails := pop.ConnectionDetails{
Dialect: "postgres",
Driver: dbConnDetails.Driver,
Database: dbConnDetails.Database,
Host: dbConnDetails.Host,
Port: dbConnDetails.Port,
User: dbConnDetails.User,
Password: dbConnDetails.Password,
Options: dbConnDetails.Options,
Pool: dbConnDetails.Pool,
IdlePool: dbConnDetails.IdlePool,
}

if useIam {
dbConnectionDetails.Password = iampg.GetCurrentPass()
}

// Set up the connection
connection, err := pop.NewConnection(&dbConnectionDetails)
if err != nil {
logger.Error("Failed create DB connection", zap.Error(err))
return err
}

// Open the connection
err = connection.Open()
if err != nil {
logger.Error("Failed to open DB connection", zap.Error(err))
return err
}

// Check the connection
db, err := sqlx.Open(connection.Dialect.Details().Dialect, connection.Dialect.URL())
if err != nil {
logger.Warn("Failed to open DB by driver name", zap.Error(err))
return err
dbWithPinger, ok := connection.Store.(pinger)
if !ok {
logger.Error("Failed to convert to pinger interface")
return nil, errors.New("Failed to convert to pinger interface")
}

// Make the db ping
logger.Info("Starting database ping....")
err = db.Ping()
err = dbWithPinger.Ping()
if err != nil {
logger.Warn("Failed to ping DB connection", zap.Error(err))
return err
return nil, err
}

logger.Info("...DB ping successful!")
return nil

// Return the open connection
return connection, nil
}

type pinger interface {
Ping() error
}