diff --git a/AUTHORS b/AUTHORS index 221f4a395..98cb1e66f 100644 --- a/AUTHORS +++ b/AUTHORS @@ -107,3 +107,4 @@ Multiplay Ltd. Percona LLC Pivotal Inc. Stripe Inc. +Zendesk Inc. diff --git a/connection.go b/connection.go index b07cd7651..6769e3ce1 100644 --- a/connection.go +++ b/connection.go @@ -489,6 +489,10 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) { // BeginTx implements driver.ConnBeginTx interface func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if mc.closed.IsSet() { + return nil, driver.ErrBadConn + } + if err := mc.watchCancel(ctx); err != nil { return nil, err } diff --git a/driver_test.go b/driver_test.go index 8edd17c47..34b476ed3 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2608,7 +2608,12 @@ func TestContextCancelBegin(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) - tx, err := dbt.db.BeginTx(ctx, nil) + conn, err := dbt.db.Conn(ctx) + if err != nil { + dbt.Fatal(err) + } + defer conn.Close() + tx, err := conn.BeginTx(ctx, nil) if err != nil { dbt.Fatal(err) } @@ -2638,7 +2643,17 @@ func TestContextCancelBegin(t *testing.T) { dbt.Errorf("expected sql.ErrTxDone or context.Canceled, got %v", err) } - // Context is canceled, so cannot begin a transaction. + // The connection is now in an inoperable state - so performing other + // operations should fail with ErrBadConn + // Important to exercise isolation level too - it runs SET TRANSACTION ISOLATION + // LEVEL XXX first, which needs to return ErrBadConn if the connection's context + // is cancelled + _, err = conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelReadCommitted}) + if err != driver.ErrBadConn { + dbt.Errorf("expected driver.ErrBadConn, got %v", err) + } + + // cannot begin a transaction (on a different conn) with a canceled context if _, err := dbt.db.BeginTx(ctx, nil); err != context.Canceled { dbt.Errorf("expected context.Canceled, got %v", err) }