Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rebind safety #775

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
64 changes: 55 additions & 9 deletions bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"sync"

"github.com/jmoiron/sqlx/reflectx"
"github.com/muir/sqltoken"
)

// Bindvar types supported by Rebind, BindMap and BindStruct.
Expand All @@ -22,21 +23,39 @@ const (
)

var defaultBinds = map[int][]string{
DOLLAR: []string{"postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "ql", "nrpostgres", "cockroach"},
QUESTION: []string{"mysql", "sqlite3", "nrmysql", "nrsqlite3"},
NAMED: []string{"oci8", "ora", "goracle", "godror"},
AT: []string{"sqlserver"},
DOLLAR: {"postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "ql", "nrpostgres", "cockroach"},
QUESTION: {"mysql", "sqlite3", "nrmysql", "nrsqlite3"},
NAMED: {"oci8", "ora", "goracle", "godror"},
AT: {"sqlserver"},
}

var binds sync.Map

var rebindConfigs = func() []sqltoken.Config {
configs := make([]sqltoken.Config, AT+1)
pg := sqltoken.PostgreSQLConfig()
pg.NoticeQuestionMark = true
pg.NoticeDollarNumber = false
configs[DOLLAR] = pg

ora := sqltoken.OracleConfig()
ora.NoticeColonWord = false
ora.NoticeQuestionMark = true
configs[NAMED] = ora

ssvr := sqltoken.SQLServerConfig()
ssvr.NoticeAtWord = false
ssvr.NoticeQuestionMark = true
configs[AT] = ssvr
return configs
}()

func init() {
for bind, drivers := range defaultBinds {
for _, driver := range drivers {
BindDriver(driver, bind)
}
}

}

// BindType returns the bindtype for a given database given a drivername.
Expand All @@ -53,15 +72,43 @@ func BindDriver(driverName string, bindType int) {
binds.Store(driverName, bindType)
}

// FIXME: this should be able to be tolerant of escaped ?'s in queries without
// losing much speed, and should be to avoid confusion.

// Rebind a query from the default bindtype (QUESTION) to the target bindtype.
func Rebind(bindType int, query string) string {
switch bindType {
case QUESTION, UNKNOWN:
return query
}
config := rebindConfigs[bindType]
tokens := sqltoken.Tokenize(query, config)
rqb := make([]byte, 0, len(query)+10)

var j int
for _, token := range tokens {
if token.Type != sqltoken.QuestionMark {
rqb = append(rqb, ([]byte)(token.Text)...)
continue
}
switch bindType {
case DOLLAR:
rqb = append(rqb, '$')
case NAMED:
rqb = append(rqb, ':', 'a', 'r', 'g')
case AT:
rqb = append(rqb, '@', 'p')
}
j++
rqb = strconv.AppendInt(rqb, int64(j), 10)
}
return string(rqb)
}

// Previous rebind implementation, kept here for benchmarking purposes
// at least for now.
func oldRebind(bindType int, query string) string {
switch bindType {
case QUESTION, UNKNOWN:
return query
}

// Add space enough for 10 params before we have to allocate
rqb := make([]byte, 0, len(query)+10)
Expand Down Expand Up @@ -130,7 +177,6 @@ func asSliceForIn(i interface{}) (v reflect.Value, ok bool) {
// []byte is a driver.Value type so it should not be expanded
if t == reflect.TypeOf([]byte{}) {
return reflect.Value{}, false

}

return v, true
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ require (
github.com/go-sql-driver/mysql v1.6.0
github.com/lib/pq v1.2.0
github.com/mattn/go-sqlite3 v1.14.6
github.com/muir/sqltoken v0.0.4
)
21 changes: 21 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,27 @@
github.com/alvaroloes/enumer v1.1.2/go.mod h1:FxrjvuXoDAx9isTJrv4c+T410zFi0DtXIT0m65DJ+Wo=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/muir/sqltoken v0.0.4 h1:SioNnG90ZYXmlfnPaUxUdNC1dFkhKL64pDeS+wXZ8k8=
github.com/muir/sqltoken v0.0.4/go.mod h1:6hPsZxszMpYyNf12og4f4VShFo/Qipz6Of0cn5KGAAU=
github.com/pascaldekloe/name v0.0.0-20180628100202-0fd16699aae1/go.mod h1:eD5JxqMiuNYyFNmyY9rkJ/slN8y59oEu4Ei7F8OoKWQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20190524210228-3d17549cdc6b/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
64 changes: 62 additions & 2 deletions named.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"unicode"

"github.com/jmoiron/sqlx/reflectx"
"github.com/muir/sqltoken"
)

// NamedStmt is a prepared statement that executes named queries. Prepare it
Expand Down Expand Up @@ -157,7 +158,6 @@ func convertMapStringInterface(v interface{}) (map[string]interface{}, bool) {
return nil, false
}
return reflect.ValueOf(v).Convert(mtype).Interface().(map[string]interface{}), true

}

func bindAnyArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) {
Expand Down Expand Up @@ -282,7 +282,7 @@ func bindArray(bindType int, query string, arg interface{}, m *reflectx.Mapper)
if arrayLen == 0 {
return "", []interface{}{}, fmt.Errorf("length of array is 0: %#v", arg)
}
var arglist = make([]interface{}, 0, len(names)*arrayLen)
arglist := make([]interface{}, 0, len(names)*arrayLen)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was made by gofmt and so were a bunch of other similar changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're using gofumpt rather than gofmt? gofmt doesn't make these kind of changes anyway.

for i := 0; i < arrayLen; i++ {
elemArglist, err := bindAnyArgs(names, arrayValue.Index(i).Interface(), m)
if err != nil {
Expand Down Expand Up @@ -311,6 +311,33 @@ func bindMap(bindType int, query string, args map[string]interface{}) (string, [
return bound, arglist, err
}

var namedParseConfigs = func() []sqltoken.Config {
configs := make([]sqltoken.Config, AT+1)
pg := sqltoken.PostgreSQLConfig()
pg.NoticeColonWord = true
pg.ColonWordIncludesUnicode = true
pg.NoticeDollarNumber = false
configs[DOLLAR] = pg

ora := sqltoken.OracleConfig()
ora.ColonWordIncludesUnicode = true
configs[NAMED] = ora

ssvr := sqltoken.SQLServerConfig()
ssvr.NoticeColonWord = true
ssvr.ColonWordIncludesUnicode = true
ssvr.NoticeAtWord = false
configs[AT] = ssvr

mysql := sqltoken.MySQLConfig()
mysql.NoticeColonWord = true
mysql.ColonWordIncludesUnicode = true
mysql.NoticeQuestionMark = false
configs[QUESTION] = mysql
configs[UNKNOWN] = mysql
return configs
}()

// -- Compilation of Named Queries

// Allow digits and letters in bind params; additionally runes are
Expand All @@ -332,6 +359,39 @@ func compileNamedQuery(qs []byte, bindType int) (query string, names []string, e
names = make([]string, 0, 10)
rebound := make([]byte, 0, len(qs))

currentVar := 1
tokens := sqltoken.Tokenize(string(qs), namedParseConfigs[bindType])

for _, token := range tokens {
if token.Type != sqltoken.ColonWord {
rebound = append(rebound, ([]byte)(token.Text)...)
continue
}
names = append(names, token.Text[1:])
switch bindType {
// oracle only supports named type bind vars even for positional
case NAMED:
rebound = append(rebound, ([]byte)(token.Text)...)
case QUESTION, UNKNOWN:
rebound = append(rebound, '?')
case DOLLAR:
rebound = append(rebound, '$')
rebound = strconv.AppendInt(rebound, int64(currentVar), 10)
currentVar++
case AT:
rebound = append(rebound, '@', 'p')
rebound = strconv.AppendInt(rebound, int64(currentVar), 10)
currentVar++
}
}
return string(rebound), names, nil
}

// kept for benchmarking purposes
func oldCmpileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) {
names = make([]string, 0, 10)
rebound := make([]byte, 0, len(qs))

inName := false
last := len(qs) - 1
currentVar := 1
Expand Down
89 changes: 53 additions & 36 deletions named_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,21 @@ import (

func TestCompileQuery(t *testing.T) {
table := []struct {
d string
Q, R, D, T, N string
V []string
}{
// basic test for named parameters, invalid char ',' terminating
{
d: "basic test for named parameters, invalid char ',' terminating",
Q: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`,
R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`,
D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`,
T: `INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)`,
N: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`,
V: []string{"name", "age", "first", "last"},
},
// This query tests a named parameter ending the string as well as numbers
{
d: "This query tests a named parameter ending the string as well as numbers",
Q: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`,
R: `SELECT * FROM a WHERE first_name=? AND last_name=?`,
D: `SELECT * FROM a WHERE first_name=$1 AND last_name=$2`,
Expand All @@ -30,15 +31,15 @@ func TestCompileQuery(t *testing.T) {
V: []string{"name1", "name2"},
},
{
Q: `SELECT "::foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
Q: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume idea here is/was that ::foo would "escape" the :foo; sqltoken is smarter about this as it recognizes the quotes so you don't need to use ::foo – which is clearly better – but this is an incompatible change that will break people's SQL if they're using it like this.

As an aside, this library is "sporadically maintained", at best, so I'm not expecting a merge or feedback from the maintainer(s) any time soon, so I wouldn't rush to fix this. I'm (probably) going to use in my own library/fork I'm working on though. Just wanted to point this out.

Also might be nice to replace those four calls to testify in your library? Don't want to start a discussion about testify vs. stdlib tests, but smaller dependency trees are always nice, and with just four calls it's not that much benefit 😅

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I called this out in my PR description. Since this is not backwards compatible it requires a major version change. The current behavior isn't documented, but people may still be relying on it.

If getting rid of the calls to testify would mean this gets merged, I would certainly do it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, I missed that 😅

If getting rid of the calls to testify would mean this gets merged, I would certainly do it.

I have no say in the matter at all; just something I noticed.

R: `SELECT ":foo" FROM a WHERE first_name=? AND last_name=?`,
D: `SELECT ":foo" FROM a WHERE first_name=$1 AND last_name=$2`,
T: `SELECT ":foo" FROM a WHERE first_name=@p1 AND last_name=@p2`,
N: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
V: []string{"name1", "name2"},
},
{
Q: `SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
Q: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
R: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`,
D: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`,
T: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=@p1 AND last_name=@p2`,
Expand All @@ -53,49 +54,66 @@ func TestCompileQuery(t *testing.T) {
T: `SELECT @name := "name", @p1, @p2, @p3`,
V: []string{"age", "first", "last"},
},
/* This unicode awareness test sadly fails, because of our byte-wise worldview.
* We could certainly iterate by Rune instead, though it's a great deal slower,
* it's probably the RightWay(tm)
{
Q: `INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`,
R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`,
D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`,
N: []string{"name", "age", "first", "last"},
N: `INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`,
T: `INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)`,
V: []string{"あ", "b", "キコ", "名前"},
},
{
Q: `SELECT id, added_at::date FROM person WHERE first_name=:first_name AND last_name=:last_name`,
R: `SELECT id, added_at::date FROM person WHERE first_name=? AND last_name=?`,
D: `SELECT id, added_at::date FROM person WHERE first_name=$1 AND last_name=$2`,
T: `SELECT id, added_at::date FROM person WHERE first_name=@p1 AND last_name=@p2`,
N: `SELECT id, added_at::date FROM person WHERE first_name=:first_name AND last_name=:last_name`,
V: []string{"first_name", "last_name"},
},
*/
}

for _, test := range table {
qr, names, err := compileNamedQuery([]byte(test.Q), QUESTION)
if err != nil {
t.Error(err)
}
if qr != test.R {
t.Errorf("expected %s, got %s", test.R, qr)
test := test
n := test.d
if n == "" {
n = test.Q
}
if len(names) != len(test.V) {
t.Errorf("expected %#v, got %#v", test.V, names)
} else {
for i, name := range names {
if name != test.V[i] {
t.Errorf("expected %dth name to be %s, got %s", i+1, test.V[i], name)
t.Run(n, func(t *testing.T) {
if test.d != "" {
t.Log(test.d)
}
t.Log(test.Q)
qr, names, err := compileNamedQuery([]byte(test.Q), QUESTION)
if err != nil {
t.Error(err)
}
if qr != test.R {
t.Errorf("R: expected %s, got(R) %s", test.R, qr)
}
if len(names) != len(test.V) {
t.Errorf("V: expected %#v, got(V) %#v", test.V, names)
} else {
for i, name := range names {
if name != test.V[i] {
t.Errorf("expected %dth name to be %s, got(V) %s", i+1, test.V[i], name)
}
}
}
}
qd, _, _ := compileNamedQuery([]byte(test.Q), DOLLAR)
if qd != test.D {
t.Errorf("\nexpected: `%s`\ngot: `%s`", test.D, qd)
}
qd, _, _ := compileNamedQuery([]byte(test.Q), DOLLAR)
if qd != test.D {
t.Errorf("\nexpected: `%s`\ngot(D): `%s`", test.D, qd)
}

qt, _, _ := compileNamedQuery([]byte(test.Q), AT)
if qt != test.T {
t.Errorf("\nexpected: `%s`\ngot: `%s`", test.T, qt)
}
qt, _, _ := compileNamedQuery([]byte(test.Q), AT)
if qt != test.T {
t.Errorf("\nexpected: `%s`\ngot(T): `%s`", test.T, qt)
}

qq, _, _ := compileNamedQuery([]byte(test.Q), NAMED)
if qq != test.N {
t.Errorf("\nexpected: `%s`\ngot: `%s`\n(len: %d vs %d)", test.N, qq, len(test.N), len(qq))
}
qq, _, _ := compileNamedQuery([]byte(test.Q), NAMED)
if qq != test.N {
t.Errorf("\nexpected: `%s`\ngot(N): `%s`\n(len: %d vs %d)", test.N, qq, len(test.N), len(qq))
}
})
}
}

Expand Down Expand Up @@ -123,7 +141,7 @@ func (t Test) Errorf(err error, format string, args ...interface{}) {

func TestEscapedColons(t *testing.T) {
t.Skip("not sure it is possible to support this in general case without an SQL parser")
var qs = `SELECT * FROM testtable WHERE timeposted BETWEEN (now() AT TIME ZONE 'utc') AND
qs := `SELECT * FROM testtable WHERE timeposted BETWEEN (now() AT TIME ZONE 'utc') AND
(now() AT TIME ZONE 'utc') - interval '01:30:00') AND name = '\'this is a test\'' and id = :id`
_, _, err := compileNamedQuery([]byte(qs), DOLLAR)
if err != nil {
Expand Down Expand Up @@ -298,7 +316,6 @@ func TestNamedQueries(t *testing.T) {
if p2.Email != sl.Email {
t.Errorf("expected %s, got %s", sl.Email, p2.Email)
}

})
}

Expand Down