Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow set session role for PostgreSQL and CockroachDB #1028

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions database/cockroachdb/README.md
Expand Up @@ -8,6 +8,7 @@
| `x-lock-table` | `LockTable` | Name of the table which maintains the migration lock |
| `x-force-lock` | `ForceLock` | Force lock acquisition to fix faulty migrations which may not have released the schema lock (Boolean, default is `false`) |
| `dbname` | `DatabaseName` | The name of the database to connect to |
| `x-role` | `Role` | The role to be set in case it differs from the user |
| `user` | | The user to sign in as |
| `password` | | The user's password |
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
Expand Down
13 changes: 13 additions & 0 deletions database/cockroachdb/cockroachdb.go
Expand Up @@ -37,6 +37,7 @@ type Config struct {
LockTable string
ForceLock bool
DatabaseName string
Role string
}

type CockroachDb struct {
Expand Down Expand Up @@ -77,6 +78,12 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if len(config.LockTable) == 0 {
config.LockTable = DefaultLockTable
}
if len(config.Role) > 0 {
query := fmt.Sprintf("SET ROLE %s", database.QuoteString(config.Role))
if _, err := instance.Exec(query); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
}

px := &CockroachDb{
db: instance,
Expand Down Expand Up @@ -127,11 +134,17 @@ func (c *CockroachDb) Open(url string) (database.Driver, error) {
forceLock = false
}

var role string
if s := purl.Query().Get("x-role"); len(s) > 0 {
role = s
}

px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
LockTable: lockTable,
ForceLock: forceLock,
Role: role,
})
if err != nil {
return nil, err
Expand Down
86 changes: 81 additions & 5 deletions database/cockroachdb/cockroachdb_test.go
Expand Up @@ -7,7 +7,10 @@ import (
"database/sql"
"fmt"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/pkg/errors"
"log"
"regexp"
"strings"
"testing"
)
Expand All @@ -26,13 +29,12 @@ import (
const defaultPort = 26257

var (
opts = dktest.Options{Cmd: []string{"start", "--insecure"}, PortRequired: true, ReadyFunc: isReady}
opts = dktest.Options{Cmd: []string{"start-single-node", "--insecure"}, PortRequired: true, ReadyFunc: isReady}
// Released versions: https://www.cockroachlabs.com/docs/releases/
specs = []dktesting.ContainerSpec{
{ImageName: "cockroachdb/cockroach:v1.0.7", Options: opts},
{ImageName: "cockroachdb/cockroach:v1.1.9", Options: opts},
{ImageName: "cockroachdb/cockroach:v2.0.7", Options: opts},
{ImageName: "cockroachdb/cockroach:v2.1.3", Options: opts},
{ImageName: "cockroachdb/cockroach:latest-v22.1", Options: opts},
{ImageName: "cockroachdb/cockroach:latest-v22.2", Options: opts},
{ImageName: "cockroachdb/cockroach:latest-v23.1", Options: opts},
}
)

Expand Down Expand Up @@ -82,6 +84,14 @@ func createDB(t *testing.T, c dktest.ContainerInfo) {
}
}

func mustRun(t *testing.T, d database.Driver, statements []string) {
for _, statement := range statements {
if err := d.Run(strings.NewReader(statement)); err != nil {
t.Fatal(err)
}
}
}

func Test(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) {
createDB(t, ci)
Expand Down Expand Up @@ -172,3 +182,69 @@ func TestFilterCustomQuery(t *testing.T) {
}
})
}

func TestRole(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) {
createDB(t, ci)

ip, port, err := ci.Port(26257)
if err != nil {
t.Fatal(err)
}

d, err := sql.Open("postgres", fmt.Sprintf("postgres://root@%v:%v?sslmode=disable", ip, port))
if err != nil {
t.Fatal(err)
}
prepare := []string{
"CREATE ROLE IF NOT EXISTS _fa NOLOGIN;",
"CREATE ROLE IF NOT EXISTS _fa_ungranted NOLOGIN",
"CREATE ROLE deploy LOGIN",
"GRANT _fa TO deploy",
"GRANT CREATE ON DATABASE migrate TO _fa, _fa_ungranted;",
}
for _, query := range prepare {
if _, err := d.Exec(query); err != nil {
t.Fatal(err)
}
}

c := &CockroachDb{}

// positive: connecting with deploy user and setting role to _fa
d2, err := c.Open(fmt.Sprintf("cockroach://deploy@%v:%v/migrate?sslmode=disable&x-role=_fa", ip, port))
if err != nil {
t.Fatal(err)
}
mustRun(t, d2, []string{
"CREATE TABLE foo (role INT UNIQUE);",
})
var exists bool
if err := d2.(*CockroachDb).db.QueryRow("SELECT EXISTS (SELECT 1 FROM pg_tables WHERE tablename = 'foo' AND schemaname = (SELECT current_schema()) AND tableowner = '_fa');").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table foo owned by _fa role to exist")
}

var e *database.Error
// negative: connecting with deploy user and trying to set not existing role
_, err = c.Open(fmt.Sprintf("cockroach://root@%v:%v/migrate?sslmode=disable&x-role=_not_existing_role", ip, port))
if !errors.As(err, &e) || err == nil {
t.Fatal(fmt.Errorf("unexpected success, wanted pq: role/user does not exist, got: %w", err))
}
re := regexp.MustCompile("^pq: role(/user)? (\")?_not_existing_role(\")? does not exist$")
if !re.MatchString(e.OrigErr.Error()) {
t.Fatal(fmt.Errorf("unexpected error, wanted _not_existing_role does not exist, got: %s", e.OrigErr.Error()))
}

