Skip to content

Commit

Permalink
Merge pull request #768 from kylejbrock/master
Browse files Browse the repository at this point in the history
implement additional context specific sql interfaces
  • Loading branch information
maddyblue committed Dec 2, 2019
2 parents f91d341 + 4d5921a commit a2bfbdf
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
*.test
*~
*.swp
.idea
4 changes: 4 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,10 @@ func (st *stmt) Close() (err error) {
}

func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
return st.query(v)
}

func (st *stmt) query(v []driver.Value) (r *rows, err error) {
if st.cn.bad {
return nil, driver.ErrBadConn
}
Expand Down
65 changes: 65 additions & 0 deletions conn_go18.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.Nam
return cn.Exec(query, list)
}

// Implement the "ConnPrepareContext" interface
func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if finish := cn.watchCancel(ctx); finish != nil {
defer finish()
}
return cn.Prepare(query)
}

// Implement the "ConnBeginTx" interface
func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
var mode string
Expand Down Expand Up @@ -147,3 +155,60 @@ func (cn *conn) cancel(ctx context.Context) error {
return err
}
}

// Implement the "StmtQueryContext" interface
func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
list := make([]driver.Value, len(args))
for i, nv := range args {
list[i] = nv.Value
}
finish := st.watchCancel(ctx)
r, err := st.query(list)
if err != nil {
if finish != nil {
finish()
}
return nil, err
}
r.finish = finish
return r, nil
}

// Implement the "StmtExecContext" interface
func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
list := make([]driver.Value, len(args))
for i, nv := range args {
list[i] = nv.Value
}

if finish := st.watchCancel(ctx); finish != nil {
defer finish()
}

return st.Exec(list)
}

func (st *stmt) watchCancel(ctx context.Context) func() {
if done := ctx.Done(); done != nil {
finished := make(chan struct{})
go func() {
select {
case <-done:
_ = st.cancel()
finished <- struct{}{}
case <-finished:
}
}()
return func() {
select {
case <-finished:
case finished <- struct{}{}:
}
}
}
return nil
}

func (st *stmt) cancel() error {
return st.cn.cancel()
}

0 comments on commit a2bfbdf

Please sign in to comment.