diff --git a/crdb/common.go b/crdb/common.go index 8039923..7beefa3 100644 --- a/crdb/common.go +++ b/crdb/common.go @@ -62,8 +62,7 @@ func ExecuteInTx(ctx context.Context, tx Tx, fn func() error) (err error) { return err } - // TODO(rafi): make the maxRetryCount configurable. Maybe pass it in the context?) - const maxRetries = 50 + maxRetries := numRetriesFromContext(ctx) retryCount := 0 for { releaseFailed := false @@ -90,7 +89,7 @@ func ExecuteInTx(ctx context.Context, tx Tx, fn func() error) (err error) { } retryCount++ - if retryCount > maxRetries { + if maxRetries > 0 && retryCount > maxRetries { return newMaxRetriesExceededError(err, maxRetries) } } diff --git a/crdb/tx.go b/crdb/tx.go index 07e78ad..7120e08 100644 --- a/crdb/tx.go +++ b/crdb/tx.go @@ -102,6 +102,26 @@ func Execute(fn func() error) (err error) { } } +type txConfigKey struct {} + +// WithMaxRetries configures context so that ExecuteTx retries tx specified +// number of times when encountering retryable errors. +// Setting retries to 0 will retry indefinitely. +func WithMaxRetries(ctx context.Context, retries int) context.Context { + return context.WithValue(ctx, txConfigKey{}, retries) +} + +const defaultRetries = 50 + +func numRetriesFromContext(ctx context.Context) int { + if v := ctx.Value(txConfigKey{}); v != nil { + if retries, ok := v.(int); ok && retries >= 0 { + return retries + } + } + return defaultRetries +} + // ExecuteTx runs fn inside a transaction and retries it as needed. On // non-retryable failures, the transaction is aborted and rolled back; on // success, the transaction is committed. diff --git a/crdb/tx_test.go b/crdb/tx_test.go index 2c1ae9d..ef0dda9 100644 --- a/crdb/tx_test.go +++ b/crdb/tx_test.go @@ -35,6 +35,19 @@ func TestExecuteTx(t *testing.T) { } } +// TestConfigureRetries verifies that the number of retries can be specified +// via context. +func TestConfigureRetries(t *testing.T) { + ctx := context.Background() + if numRetriesFromContext(ctx) != defaultRetries { + t.Fatal("expect default number of retries") + } + ctx = WithMaxRetries(context.Background(), 123 + defaultRetries) + if numRetriesFromContext(ctx) != defaultRetries + 123 { + t.Fatal("expected default+123 retires") + } +} + type stdlibWriteSkewTest struct { db *sql.DB }