// negative: connecting with deploy user and trying to set not granted role
_, err = c.Open(fmt.Sprintf("cockroach://deploy@%v:%v/migrate?sslmode=disable&x-role=_fa_ungranted", ip, port))
if !errors.As(err, &e) || err == nil {
t.Fatal(fmt.Errorf("unexpected success, wanted permission denied error, got: %w", err))
}
if !strings.Contains(e.OrigErr.Error(), "permission denied to set role") {
t.Fatal(fmt.Errorf("unexpected error, wanted permission denied error, got: %s", e.OrigErr.Error()))
}
})
}
1 change: 1 addition & 0 deletions database/postgres/README.md
Expand Up @@ -9,6 +9,7 @@
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
| `x-role` | `Role` | The role to be set in case it differs from the user |
| `dbname` | `DatabaseName` | The name of the database to connect to |
| `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. |
| `user` | | The user to sign in as |
Expand Down
27 changes: 27 additions & 0 deletions database/postgres/postgres.go
Expand Up @@ -6,6 +6,7 @@ package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"io"
nurl "net/url"
Expand Down Expand Up @@ -41,6 +42,7 @@ var (
ErrNoDatabaseName = fmt.Errorf("no database name")
ErrNoSchema = fmt.Errorf("no schema")
ErrDatabaseDirty = fmt.Errorf("database is dirty")
ErrNoSuchRole = fmt.Errorf("no such role")
)

type Config struct {
Expand All @@ -53,6 +55,7 @@ type Config struct {
migrationsTableName string
StatementTimeout time.Duration
MultiStatementMaxSize int
Role string
}

type Postgres struct {
Expand Down Expand Up @@ -106,6 +109,18 @@ func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postg
config.MigrationsTable = DefaultMigrationsTable
}

if len(config.Role) > 0 {
var role string
query := `SELECT rolname FROM pg_roles WHERE rolname = $1`
if err := conn.QueryRowContext(ctx, query, config.Role).Scan(&role); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNoSuchRole
}
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
config.Role = role
}

config.migrationsSchemaName = config.SchemaName
config.migrationsTableName = config.MigrationsTable
if config.MigrationsTableQuoted {
Expand Down Expand Up @@ -202,13 +217,19 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
}
}

var role string
if s := purl.Query().Get("x-role"); len(s) > 0 {
role = s
}

px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
MigrationsTableQuoted: migrationsTableQuoted,
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
MultiStatementEnabled: multiStatementEnabled,
MultiStatementMaxSize: multiStatementMaxSize,
Role: role,
})

if err != nil {
Expand Down Expand Up @@ -291,6 +312,12 @@ func (p *Postgres) runStatement(statement []byte) error {
ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
defer cancel()
}
if len(p.config.Role) > 0 {
query := "SET ROLE " + p.config.Role
if _, err := p.conn.ExecContext(ctx, query); err != nil {
return database.Error{OrigErr: err, Err: "failed to set role", Query: statement}
}
}
query := string(statement)
if strings.TrimSpace(query) == "" {
return nil
Expand Down
69 changes: 69 additions & 0 deletions database/postgres/postgres_test.go
Expand Up @@ -801,3 +801,72 @@ func Test_computeLineFromPos(t *testing.T) {
})
}
}

func TestRole(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}

addr := pgConnectionString(ip, port)

// Check that opening the postgres connection returns NilVersion
p := &Postgres{}

d, err := p.Open(addr)

if err != nil {
t.Fatal(err)
}

defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()

mustRun(t, d, []string{
"CREATE USER _fa",
"GRANT CONNECT, CREATE, TEMPORARY ON DATABASE postgres TO _fa",
"CREATE USER deploy WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
"GRANT _fa TO deploy",
})

// positive: connecting with deploy user and setting role to _fa
d2, err := p.Open(fmt.Sprintf("postgres://deploy:%s@%v:%v/postgres?sslmode=disable&x-role=_fa", pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
if err := d2.Run(strings.NewReader("CREATE TABLE foo (foo text);")); err != nil {
t.Fatal(err)
}
var exists bool
if err := d2.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE tablename = 'foo' AND tableowner = '_fa');").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table foo owned by _fa role to exist")
}

// negative: connecting with deploy user and trying to set not existing role
_, err = p.Open(fmt.Sprintf("postgres://deploy:%s@%v:%v/postgres?sslmode=disable&x-role=not_existing_role", pgPassword, ip, port))
if err != ErrNoSuchRole {
t.Fatal(fmt.Errorf("expected %w, but got %w", ErrNoSuchRole, err))
}

// negative: connecting with deploy user and trying to set not granted role
d3, err := p.Open(fmt.Sprintf("postgres://deploy:%s@%v:%v/postgres?sslmode=disable&x-role=postgres", pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
err = d3.Run(strings.NewReader("CREATE TABLE imdasuperuser (foo text);"))
var e database.Error
if !errors.As(err, &e) || err == nil {
t.Fatal(fmt.Errorf("unexpected success, wanted permission denied error, got: %w", err))
}
if !strings.Contains(e.OrigErr.Error(), "permission denied to set role") {
t.Fatal(fmt.Errorf("unexpected error, wanted permission denied error, got: %w", err))
}
})
}
4 changes: 4 additions & 0 deletions database/util.go
Expand Up @@ -31,3 +31,7 @@ func CasRestoreOnErr(lock *atomic.Bool, o, n bool, casErr error, f func() error)
}
return nil
}

func QuoteString(str string) string {
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
}