Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UnsafeLogged added to report missing fields #789

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
113 changes: 96 additions & 17 deletions sqlx.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"

"io"
"io/ioutil"
"path/filepath"
"reflect"
Expand Down Expand Up @@ -144,6 +145,43 @@ func isUnsafe(i interface{}) bool {
}
}

func logFor(i interface{}) io.Writer {
switch v := i.(type) {
case Row:
return v.log
case *Row:
return v.log
case Rows:
return v.log
case *Rows:
return v.log
case NamedStmt:
return v.Stmt.log
case *NamedStmt:
return v.Stmt.log
case Stmt:
return v.log
case *Stmt:
return v.log
case qStmt:
return v.log
case *qStmt:
return v.log
case DB:
return v.log
case *DB:
return v.log
case Tx:
return v.log
case *Tx:
return v.log
case sql.Rows, *sql.Rows:
return nil
default:
return nil
}
}

func mapperFor(i interface{}) *reflectx.Mapper {
switch i := i.(type) {
case DB:
Expand All @@ -167,6 +205,7 @@ var _valuerInterface = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
type Row struct {
err error
unsafe bool
log io.Writer
rows *sql.Rows
Mapper *reflectx.Mapper
}
Expand Down Expand Up @@ -243,6 +282,7 @@ type DB struct {
*sql.DB
driverName string
unsafe bool
log io.Writer
Mapper *reflectx.Mapper
}

Expand Down Expand Up @@ -291,7 +331,16 @@ func (db *DB) Rebind(query string) string {
// sqlx.Stmt and sqlx.Tx which are created from this DB will inherit its
// safety behavior.
func (db *DB) Unsafe() *DB {
return &DB{DB: db.DB, driverName: db.driverName, unsafe: true, Mapper: db.Mapper}
return db.UnsafeLogged(nil)
}

// Like Unsafe, UnsafeLogged returns a version of DB which will succeed to scan
// when columns in the SQL result have no fields in the destination struct.
// But unlike Unsafe(), this will write a short log, if it does.
// sqlx.Stmt and sqlx.Tx which are created from this DB will inherit its
// safety behavior.
func (db *DB) UnsafeLogged(log io.Writer) *DB {
return &DB{DB: db.DB, driverName: db.driverName, unsafe: true, log: log, Mapper: db.Mapper}
}

// BindNamed binds a query using the DB driver's bindvar type.
Expand Down Expand Up @@ -340,7 +389,7 @@ func (db *DB) Beginx() (*Tx, error) {
if err != nil {
return nil, err
}
return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err
return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, log: db.log, Mapper: db.Mapper}, err
}

// Queryx queries the database and returns an *sqlx.Rows.
Expand All @@ -350,14 +399,14 @@ func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) {
if err != nil {
return nil, err
}
return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err
return &Rows{Rows: r, unsafe: db.unsafe, log: db.log, Mapper: db.Mapper}, err
}

// QueryRowx queries the database and returns an *sqlx.Row.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) QueryRowx(query string, args ...interface{}) *Row {
rows, err := db.DB.Query(query, args...)
return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper}
return &Row{rows: rows, err: err, unsafe: db.unsafe, log: db.log, Mapper: db.Mapper}
}

// MustExec (panic) runs MustExec using this database.
Expand All @@ -381,6 +430,7 @@ type Conn struct {
*sql.Conn
driverName string
unsafe bool
log io.Writer
Mapper *reflectx.Mapper
}

Expand All @@ -389,6 +439,7 @@ type Tx struct {
*sql.Tx
driverName string
unsafe bool
log io.Writer
Mapper *reflectx.Mapper
}

Expand All @@ -405,7 +456,14 @@ func (tx *Tx) Rebind(query string) string {
// Unsafe returns a version of Tx which will silently succeed to scan when
// columns in the SQL result have no fields in the destination struct.
func (tx *Tx) Unsafe() *Tx {
return &Tx{Tx: tx.Tx, driverName: tx.driverName, unsafe: true, Mapper: tx.Mapper}
return tx.UnsafeLogged(nil)
}

// Like Unsafe, UnsafeLogged returns a version of Tx which will succeed to
// scan when columns in the SQL result have no fields in the destination struct.
// But unlike Unsafe(), this will write a short log, if it does.
func (tx *Tx) UnsafeLogged(log io.Writer) *Tx {
return &Tx{Tx: tx.Tx, driverName: tx.driverName, unsafe: true, log: log, Mapper: tx.Mapper}
}

// BindNamed binds a query within a transaction's bindvar type.
Expand Down Expand Up @@ -438,14 +496,14 @@ func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) {
if err != nil {
return nil, err
}
return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err
return &Rows{Rows: r, unsafe: tx.unsafe, log: tx.log, Mapper: tx.Mapper}, err
}

