Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve multistatement support for postgres #1018

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion Makefile
Expand Up @@ -105,6 +105,9 @@ echo-database:
@echo "$(DATABASE)"


lint:
golangci-lint run -c .golangci.yml

define external_deps
@echo '-- $(1)'; go list -f '{{join .Deps "\n"}}' $(1) | grep -v github.com/$(REPO_OWNER)/migrate | xargs go list -f '{{if not .Standard}}{{.ImportPath}}{{end}}'

Expand All @@ -113,7 +116,8 @@ endef

.PHONY: build build-docker build-cli clean test-short test test-with-flags html-coverage \
restore-import-paths rewrite-import-paths list-external-deps release \
docs kill-docs open-docs kill-orphaned-docker-containers echo-source echo-database
docs kill-docs open-docs kill-orphaned-docker-containers echo-source echo-database \
lint

SHELL = /bin/sh
RAND = $(shell echo $$RANDOM)
Expand Down
15 changes: 11 additions & 4 deletions database/multistmt/parse.go
Expand Up @@ -15,7 +15,7 @@ var StartBufSize = 4096
// from the multi-statement migration should be parsed and handled.
type Handler func(migration []byte) bool

func splitWithDelimiter(delimiter []byte) func(d []byte, atEOF bool) (int, []byte, error) {
func splitWithDelimiter(delimiter []byte) bufio.SplitFunc {
return func(d []byte, atEOF bool) (int, []byte, error) {
// SplitFunc inspired by bufio.ScanLines() implementation
if atEOF {
Expand All @@ -31,11 +31,13 @@ func splitWithDelimiter(delimiter []byte) func(d []byte, atEOF bool) (int, []byt
}
}

// Parse parses the given multi-statement migration
func Parse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler) error {
type scanFuncFactory func(delimiter []byte) bufio.SplitFunc

// parse parses the given multi-statement migration
func parse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler, factory scanFuncFactory) error {
scanner := bufio.NewScanner(reader)
scanner.Buffer(make([]byte, 0, StartBufSize), maxMigrationSize)
scanner.Split(splitWithDelimiter(delimiter))
scanner.Split(factory(delimiter))
for scanner.Scan() {
cont := h(scanner.Bytes())
if !cont {
Expand All @@ -44,3 +46,8 @@ func Parse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler)
}
return scanner.Err()
}

// Parse parses the given multi-statement migration
func Parse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler) error {
return parse(reader, delimiter, maxMigrationSize, h, splitWithDelimiter)
}
93 changes: 93 additions & 0 deletions database/multistmt/parse_postgres.go
@@ -0,0 +1,93 @@
// Package multistmt provides methods for parsing multi-statement database migrations
package multistmt

import (
"bufio"
"bytes"
"io"
"unicode"

"golang.org/x/exp/slices"
)

const dollar = '$'

func isValidTagSymbol(r rune) bool {
return unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_'
}

func pgSplitWithDelimiter(delimiter []byte) bufio.SplitFunc {
// https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-DOLLAR-QUOTING
// inside the dollar-quoted string, single quotes can be used without needing to
// be escaped. Indeed, no characters inside a dollar-quoted string are ever
// escaped: the string content is always written literally. Backslashes are not
// special, and neither are dollar signs, unless they are part of a sequence
// matching the opening tag.
//
// It is possible to nest dollar-quoted string constants by choosing different
// tags at each nesting level. This is most commonly used in writing function
// definitions
return func(d []byte, atEOF bool) (int, []byte, error) {
if atEOF {
if len(d) == 0 {
return 0, nil, nil
}

return len(d), d, nil
}

stack := [][]byte{delimiter}
maybeDollarQuoted := false
firstDollarPosition := 0

reader := bufio.NewReader(bytes.NewReader(d))
position := 0

for position < len(d) {
currentDelimiter := stack[len(stack)-1]

if len(d[position:]) >= len(currentDelimiter) {
if slices.Equal(d[position:position+len(currentDelimiter)], currentDelimiter) {
// pop delimiter from stack and fast-forward cursor and reader
stack = stack[:len(stack)-1]
position += len(currentDelimiter)
_, _ = io.ReadFull(reader, currentDelimiter)

if len(stack) != 0 {
continue
}
}
}

if len(stack) == 0 {
return position, d[:position], nil
}

r, size, err := reader.ReadRune()
if err != nil {
return position + size, d[:position+size], err
}

switch {
case r == dollar && !maybeDollarQuoted:
maybeDollarQuoted = true

firstDollarPosition = position
case r == dollar && maybeDollarQuoted:
stack = append(stack, d[firstDollarPosition:position+size])
maybeDollarQuoted = false
case !isValidTagSymbol(r) && maybeDollarQuoted:
maybeDollarQuoted = false
}

position += size
}

return 0, nil, nil
}
}

// PGParse parses the given multi-statement migration for PostgreSQL respecting the dollar-quoted strings
func PGParse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler) error {
return parse(reader, delimiter, maxMigrationSize, h, pgSplitWithDelimiter)
}
102 changes: 102 additions & 0 deletions database/multistmt/parse_postgres_test.go
@@ -0,0 +1,102 @@
package multistmt_test

