From 4b55993141422535e605644b07412b1b46116557 Mon Sep 17 00:00:00 2001 From: Rafi Shamim Date: Thu, 7 Apr 2022 13:50:31 -0400 Subject: [PATCH] Avoid asserting on error message for cancel tests --- conn_test.go | 59 ++++++++++++++++++++++++++------------------------ go18_test.go | 13 ++++++----- issues_test.go | 10 +++++---- 3 files changed, 45 insertions(+), 37 deletions(-) diff --git a/conn_test.go b/conn_test.go index b32f983a..eb259570 100644 --- a/conn_test.go +++ b/conn_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "database/sql/driver" + "errors" "fmt" "io" "net" @@ -1859,34 +1860,34 @@ func TestStmtQueryContext(t *testing.T) { defer db.Close() tests := []struct { - name string - ctx func() (context.Context, context.CancelFunc) - sql string - err error + name string + ctx func() (context.Context, context.CancelFunc) + sql string + cancelExpected bool }{ { name: "context.Background", ctx: func() (context.Context, context.CancelFunc) { return context.Background(), nil }, - sql: "SELECT pg_sleep(1);", - err: nil, + sql: "SELECT pg_sleep(1);", + cancelExpected: false, }, { name: "context.WithTimeout exceeded", ctx: func() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), 1*time.Second) }, - sql: "SELECT pg_sleep(10);", - err: &Error{Message: "canceling statement due to user request"}, + sql: "SELECT pg_sleep(10);", + cancelExpected: true, }, { name: "context.WithTimeout", ctx: func() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), time.Minute) }, - sql: "SELECT pg_sleep(1);", - err: nil, + sql: "SELECT pg_sleep(1);", + cancelExpected: false, }, } for _, tt := range tests { @@ -1900,11 +1901,12 @@ func TestStmtQueryContext(t *testing.T) { t.Fatal(err) } _, err = stmt.QueryContext(ctx) + pgErr := (*Error)(nil) switch { - case (err != nil) != (tt.err != nil): - t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, expected = %v", err, tt.err) - case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()): - t.Errorf("stmt.QueryContext() got = %v, expected = %v", err.Error(), tt.err.Error()) + case (err != nil) != tt.cancelExpected: + t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, cancelExpected = %v", err, tt.cancelExpected) + case (err != nil && tt.cancelExpected) && !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode): + t.Errorf("stmt.QueryContext() got = %v, cancelExpected = %v", err.Error(), tt.cancelExpected) } }) } @@ -1915,34 +1917,34 @@ func TestStmtExecContext(t *testing.T) { defer db.Close() tests := []struct { - name string - ctx func() (context.Context, context.CancelFunc) - sql string - err error + name string + ctx func() (context.Context, context.CancelFunc) + sql string + cancelExpected bool }{ { name: "context.Background", ctx: func() (context.Context, context.CancelFunc) { return context.Background(), nil }, - sql: "SELECT pg_sleep(1);", - err: nil, + sql: "SELECT pg_sleep(1);", + cancelExpected: false, }, { name: "context.WithTimeout exceeded", ctx: func() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), 1*time.Second) }, - sql: "SELECT pg_sleep(10);", - err: &Error{Message: "canceling statement due to user request"}, + sql: "SELECT pg_sleep(10);", + cancelExpected: true, }, { name: "context.WithTimeout", ctx: func() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), time.Minute) }, - sql: "SELECT pg_sleep(1);", - err: nil, + sql: "SELECT pg_sleep(1);", + cancelExpected: false, }, } for _, tt := range tests { @@ -1956,11 +1958,12 @@ func TestStmtExecContext(t *testing.T) { t.Fatal(err) } _, err = stmt.ExecContext(ctx) + pgErr := (*Error)(nil) switch { - case (err != nil) != (tt.err != nil): - t.Fatalf("stmt.ExecContext() unexpected nil err got = %v, expected = %v", err, tt.err) - case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()): - t.Errorf("stmt.ExecContext() got = %v, expected = %v", err.Error(), tt.err.Error()) + case (err != nil) != tt.cancelExpected: + t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, cancelExpected = %v", err, tt.cancelExpected) + case (err != nil && tt.cancelExpected) && !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode): + t.Errorf("stmt.QueryContext() got = %v, cancelExpected = %v", err.Error(), tt.cancelExpected) } }) } diff --git a/go18_test.go b/go18_test.go index 27501e74..bcc02006 100644 --- a/go18_test.go +++ b/go18_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "database/sql/driver" + "errors" "runtime" "strings" "testing" @@ -75,6 +76,8 @@ func TestMultipleSimpleQuery(t *testing.T) { const contextRaceIterations = 100 +const cancelErrorCode ErrorCode = "57014" + func TestContextCancelExec(t *testing.T) { db := openTestConn(t) defer db.Close() @@ -87,7 +90,7 @@ func TestContextCancelExec(t *testing.T) { // Not canceled until after the exec has started. if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil { t.Fatal("expected error") - } else if err.Error() != "pq: canceling statement due to user request" { + } else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { t.Fatalf("unexpected error: %s", err) } @@ -125,7 +128,7 @@ func TestContextCancelQuery(t *testing.T) { // Not canceled until after the exec has started. if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil { t.Fatal("expected error") - } else if err.Error() != "pq: canceling statement due to user request" { + } else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { t.Fatalf("unexpected error: %s", err) } @@ -215,7 +218,7 @@ func TestContextCancelBegin(t *testing.T) { // Not canceled until after the exec has started. if _, err := tx.Exec("select pg_sleep(1)"); err == nil { t.Fatal("expected error") - } else if err.Error() != "pq: canceling statement due to user request" { + } else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { t.Fatalf("unexpected error: %s", err) } @@ -240,8 +243,8 @@ func TestContextCancelBegin(t *testing.T) { cancel() if err != nil { t.Fatal(err) - } else if err := tx.Rollback(); err != nil && - err.Error() != "pq: canceling statement due to user request" && + } else if err, pgErr := tx.Rollback(), (*Error)(nil); err != nil && + !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) && err != sql.ErrTxDone && err != driver.ErrBadConn && err != context.Canceled { t.Fatal(err) } diff --git a/issues_test.go b/issues_test.go index 4d24c9dd..26a70282 100644 --- a/issues_test.go +++ b/issues_test.go @@ -2,6 +2,7 @@ package pq import ( "context" + "errors" "testing" "time" ) @@ -51,10 +52,9 @@ func TestIssue1046(t *testing.T) { t.Logf("FAIL %s: query returned after context deadline: %v\n", t.Name(), since) t.Fail() } - expectedErr := &Error{Message: "canceling statement due to user request"} - if err == nil || err.Error() != expectedErr.Error() { + if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { t.Logf("ctx.Err(): [%T]%+v\n", ctx.Err(), ctx.Err()) - t.Logf("got err: [%T] %+v expected err: [%T] %+v", err, err, expectedErr, expectedErr) + t.Logf("got err: [%T] %+v expected errCode: %v", err, err, cancelErrorCode) t.Fail() } } @@ -72,7 +72,9 @@ func TestIssue1062(t *testing.T) { var v int err := row.Scan(&v) - if err != nil && err != context.Canceled && err.Error() != "pq: canceling statement due to user request" { + if pgErr := (*Error)(nil); err != nil && + err != context.Canceled && + !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { t.Fatalf("Scan resulted in unexpected error %v for canceled QueryRowContext at attempt %d", err, i+1) } }