Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Queryable interface #809

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 30 additions & 0 deletions sqlx.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sqlx

import (
"context"
"database/sql"
"database/sql/driver"
"errors"
Expand Down Expand Up @@ -237,6 +238,35 @@ func (r *Row) Err() error {
return r.err
}

// Queryable includes all methods shared by sqlx.DB and sqlx.Tx, allowing
// either type to be used interchangeably.
type Queryable interface {
Ext
ExecerContext
PreparerContext
QueryerContext
Preparer

GetContext(context.Context, interface{}, string, ...interface{}) error
SelectContext(context.Context, interface{}, string, ...interface{}) error
Get(interface{}, string, ...interface{}) error
MustExecContext(context.Context, string, ...interface{}) sql.Result
PreparexContext(context.Context, string) (*Stmt, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
Select(interface{}, string, ...interface{}) error
QueryRow(string, ...interface{}) *sql.Row
PrepareNamedContext(context.Context, string) (*NamedStmt, error)
PrepareNamed(string) (*NamedStmt, error)
Preparex(string) (*Stmt, error)
NamedExec(string, interface{}) (sql.Result, error)
NamedExecContext(context.Context, string, interface{}) (sql.Result, error)
MustExec(string, ...interface{}) sql.Result
NamedQuery(string, interface{}) (*Rows, error)
}

var _ Queryable = (*DB)(nil)
var _ Queryable = (*Tx)(nil)

// DB is a wrapper around sql.DB which keeps track of the driverName upon Open,
// used mostly to automatically bind named queries using the right bindvars.
type DB struct {
Expand Down
71 changes: 71 additions & 0 deletions sqlx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1924,3 +1924,74 @@ func TestSelectReset(t *testing.T) {
}
})
}

func TestQueryable(t *testing.T) {
sqlDBType := reflect.TypeOf(&sql.DB{})
dbType := reflect.TypeOf(&DB{})
sqlTxType := reflect.TypeOf(&sql.Tx{})
txType := reflect.TypeOf(&Tx{})

dbMethods := exportableMethods(sqlDBType)
for k, v := range exportableMethods(dbType) {
dbMethods[k] = v
}

txMethods := exportableMethods(sqlTxType)
for k, v := range exportableMethods(txType) {
txMethods[k] = v
}

sharedMethods := make([]string, 0)

for name, dbMethod := range dbMethods {
if txMethod, ok := txMethods[name]; ok {
if methodsEqual(dbMethod.Type, txMethod.Type) {
sharedMethods = append(sharedMethods, name)
}
}
}

queryableType := reflect.TypeOf((*Queryable)(nil)).Elem()
queryableMethods := exportableMethods(queryableType)

for _, sharedMethodName := range sharedMethods {
if _, ok := queryableMethods[sharedMethodName]; !ok {
t.Errorf("Queryable does not include shared DB/Tx method: %s", sharedMethodName)
}
}
}

func exportableMethods(t reflect.Type) map[string]reflect.Method {
methods := make(map[string]reflect.Method)

for i := 0; i < t.NumMethod(); i++ {
method := t.Method(i)

if method.IsExported() {
methods[method.Name] = method
}
}

return methods
}

func methodsEqual(t reflect.Type, ot reflect.Type) bool {
if t.NumIn() != ot.NumIn() || t.NumOut() != ot.NumOut() || t.IsVariadic() != ot.IsVariadic() {
return false
}

// Start at 1 to avoid comparing receiver argument
for i := 1; i < t.NumIn(); i++ {
if t.In(i) != ot.In(i) {
return false
}
}

for i := 0; i < t.NumOut(); i++ {
if t.Out(i) != ot.Out(i) {
return false
}
}

return true
}