From c98612b33a7c517dbd37c0fc973ee0a842ace859 Mon Sep 17 00:00:00 2001 From: Jesse White Date: Sat, 6 Apr 2024 11:26:36 -0400 Subject: [PATCH 1/5] add support for AWS IAM authentication in the postgres datastore fixes #659 --- e2e/go.mod | 14 ++++ e2e/go.sum | 28 +++++++ go.mod | 14 ++++ go.sum | 28 +++++++ .../datastore/postgres/migrations/driver.go | 16 +++- internal/datastore/postgres/options.go | 8 ++ internal/datastore/postgres/postgres.go | 21 +++++- internal/testserver/datastore/postgres.go | 2 +- pkg/cmd/datastore/datastore.go | 3 + pkg/cmd/datastore/zz_generated.options.go | 9 +++ pkg/cmd/migrate.go | 10 ++- pkg/datastore/credentials.go | 73 +++++++++++++++++++ pkg/datastore/credentials_test.go | 19 +++++ 13 files changed, 237 insertions(+), 8 deletions(-) create mode 100644 pkg/datastore/credentials.go create mode 100644 pkg/datastore/credentials_test.go diff --git a/e2e/go.mod b/e2e/go.mod index 348639f7d6..04313f7554 100644 --- a/e2e/go.mod +++ b/e2e/go.mod @@ -18,6 +18,20 @@ require ( require ( github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230512164433-5d1fd1a340c9 // indirect github.com/authzed/cel-go v0.17.5 // indirect + github.com/aws/aws-sdk-go-v2 v1.26.1 // indirect + github.com/aws/aws-sdk-go-v2/config v1.27.11 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.11 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 // indirect + github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.4.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 // indirect + github.com/aws/smithy-go v1.20.2 // indirect github.com/benbjohnson/clock v1.3.5 // indirect github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d // indirect github.com/creasty/defaults v1.7.0 // indirect diff --git a/e2e/go.sum b/e2e/go.sum index 82fac4f994..b532d68ea0 100644 --- a/e2e/go.sum +++ b/e2e/go.sum @@ -28,6 +28,34 @@ github.com/authzed/cel-go v0.17.5 h1:lfpkNrR99B5QRHg5qdG9oLu/kguVlZC68VJuMk8tH9Y github.com/authzed/cel-go v0.17.5/go.mod h1:XL/zEq5hKGVF8aOdMbG7w+BQPihLjY2W8N+UIygDA2I= github.com/authzed/grpcutil v0.0.0-20240123092924-129dc0a6a6e1 h1:zBfQzia6Hz45pJBeURTrv1b6HezmejB6UmiGuBilHZM= github.com/authzed/grpcutil v0.0.0-20240123092924-129dc0a6a6e1/go.mod h1:s3qC7V7XIbiNWERv7Lfljy/Lx25/V1Qlexb0WJuA8uQ= +github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= +github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= +github.com/aws/aws-sdk-go-v2/config v1.27.11 h1:f47rANd2LQEYHda2ddSCKYId18/8BhSRM4BULGmfgNA= +github.com/aws/aws-sdk-go-v2/config v1.27.11/go.mod h1:SMsV78RIOYdve1vf36z8LmnszlRWkwMQtomCAI0/mIE= +github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs= +github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 h1:FVJ0r5XTHSmIHJV6KuDmdYhEpvlHpiSd38RQWhut5J4= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1/go.mod h1:zusuAeqezXzAB24LGuzuekqMAEgWkVYukBec3kr3jUg= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.4.5 h1:Jm5og3wZoeKE1fkRkp/zT53vsOAZl3cR5FJ9JRNuIgQ= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.4.5/go.mod h1:RI6PT6IXi7wmGtuRDfc8gmqMsYzTyz+py0cvLw0itck= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1xUsUr3I8cHps0G+XM3WWU16lP6yG8qu1GAZAs= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 h1:ogRAwT1/gxJBcSWDMZlgyFUM962F51A5CRhDLbxLdmo= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7/go.mod h1:YCsIZhXfRPLFFCl5xxY+1T9RKzOKjCut+28JSX2DnAk= +github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 h1:vN8hEbpRnL7+Hopy9dzmRle1xmDc7o8tmY0klsr175w= +github.com/aws/aws-sdk-go-v2/service/sso v1.20.5/go.mod h1:qGzynb/msuZIE8I75DVRCUXw3o3ZyBmUvMwQ2t/BrGM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 h1:Jux+gDDyi1Lruk+KHF91tK2KCuY61kzoCpvtvJJBtOE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4/go.mod h1:mUYPBhaF2lGiukDEjJX2BLRRKTmoUSitGDUgM4tRxak= +github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 h1:cwIxeBttqPN3qkaAjcEcsh8NYr8n2HZPkcKgPAi1phU= +github.com/aws/aws-sdk-go-v2/service/sts v1.28.6/go.mod h1:FZf1/nKNEkHdGGJP/cI2MoIMquumuRK6ol3QQJNDxmw= +github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= +github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= diff --git a/go.mod b/go.mod index 44bf12c086..3b9f5b26ee 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,9 @@ require ( github.com/authzed/consistent v0.1.0 github.com/authzed/grpcutil v0.0.0-20240123092924-129dc0a6a6e1 github.com/aws/aws-sdk-go v1.51.11 + github.com/aws/aws-sdk-go-v2 v1.26.1 + github.com/aws/aws-sdk-go-v2/config v1.27.11 + github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.4.5 github.com/benbjohnson/clock v1.3.5 github.com/bits-and-blooms/bloom/v3 v3.7.0 github.com/cenkalti/backoff/v4 v4.3.0 @@ -99,6 +102,17 @@ require ( ) require ( + github.com/aws/aws-sdk-go-v2/credentials v1.17.11 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 // indirect + github.com/aws/smithy-go v1.20.2 // indirect github.com/bombsimon/wsl/v4 v4.2.1 // indirect github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1 // indirect github.com/jjti/go-spancheck v0.5.3 // indirect diff --git a/go.sum b/go.sum index 3d33272b6f..25ee3116a2 100644 --- a/go.sum +++ b/go.sum @@ -124,6 +124,34 @@ github.com/authzed/grpcutil v0.0.0-20240123092924-129dc0a6a6e1/go.mod h1:s3qC7V7 github.com/aws/aws-sdk-go v1.44.256/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= github.com/aws/aws-sdk-go v1.51.11 h1:El5VypsMIz7sFwAAj/j06JX9UGs4KAbAIEaZ57bNY4s= github.com/aws/aws-sdk-go v1.51.11/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= +github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= +github.com/aws/aws-sdk-go-v2/config v1.27.11 h1:f47rANd2LQEYHda2ddSCKYId18/8BhSRM4BULGmfgNA= +github.com/aws/aws-sdk-go-v2/config v1.27.11/go.mod h1:SMsV78RIOYdve1vf36z8LmnszlRWkwMQtomCAI0/mIE= +github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs= +github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 h1:FVJ0r5XTHSmIHJV6KuDmdYhEpvlHpiSd38RQWhut5J4= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1/go.mod h1:zusuAeqezXzAB24LGuzuekqMAEgWkVYukBec3kr3jUg= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.4.5 h1:Jm5og3wZoeKE1fkRkp/zT53vsOAZl3cR5FJ9JRNuIgQ= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.4.5/go.mod h1:RI6PT6IXi7wmGtuRDfc8gmqMsYzTyz+py0cvLw0itck= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1xUsUr3I8cHps0G+XM3WWU16lP6yG8qu1GAZAs= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 h1:ogRAwT1/gxJBcSWDMZlgyFUM962F51A5CRhDLbxLdmo= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7/go.mod h1:YCsIZhXfRPLFFCl5xxY+1T9RKzOKjCut+28JSX2DnAk= +github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 h1:vN8hEbpRnL7+Hopy9dzmRle1xmDc7o8tmY0klsr175w= +github.com/aws/aws-sdk-go-v2/service/sso v1.20.5/go.mod h1:qGzynb/msuZIE8I75DVRCUXw3o3ZyBmUvMwQ2t/BrGM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 h1:Jux+gDDyi1Lruk+KHF91tK2KCuY61kzoCpvtvJJBtOE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4/go.mod h1:mUYPBhaF2lGiukDEjJX2BLRRKTmoUSitGDUgM4tRxak= +github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 h1:cwIxeBttqPN3qkaAjcEcsh8NYr8n2HZPkcKgPAi1phU= +github.com/aws/aws-sdk-go-v2/service/sts v1.28.6/go.mod h1:FZf1/nKNEkHdGGJP/cI2MoIMquumuRK6ol3QQJNDxmw= +github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= +github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= diff --git a/internal/datastore/postgres/migrations/driver.go b/internal/datastore/postgres/migrations/driver.go index 1c6ea2ce1c..23b03a1326 100644 --- a/internal/datastore/postgres/migrations/driver.go +++ b/internal/datastore/postgres/migrations/driver.go @@ -5,12 +5,13 @@ import ( "errors" "fmt" - pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" - "github.com/authzed/spicedb/pkg/migrate" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "go.opentelemetry.io/otel" + + pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/migrate" ) const postgresMissingTableErrorCode = "42P01" @@ -26,7 +27,7 @@ type AlembicPostgresDriver struct { } // NewAlembicPostgresDriver creates a new driver with active connections to the database specified. -func NewAlembicPostgresDriver(ctx context.Context, url string) (*AlembicPostgresDriver, error) { +func NewAlembicPostgresDriver(ctx context.Context, url string, credentialsProvider datastore.CredentialsProvider) (*AlembicPostgresDriver, error) { ctx, span := tracer.Start(ctx, "NewAlembicPostgresDriver") defer span.End() @@ -37,6 +38,13 @@ func NewAlembicPostgresDriver(ctx context.Context, url string) (*AlembicPostgres pgxcommon.ConfigurePGXLogger(connConfig) pgxcommon.ConfigureOTELTracer(connConfig) + if credentialsProvider != nil { + connConfig.User, connConfig.Password, err = credentialsProvider.Get(ctx, connConfig.Host, connConfig.Port, connConfig.User) + if err != nil { + return nil, err + } + } + db, err := pgx.ConnectConfig(ctx, connConfig) if err != nil { return nil, err diff --git a/internal/datastore/postgres/options.go b/internal/datastore/postgres/options.go index b10e6dd231..bc3ea935ec 100644 --- a/internal/datastore/postgres/options.go +++ b/internal/datastore/postgres/options.go @@ -12,6 +12,8 @@ type postgresOptions struct { maxRevisionStalenessPercent float64 + credentialsProviderName string + watchBufferLength uint16 watchBufferWriteTimeout time.Duration revisionQuantization time.Duration @@ -58,6 +60,7 @@ const ( defaultEnablePrometheusStats = false defaultMaxRetries = 10 defaultGCEnabled = true + defaultCredentialsProviderName = "" ) // Option provides the facility to configure how clients within the @@ -76,6 +79,7 @@ func generateConfig(options []Option) (postgresOptions, error) { enablePrometheusStats: defaultEnablePrometheusStats, maxRetries: defaultMaxRetries, gcEnabled: defaultGCEnabled, + credentialsProviderName: defaultCredentialsProviderName, queryInterceptor: nil, } @@ -332,3 +336,7 @@ func WithQueryInterceptor(interceptor pgxcommon.QueryInterceptor) Option { func MigrationPhase(phase string) Option { return func(po *postgresOptions) { po.migrationPhase = phase } } + +func CredentialsProviderName(credentialsProviderName string) Option { + return func(po *postgresOptions) { po.credentialsProviderName = credentialsProviderName } +} diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index 6cd9015f56..76b1e7704e 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -146,6 +146,12 @@ func newPostgresDatastore( return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, pgURL) } + // Setup the credential provider + credentialsProvider, err := datastore.NewCredentialsProvider(ctx, config.credentialsProviderName) + if err != nil { + return nil, err + } + // Setup the config for each of the read and write pools. readPoolConfig := pgConfig.Copy() config.readPoolOpts.ConfigurePgx(readPoolConfig) @@ -163,6 +169,16 @@ func newPostgresDatastore( return nil } + if credentialsProvider != nil { + // add before connect callbacks to trigger the token + getToken := func(ctx context.Context, config *pgx.ConnConfig) error { + config.User, config.Password, err = credentialsProvider.Get(ctx, config.Host, config.Port, config.User) + return err + } + readPoolConfig.BeforeConnect = getToken + writePoolConfig.BeforeConnect = getToken + } + if config.migrationPhase != "" { log.Info(). Str("phase", config.migrationPhase). @@ -260,6 +276,7 @@ func newPostgresDatastore( cancelGc: cancelGc, readTxOptions: pgx.TxOptions{IsoLevel: pgx.RepeatableRead, AccessMode: pgx.ReadOnly}, maxRetries: config.maxRetries, + credentialsProvider: credentialsProvider, } datastore.SetOptimizedRevisionFunc(datastore.optimizedRevisionFunc) @@ -300,6 +317,8 @@ type pgDatastore struct { maxRetries uint8 watchEnabled bool + credentialsProvider datastore.CredentialsProvider + gcGroup *errgroup.Group gcCtx context.Context cancelGc context.CancelFunc @@ -534,7 +553,7 @@ func (pgd *pgDatastore) ReadyState(ctx context.Context) (datastore.ReadyState, e return datastore.ReadyState{}, fmt.Errorf("invalid head migration found for postgres: %w", err) } - pgDriver, err := migrations.NewAlembicPostgresDriver(ctx, pgd.dburl) + pgDriver, err := migrations.NewAlembicPostgresDriver(ctx, pgd.dburl, pgd.credentialsProvider) if err != nil { return datastore.ReadyState{}, err } diff --git a/internal/testserver/datastore/postgres.go b/internal/testserver/datastore/postgres.go index 2915b467cb..b4242547de 100644 --- a/internal/testserver/datastore/postgres.go +++ b/internal/testserver/datastore/postgres.go @@ -137,7 +137,7 @@ func (b *postgresTester) NewDatastore(t testing.TB, initFunc InitFunc) datastore for i := 0; i < retryCount; i++ { connectStr := b.NewDatabase(t) - migrationDriver, err := pgmigrations.NewAlembicPostgresDriver(context.Background(), connectStr) + migrationDriver, err := pgmigrations.NewAlembicPostgresDriver(context.Background(), connectStr, nil) if err == nil { ctx := context.WithValue(context.Background(), migrate.BackfillBatchSize, uint64(1000)) require.NoError(t, pgmigrations.DatabaseMigrations.Run(ctx, migrationDriver, b.targetMigration, migrate.LiveRun)) diff --git a/pkg/cmd/datastore/datastore.go b/pkg/cmd/datastore/datastore.go index 2fd3eaa7c7..ad637a727b 100644 --- a/pkg/cmd/datastore/datastore.go +++ b/pkg/cmd/datastore/datastore.go @@ -99,6 +99,7 @@ type Config struct { LegacyFuzzing time.Duration `debugmap:"visible"` RevisionQuantization time.Duration `debugmap:"visible"` MaxRevisionStalenessPercent float64 `debugmap:"visible"` + CredentialsProviderName string `debugmap:"visible"` // Options ReadConnPool ConnPoolConfig `debugmap:"visible"` @@ -166,6 +167,7 @@ func RegisterDatastoreFlagsWithPrefix(flagSet *pflag.FlagSet, prefix string, opt flagSet.StringVar(&opts.Engine, flagName("datastore-engine"), defaults.Engine, fmt.Sprintf(`type of datastore to initialize (%s)`, datastore.EngineOptions())) flagSet.StringVar(&opts.URI, flagName("datastore-conn-uri"), defaults.URI, `connection string used by remote datastores (e.g. "postgres://postgres:password@localhost:5432/spicedb")`) + flagSet.StringVar(&opts.CredentialsProviderName, flagName("datastore-credentials-provider-name"), defaults.CredentialsProviderName, fmt.Sprintf(`retrieve datastore credentials dynamically using (%s)`, datastore.CredentialsProviderOptions())) var legacyConnPool ConnPoolConfig RegisterConnPoolFlagsWithPrefix(flagSet, "datastore-conn", DefaultReadConnPool(), &legacyConnPool) @@ -390,6 +392,7 @@ func newCRDBDatastore(ctx context.Context, opts Config) (datastore.Datastore, er func newPostgresDatastore(ctx context.Context, opts Config) (datastore.Datastore, error) { pgOpts := []postgres.Option{ + postgres.CredentialsProviderName(opts.CredentialsProviderName), postgres.GCWindow(opts.GCWindow), postgres.GCEnabled(!opts.ReadOnly), postgres.RevisionQuantization(opts.RevisionQuantization), diff --git a/pkg/cmd/datastore/zz_generated.options.go b/pkg/cmd/datastore/zz_generated.options.go index fbb25752f3..e3457544c4 100644 --- a/pkg/cmd/datastore/zz_generated.options.go +++ b/pkg/cmd/datastore/zz_generated.options.go @@ -37,6 +37,7 @@ func (c *Config) ToOption() ConfigOption { to.LegacyFuzzing = c.LegacyFuzzing to.RevisionQuantization = c.RevisionQuantization to.MaxRevisionStalenessPercent = c.MaxRevisionStalenessPercent + to.CredentialsProviderName = c.CredentialsProviderName to.ReadConnPool = c.ReadConnPool to.WriteConnPool = c.WriteConnPool to.ReadOnly = c.ReadOnly @@ -78,6 +79,7 @@ func (c Config) DebugMap() map[string]any { debugMap["LegacyFuzzing"] = helpers.DebugValue(c.LegacyFuzzing, false) debugMap["RevisionQuantization"] = helpers.DebugValue(c.RevisionQuantization, false) debugMap["MaxRevisionStalenessPercent"] = helpers.DebugValue(c.MaxRevisionStalenessPercent, false) + debugMap["CredentialsProviderName"] = helpers.DebugValue(c.CredentialsProviderName, false) debugMap["ReadConnPool"] = helpers.DebugValue(c.ReadConnPool, false) debugMap["WriteConnPool"] = helpers.DebugValue(c.WriteConnPool, false) debugMap["ReadOnly"] = helpers.DebugValue(c.ReadOnly, false) @@ -168,6 +170,13 @@ func WithMaxRevisionStalenessPercent(maxRevisionStalenessPercent float64) Config } } +// WithCredentialsProviderName returns an option that can set CredentialsProviderName on a Config +func WithCredentialsProviderName(credentialsProviderName string) ConfigOption { + return func(c *Config) { + c.CredentialsProviderName = credentialsProviderName + } +} + // WithReadConnPool returns an option that can set ReadConnPool on a Config func WithReadConnPool(readConnPool ConnPoolConfig) ConfigOption { return func(c *Config) { diff --git a/pkg/cmd/migrate.go b/pkg/cmd/migrate.go index 17c30d0c7e..00c50121fe 100644 --- a/pkg/cmd/migrate.go +++ b/pkg/cmd/migrate.go @@ -23,6 +23,7 @@ import ( func RegisterMigrateFlags(cmd *cobra.Command) { cmd.Flags().String("datastore-engine", "memory", fmt.Sprintf(`type of datastore to initialize (%s)`, datastore.EngineOptions())) cmd.Flags().String("datastore-conn-uri", "", `connection string used by remote datastores (e.g. "postgres://postgres:password@localhost:5432/spicedb")`) + cmd.Flags().String("datastore-credentials-provider-name", "", fmt.Sprintf(`retrieve datastore credentials dynamically using (%s)`, datastore.CredentialsProviderOptions())) cmd.Flags().String("datastore-spanner-credentials", "", "path to service account key credentials file with access to the cloud spanner instance (omit to use application default credentials)") cmd.Flags().String("datastore-spanner-emulator-host", "", "URI of spanner emulator instance used for development and testing (e.g. localhost:9010)") cmd.Flags().String("datastore-mysql-table-prefix", "", "prefix to add to the name of all mysql database tables") @@ -59,8 +60,13 @@ func migrateRun(cmd *cobra.Command, args []string) error { } else if datastoreEngine == "postgres" { log.Ctx(cmd.Context()).Info().Msg("migrating postgres datastore") - var err error - migrationDriver, err := migrations.NewAlembicPostgresDriver(cmd.Context(), dbURL) + credentialsProviderName := cobrautil.MustGetString(cmd, "datastore-credentials-provider-name") + credentialsProvider, err := datastore.NewCredentialsProvider(cmd.Context(), credentialsProviderName) + if err != nil { + return err + } + + migrationDriver, err := migrations.NewAlembicPostgresDriver(cmd.Context(), dbURL, credentialsProvider) if err != nil { return fmt.Errorf("unable to create migration driver for %s: %w", datastoreEngine, err) } diff --git a/pkg/datastore/credentials.go b/pkg/datastore/credentials.go new file mode 100644 index 0000000000..79f1ae9b52 --- /dev/null +++ b/pkg/datastore/credentials.go @@ -0,0 +1,73 @@ +package datastore + +import ( + "context" + "fmt" + "sort" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + rdsauth "github.com/aws/aws-sdk-go-v2/feature/rds/auth" + "golang.org/x/exp/maps" + + log "github.com/authzed/spicedb/internal/logging" +) + +type CredentialsProvider interface { + // Get return the username and password to use when connecting to the database + Get(ctx context.Context, dbHostname string, dbPort uint16, dbUser string) (string, string, error) +} + +type credentialsProviderBuilderFunc func(ctx context.Context) (CredentialsProvider, error) + +const ( + AWSIAMCredentialProvider = "aws-iam" +) + +var BuilderForCredentialProvider = map[string]credentialsProviderBuilderFunc{ + AWSIAMCredentialProvider: newAWSIAMCredentialsProvider, +} + +// CredentialsProviderOptions returns the full set of credential provider names, sorted and quoted into a string. +func CredentialsProviderOptions() string { + ids := maps.Keys(BuilderForCredentialProvider) + sort.Strings(ids) + quoted := make([]string, 0, len(ids)) + for _, id := range ids { + quoted = append(quoted, `"`+id+`"`) + } + return strings.Join(quoted, ", ") +} + +// NewCredentialsProvider create a new CredentialsProvider for the given name +// return nil if no match is found +// return an error if there is a problem initializing the given CredentialsProvider +func NewCredentialsProvider(ctx context.Context, name string) (CredentialsProvider, error) { + builder, ok := BuilderForCredentialProvider[name] + if !ok { + return nil, nil + } + return builder(ctx) +} + +// AWS IAM provider + +func newAWSIAMCredentialsProvider(ctx context.Context) (CredentialsProvider, error) { + awsSdkConfig, err := awsconfig.LoadDefaultConfig(ctx) + if err != nil { + return nil, err + } + return &awsIamCredentialsProvider{awsSdkConfig: awsSdkConfig}, nil +} + +type awsIamCredentialsProvider struct { + awsSdkConfig aws.Config +} + +func (d awsIamCredentialsProvider) Get(ctx context.Context, dbHostname string, dbPort uint16, dbUser string) (string, string, error) { + dbEndpoint := fmt.Sprintf("%s:%d", dbHostname, dbPort) + authToken, err := rdsauth.BuildAuthToken(ctx, dbEndpoint, d.awsSdkConfig.Region, dbUser, d.awsSdkConfig.Credentials) + log.Ctx(ctx).Trace().Str("region", d.awsSdkConfig.Region).Str("endpoint", dbEndpoint).Str("user", dbUser).Msg("successfully retrieved IAM auth token for DB") + return dbUser, authToken, err +} diff --git a/pkg/datastore/credentials_test.go b/pkg/datastore/credentials_test.go new file mode 100644 index 0000000000..158269b6ae --- /dev/null +++ b/pkg/datastore/credentials_test.go @@ -0,0 +1,19 @@ +package datastore + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewCredentialsProvider(t *testing.T) { + unknownCredentialsProviders := []string{"", "some-unknown-credentials-provider"} + for _, unknownCredentialsProvider := range unknownCredentialsProviders { + t.Run(unknownCredentialsProvider, func(t *testing.T) { + credentialsProvider, err := NewCredentialsProvider(context.Background(), unknownCredentialsProvider) + require.Nil(t, credentialsProvider) + require.NoError(t, err) + }) + } +} From e6905b4fcab4ffa393ef716ff51dd9d4901a43b4 Mon Sep 17 00:00:00 2001 From: Jesse White Date: Sat, 13 Apr 2024 15:12:15 -0400 Subject: [PATCH 2/5] review feedback & more unit testing --- .../datastore/postgres/migrations/driver.go | 3 +++ internal/testserver/datastore/postgres.go | 2 +- pkg/datastore/credentials.go | 16 ++++++++++++-- pkg/datastore/credentials_test.go | 21 ++++++++++++++++++- 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/internal/datastore/postgres/migrations/driver.go b/internal/datastore/postgres/migrations/driver.go index 23b03a1326..9fe11d2164 100644 --- a/internal/datastore/postgres/migrations/driver.go +++ b/internal/datastore/postgres/migrations/driver.go @@ -9,6 +9,8 @@ import ( "github.com/jackc/pgx/v5/pgconn" "go.opentelemetry.io/otel" + log "github.com/authzed/spicedb/internal/logging" + pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/migrate" @@ -39,6 +41,7 @@ func NewAlembicPostgresDriver(ctx context.Context, url string, credentialsProvid pgxcommon.ConfigureOTELTracer(connConfig) if credentialsProvider != nil { + log.Ctx(ctx).Debug().Str("name", credentialsProvider.Name()).Msg("using credentials provider") connConfig.User, connConfig.Password, err = credentialsProvider.Get(ctx, connConfig.Host, connConfig.Port, connConfig.User) if err != nil { return nil, err diff --git a/internal/testserver/datastore/postgres.go b/internal/testserver/datastore/postgres.go index b4242547de..ab48f0f3ab 100644 --- a/internal/testserver/datastore/postgres.go +++ b/internal/testserver/datastore/postgres.go @@ -137,7 +137,7 @@ func (b *postgresTester) NewDatastore(t testing.TB, initFunc InitFunc) datastore for i := 0; i < retryCount; i++ { connectStr := b.NewDatabase(t) - migrationDriver, err := pgmigrations.NewAlembicPostgresDriver(context.Background(), connectStr, nil) + migrationDriver, err := pgmigrations.NewAlembicPostgresDriver(context.Background(), connectStr, datastore.NoCredentialsProvider) if err == nil { ctx := context.WithValue(context.Background(), migrate.BackfillBatchSize, uint64(1000)) require.NoError(t, pgmigrations.DatabaseMigrations.Run(ctx, migrationDriver, b.targetMigration, migrate.LiveRun)) diff --git a/pkg/datastore/credentials.go b/pkg/datastore/credentials.go index 79f1ae9b52..7b3bfa4a9d 100644 --- a/pkg/datastore/credentials.go +++ b/pkg/datastore/credentials.go @@ -14,14 +14,20 @@ import ( log "github.com/authzed/spicedb/internal/logging" ) +// CredentialsProvider allows datastore credentials to be retrieved dynamically type CredentialsProvider interface { - // Get return the username and password to use when connecting to the database + // Name return the name of the provider + Name() string + // Get return the username and password to use when connecting to the underlying datastore Get(ctx context.Context, dbHostname string, dbPort uint16, dbUser string) (string, string, error) } +var NoCredentialsProvider CredentialsProvider = nil + type credentialsProviderBuilderFunc func(ctx context.Context) (CredentialsProvider, error) const ( + // AWSIAMCredentialProvider generates AWS IAM tokens for authenticating with the datastore (i.e. RDS) AWSIAMCredentialProvider = "aws-iam" ) @@ -65,9 +71,15 @@ type awsIamCredentialsProvider struct { awsSdkConfig aws.Config } +func (d awsIamCredentialsProvider) Name() string { + return AWSIAMCredentialProvider +} + func (d awsIamCredentialsProvider) Get(ctx context.Context, dbHostname string, dbPort uint16, dbUser string) (string, string, error) { dbEndpoint := fmt.Sprintf("%s:%d", dbHostname, dbPort) authToken, err := rdsauth.BuildAuthToken(ctx, dbEndpoint, d.awsSdkConfig.Region, dbUser, d.awsSdkConfig.Credentials) - log.Ctx(ctx).Trace().Str("region", d.awsSdkConfig.Region).Str("endpoint", dbEndpoint).Str("user", dbUser).Msg("successfully retrieved IAM auth token for DB") + if err != nil { + log.Ctx(ctx).Trace().Str("region", d.awsSdkConfig.Region).Str("endpoint", dbEndpoint).Str("user", dbUser).Msg("successfully retrieved IAM auth token for DB") + } return dbUser, authToken, err } diff --git a/pkg/datastore/credentials_test.go b/pkg/datastore/credentials_test.go index 158269b6ae..240040e66a 100644 --- a/pkg/datastore/credentials_test.go +++ b/pkg/datastore/credentials_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestNewCredentialsProvider(t *testing.T) { +func TestUnknownCredentialsProvider(t *testing.T) { unknownCredentialsProviders := []string{"", "some-unknown-credentials-provider"} for _, unknownCredentialsProvider := range unknownCredentialsProviders { t.Run(unknownCredentialsProvider, func(t *testing.T) { @@ -17,3 +17,22 @@ func TestNewCredentialsProvider(t *testing.T) { }) } } + +func TestAWSIAMCredentialsProvider(t *testing.T) { + // set up the environment, so we don't make any external calls to AWS + t.Setenv("AWS_CONFIG_FILE", "file_not_exists") + t.Setenv("AWS_SHARED_CREDENTIALS_FILE", "file_not_exists") + t.Setenv("AWS_ENDPOINT_URL", "http://169.254.169.254/aws") + t.Setenv("AWS_ACCESS_KEY", "access_key") + t.Setenv("AWS_SECRET_KEY", "secret_key") + t.Setenv("AWS_REGION", "us-east-1") + + credentialsProvider, err := NewCredentialsProvider(context.Background(), AWSIAMCredentialProvider) + require.NotNil(t, credentialsProvider) + require.NoError(t, err) + + username, password, err := credentialsProvider.Get(context.Background(), "some-hostname", 5432, "some-user") + require.NoError(t, err) + require.Equal(t, "some-user", username) + require.Containsf(t, password, "X-Amz-Algorithm", "signed token should contain algorithm attribute") +} From 96d6342b87aacbdc9ddc7deb6927fe3dbcbdef90 Mon Sep 17 00:00:00 2001 From: Jesse White Date: Mon, 15 Apr 2024 14:14:02 -0400 Subject: [PATCH 3/5] return -> returns --- pkg/datastore/credentials.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/datastore/credentials.go b/pkg/datastore/credentials.go index 7b3bfa4a9d..46c5b1e183 100644 --- a/pkg/datastore/credentials.go +++ b/pkg/datastore/credentials.go @@ -16,9 +16,9 @@ import ( // CredentialsProvider allows datastore credentials to be retrieved dynamically type CredentialsProvider interface { - // Name return the name of the provider + // Name returns the name of the provider Name() string - // Get return the username and password to use when connecting to the underlying datastore + // Get returns the username and password to use when connecting to the underlying datastore Get(ctx context.Context, dbHostname string, dbPort uint16, dbUser string) (string, string, error) } @@ -47,8 +47,8 @@ func CredentialsProviderOptions() string { } // NewCredentialsProvider create a new CredentialsProvider for the given name -// return nil if no match is found -// return an error if there is a problem initializing the given CredentialsProvider +// returns nil if no match is found +// returns an error if there is a problem initializing the given CredentialsProvider func NewCredentialsProvider(ctx context.Context, name string) (CredentialsProvider, error) { builder, ok := BuilderForCredentialProvider[name] if !ok { From 587fed64ab312e7080d561c307e4b4567a3b9a03 Mon Sep 17 00:00:00 2001 From: Jesse White Date: Mon, 15 Apr 2024 14:11:26 -0400 Subject: [PATCH 4/5] return an error for unknown credential providers return the NoCredentialsProvider (aka nil) when given a empty string --- pkg/datastore/credentials.go | 8 +++++--- pkg/datastore/credentials_test.go | 10 ++++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/pkg/datastore/credentials.go b/pkg/datastore/credentials.go index 46c5b1e183..783430be80 100644 --- a/pkg/datastore/credentials.go +++ b/pkg/datastore/credentials.go @@ -47,12 +47,14 @@ func CredentialsProviderOptions() string { } // NewCredentialsProvider create a new CredentialsProvider for the given name -// returns nil if no match is found -// returns an error if there is a problem initializing the given CredentialsProvider +// returns an error if no match is found, of if there is a problem creating the given CredentialsProvider func NewCredentialsProvider(ctx context.Context, name string) (CredentialsProvider, error) { + if name == "" { + return NoCredentialsProvider, nil + } builder, ok := BuilderForCredentialProvider[name] if !ok { - return nil, nil + return nil, fmt.Errorf("unknown credentials provider: %s", name) } return builder(ctx) } diff --git a/pkg/datastore/credentials_test.go b/pkg/datastore/credentials_test.go index 240040e66a..780f583e9c 100644 --- a/pkg/datastore/credentials_test.go +++ b/pkg/datastore/credentials_test.go @@ -7,13 +7,19 @@ import ( "github.com/stretchr/testify/require" ) +func TestNoCredentialsProvider(t *testing.T) { + credentialsProvider, err := NewCredentialsProvider(context.Background(), "") + require.Equal(t, NoCredentialsProvider, credentialsProvider) + require.NoError(t, err) +} + func TestUnknownCredentialsProvider(t *testing.T) { - unknownCredentialsProviders := []string{"", "some-unknown-credentials-provider"} + unknownCredentialsProviders := []string{"some-unknown-credentials-provider", " "} for _, unknownCredentialsProvider := range unknownCredentialsProviders { t.Run(unknownCredentialsProvider, func(t *testing.T) { credentialsProvider, err := NewCredentialsProvider(context.Background(), unknownCredentialsProvider) require.Nil(t, credentialsProvider) - require.NoError(t, err) + require.Error(t, err) }) } } From 8d853ceb6657c4866b8479f9c9da28e5066ce3f1 Mon Sep 17 00:00:00 2001 From: Jesse White Date: Mon, 15 Apr 2024 15:56:21 -0400 Subject: [PATCH 5/5] don't treat "" as a special case, avoid calling instead --- internal/datastore/postgres/postgres.go | 11 +++++++---- pkg/cmd/migrate.go | 10 +++++++--- pkg/datastore/credentials.go | 3 --- pkg/datastore/credentials_test.go | 8 +------- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index 76b1e7704e..b77b939b37 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -146,10 +146,13 @@ func newPostgresDatastore( return nil, common.RedactAndLogSensitiveConnString(ctx, errUnableToInstantiate, err, pgURL) } - // Setup the credential provider - credentialsProvider, err := datastore.NewCredentialsProvider(ctx, config.credentialsProviderName) - if err != nil { - return nil, err + // Setup the credentials provider + var credentialsProvider datastore.CredentialsProvider + if config.credentialsProviderName != "" { + credentialsProvider, err = datastore.NewCredentialsProvider(ctx, config.credentialsProviderName) + if err != nil { + return nil, err + } } // Setup the config for each of the read and write pools. diff --git a/pkg/cmd/migrate.go b/pkg/cmd/migrate.go index 00c50121fe..16a30592b8 100644 --- a/pkg/cmd/migrate.go +++ b/pkg/cmd/migrate.go @@ -60,10 +60,14 @@ func migrateRun(cmd *cobra.Command, args []string) error { } else if datastoreEngine == "postgres" { log.Ctx(cmd.Context()).Info().Msg("migrating postgres datastore") + var credentialsProvider datastore.CredentialsProvider credentialsProviderName := cobrautil.MustGetString(cmd, "datastore-credentials-provider-name") - credentialsProvider, err := datastore.NewCredentialsProvider(cmd.Context(), credentialsProviderName) - if err != nil { - return err + if credentialsProviderName != "" { + var err error + credentialsProvider, err = datastore.NewCredentialsProvider(cmd.Context(), credentialsProviderName) + if err != nil { + return err + } } migrationDriver, err := migrations.NewAlembicPostgresDriver(cmd.Context(), dbURL, credentialsProvider) diff --git a/pkg/datastore/credentials.go b/pkg/datastore/credentials.go index 783430be80..ee71fc309c 100644 --- a/pkg/datastore/credentials.go +++ b/pkg/datastore/credentials.go @@ -49,9 +49,6 @@ func CredentialsProviderOptions() string { // NewCredentialsProvider create a new CredentialsProvider for the given name // returns an error if no match is found, of if there is a problem creating the given CredentialsProvider func NewCredentialsProvider(ctx context.Context, name string) (CredentialsProvider, error) { - if name == "" { - return NoCredentialsProvider, nil - } builder, ok := BuilderForCredentialProvider[name] if !ok { return nil, fmt.Errorf("unknown credentials provider: %s", name) diff --git a/pkg/datastore/credentials_test.go b/pkg/datastore/credentials_test.go index 780f583e9c..076df2dd0b 100644 --- a/pkg/datastore/credentials_test.go +++ b/pkg/datastore/credentials_test.go @@ -7,14 +7,8 @@ import ( "github.com/stretchr/testify/require" ) -func TestNoCredentialsProvider(t *testing.T) { - credentialsProvider, err := NewCredentialsProvider(context.Background(), "") - require.Equal(t, NoCredentialsProvider, credentialsProvider) - require.NoError(t, err) -} - func TestUnknownCredentialsProvider(t *testing.T) { - unknownCredentialsProviders := []string{"some-unknown-credentials-provider", " "} + unknownCredentialsProviders := []string{"", " ", "some-unknown-credentials-provider"} for _, unknownCredentialsProvider := range unknownCredentialsProviders { t.Run(unknownCredentialsProvider, func(t *testing.T) { credentialsProvider, err := NewCredentialsProvider(context.Background(), unknownCredentialsProvider)