import (
"strings"
"testing"

"github.com/stretchr/testify/assert"

"github.com/golang-migrate/migrate/v4/database/multistmt"
)

func TestPGParse(t *testing.T) {
createFunctionEmptyTagStmt := `CREATE FUNCTION set_new_id() RETURNS TRIGGER AS
$$
BEGIN
NEW.new_id := NEW.id;
RETURN NEW;
END
$$ LANGUAGE PLPGSQL;`

createFunctionStmt := `CREATE FUNCTION set_new_id() RETURNS TRIGGER AS
$BODY$
BEGIN
NEW.new_id := NEW.id;
RETURN NEW;
END
$BODY$ LANGUAGE PLPGSQL;`

createTriggerStmt := `CREATE TRIGGER set_new_id_trigger BEFORE INSERT OR UPDATE ON mytable
FOR EACH ROW EXECUTE PROCEDURE set_new_id();`

nestedDollarQuotes := `$function$
BEGIN
RETURN ($1 ~ $q$[\t\r\n\v\\]$q$);
END;
$function$;`

advancedCreateFunction := `CREATE FUNCTION check_password(uname TEXT, pass TEXT)
RETURNS BOOLEAN AS $$
DECLARE passed BOOLEAN;
BEGIN
SELECT (pwd = $2) INTO passed
FROM pwds
WHERE username = $1;

RETURN passed;
END;
$$ LANGUAGE plpgsql
SECURITY DEFINER
-- Set a secure search_path: trusted schema(s), then 'pg_temp'.
SET search_path = admin, pg_temp;`

testCases := []struct {
name string
multiStmt string
delimiter string
expected []string
expectedErr error
}{
{name: "single statement, no delimiter", multiStmt: "single statement, no delimiter", delimiter: ";",
expected: []string{"single statement, no delimiter"}, expectedErr: nil},
{name: "single statement, one delimiter", multiStmt: "single statement, one delimiter;", delimiter: ";",
expected: []string{"single statement, one delimiter;"}, expectedErr: nil},
{name: "two statements, no trailing delimiter", multiStmt: "statement one; statement two", delimiter: ";",
expected: []string{"statement one;", " statement two"}, expectedErr: nil},
{name: "two statements, with trailing delimiter", multiStmt: "statement one; statement two;", delimiter: ";",
expected: []string{"statement one;", " statement two;"}, expectedErr: nil},
{name: "singe statement with nested dollar-quoted string", multiStmt: nestedDollarQuotes, delimiter: ";",
expected: []string{nestedDollarQuotes}},
{name: "multiple statements with dollar-quoted strings", multiStmt: strings.Join([]string{createFunctionStmt,
createFunctionEmptyTagStmt, advancedCreateFunction, createTriggerStmt, nestedDollarQuotes}, ""),
delimiter: ";",
expected: []string{createFunctionStmt, createFunctionEmptyTagStmt, advancedCreateFunction,
createTriggerStmt, nestedDollarQuotes}},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
stmts := make([]string, 0, len(tc.expected))
err := multistmt.PGParse(strings.NewReader(tc.multiStmt), []byte(tc.delimiter), maxMigrationSize, func(b []byte) bool {
stmts = append(stmts, string(b))
return true
})
assert.Equal(t, tc.expectedErr, err)
assert.Equal(t, tc.expected, stmts)
})
}
}