// QueryRowx within a transaction.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row {
rows, err := tx.Tx.Query(query, args...)
return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper}
return &Row{rows: rows, err: err, unsafe: tx.unsafe, log: tx.log, Mapper: tx.Mapper}
}

// Get within a transaction.
Expand Down Expand Up @@ -501,13 +559,21 @@ func (tx *Tx) PrepareNamed(query string) (*NamedStmt, error) {
type Stmt struct {
*sql.Stmt
unsafe bool
log io.Writer
Mapper *reflectx.Mapper
}

// Unsafe returns a version of Stmt which will silently succeed to scan when
// columns in the SQL result have no fields in the destination struct.
func (s *Stmt) Unsafe() *Stmt {
return &Stmt{Stmt: s.Stmt, unsafe: true, Mapper: s.Mapper}
return s.UnsafeLogged(nil)
}

// Like Unsafe, UnsafeLogged returns a version of Stmt which will succeed to
// scan when columns in the SQL result have no fields in the destination struct.
// But unlike Unsafe(), this will write a short log, if it does.
func (s *Stmt) UnsafeLogged(log io.Writer) *Stmt {
return &Stmt{Stmt: s.Stmt, unsafe: true, log: log, Mapper: s.Mapper}
}

// Select using the prepared statement.
Expand Down Expand Up @@ -557,12 +623,12 @@ func (q *qStmt) Queryx(query string, args ...interface{}) (*Rows, error) {
if err != nil {
return nil, err
}
return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err
return &Rows{Rows: r, unsafe: q.Stmt.unsafe, log: q.Stmt.log, Mapper: q.Stmt.Mapper}, err
}

func (q *qStmt) QueryRowx(query string, args ...interface{}) *Row {
rows, err := q.Stmt.Query(args...)
return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}
return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, log: q.Stmt.log, Mapper: q.Stmt.Mapper}
}

func (q *qStmt) Exec(query string, args ...interface{}) (sql.Result, error) {
Expand All @@ -574,6 +640,7 @@ func (q *qStmt) Exec(query string, args ...interface{}) (sql.Result, error) {
type Rows struct {
*sql.Rows
unsafe bool
log io.Writer
Mapper *reflectx.Mapper
// these fields cache memory use for a rows during iteration w/ structScan
started bool
Expand Down Expand Up @@ -614,8 +681,12 @@ func (r *Rows) StructScan(dest interface{}) error {

r.fields = m.TraversalsByName(v.Type(), columns)
// if we are not unsafe and are missing fields, return an error
if f, err := missingFields(r.fields); err != nil && !r.unsafe {
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
if f, err := missingFields(r.fields); err != nil {
if !r.unsafe {
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
} else if r.log != nil {
fmt.Fprintf(r.log, "missing destination name %s in %T\n", columns[f], dest)
}
}
r.values = make([]interface{}, len(columns))
r.started = true
Expand Down Expand Up @@ -662,7 +733,7 @@ func Preparex(p Preparer, query string) (*Stmt, error) {
if err != nil {
return nil, err
}
return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err
return &Stmt{Stmt: s, unsafe: isUnsafe(p), log: logFor(p), Mapper: mapperFor(p)}, err
}

// Select executes a query using the provided Queryer, and StructScans each row
Expand Down Expand Up @@ -776,8 +847,12 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error {

fields := m.TraversalsByName(v.Type(), columns)
// if we are not unsafe and are missing fields, return an error
if f, err := missingFields(fields); err != nil && !r.unsafe {
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
if f, err := missingFields(fields); err != nil {
if !r.unsafe {
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
} else if r.log != nil {
fmt.Fprintf(r.log, "missing destination name %s in %T\n", columns[f], dest)
}
}
values := make([]interface{}, len(columns))

Expand Down Expand Up @@ -942,8 +1017,12 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error {

fields := m.TraversalsByName(base, columns)
// if we are not unsafe and are missing fields, return an error
if f, err := missingFields(fields); err != nil && !isUnsafe(rows) {
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
if f, err := missingFields(fields); err != nil {
if !isUnsafe(rows) {
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
} else if log := logFor(rows); log != nil {
fmt.Fprintf(log, "missing destination name %s in %T\n", columns[f], dest)
}
}
values = make([]interface{}, len(columns))

Expand Down