func TestPGParseDiscontinue(t *testing.T) {
multiStmt := "statement one; statement two"
delimiter := ";"
expected := []string{"statement one;"}

stmts := make([]string, 0, len(expected))
err := multistmt.PGParse(strings.NewReader(multiStmt), []byte(delimiter), maxMigrationSize, func(b []byte) bool {
stmts = append(stmts, string(b))
return false
})
assert.Nil(t, err)
assert.Equal(t, expected, stmts)
}
2 changes: 1 addition & 1 deletion database/pgx/pgx.go
Expand Up @@ -363,7 +363,7 @@ func (p *Postgres) releaseTableLock() error {
func (p *Postgres) Run(migration io.Reader) error {
if p.config.MultiStatementEnabled {
var err error
if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
if e := multistmt.PGParse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
if err = p.runStatement(m); err != nil {
return false
}
Expand Down
17 changes: 16 additions & 1 deletion database/pgx/pgx_test.go
Expand Up @@ -211,7 +211,13 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
t.Error(err)
}
}()
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil {
if err := d.Run(strings.NewReader(`CREATE TABLE foo (foo text);
CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);
CREATE FUNCTION baz() RETURNS integer AS $$
BEGIN
RETURN 1;
END;
$$ LANGUAGE plpgsql;`)); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}

Expand All @@ -223,6 +229,15 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
if !exists {
t.Fatalf("expected table bar to exist")
}

// make sure procedure exists
var proc string
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT 'baz'::regproc;").Scan(&proc); err != nil {
t.Fatal(err)
}
if proc != "baz" {
t.Fatalf("expected procedure baz to exists")
}
})
}

Expand Down
2 changes: 1 addition & 1 deletion database/pgx/v5/pgx.go
Expand Up @@ -254,7 +254,7 @@ func (p *Postgres) Unlock() error {
func (p *Postgres) Run(migration io.Reader) error {
if p.config.MultiStatementEnabled {
var err error
if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
if e := multistmt.PGParse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
if err = p.runStatement(m); err != nil {
return false
}
Expand Down
17 changes: 16 additions & 1 deletion database/pgx/v5/pgx_test.go
Expand Up @@ -186,7 +186,13 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
t.Error(err)
}
}()
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil {
if err := d.Run(strings.NewReader(`CREATE TABLE foo (foo text);
CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);
CREATE FUNCTION baz() RETURNS integer AS $$
BEGIN
RETURN 1;
END;
$$ LANGUAGE plpgsql;`)); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}

Expand All @@ -198,6 +204,15 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
if !exists {
t.Fatalf("expected table bar to exist")
}

// make sure procedure exists
var proc string
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT 'baz'::regproc;").Scan(&proc); err != nil {
t.Fatal(err)
}
if proc != "baz" {
t.Fatalf("expected procedure baz to exists")
}
})
}

Expand Down
2 changes: 1 addition & 1 deletion database/postgres/postgres.go
Expand Up @@ -267,7 +267,7 @@ func (p *Postgres) Unlock() error {
func (p *Postgres) Run(migration io.Reader) error {
if p.config.MultiStatementEnabled {
var err error
if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
if e := multistmt.PGParse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
if err = p.runStatement(m); err != nil {
return false
}
Expand Down
18 changes: 17 additions & 1 deletion database/postgres/postgres_test.go
Expand Up @@ -183,7 +183,14 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
t.Error(err)
}
}()
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil {
if err := d.Run(strings.NewReader(`CREATE TABLE foo (foo text);
CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);
CREATE FUNCTION baz() RETURNS integer AS $$
BEGIN
RETURN 1;
END;
$$ LANGUAGE plpgsql;
`)); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}

Expand All @@ -195,6 +202,15 @@ func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
if !exists {
t.Fatalf("expected table bar to exist")
}

// make sure procedure exists
var proc string
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT 'baz'::regproc;").Scan(&proc); err != nil {
t.Fatal(err)
}
if proc != "baz" {
t.Fatalf("expected procedure baz to exists")
}
})
}

Expand Down