diff --git a/redis/conn.go b/redis/conn.go index 99dc6fa1..526a5649 100644 --- a/redis/conn.go +++ b/redis/conn.go @@ -710,6 +710,36 @@ func (c *conn) Receive() (interface{}, error) { return c.ReceiveWithTimeout(c.readTimeout) } +func (c *conn) ReceiveContext(ctx context.Context) (interface{}, error) { + var realTimeout time.Duration + if dl, ok := ctx.Deadline(); ok { + timeout := time.Until(dl) + if timeout >= c.readTimeout && c.readTimeout != 0 { + realTimeout = c.readTimeout + } else if timeout <= 0 { + return nil, c.fatal(context.DeadlineExceeded) + } else { + realTimeout = timeout + } + } else { + realTimeout = c.readTimeout + } + endch := make(chan struct{}) + var r interface{} + var e error + go func() { + defer close(endch) + + r, e = c.ReceiveWithTimeout(realTimeout) + }() + select { + case <-ctx.Done(): + return nil, c.fatal(ctx.Err()) + case <-endch: + return r, e + } +} + func (c *conn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) { var deadline time.Time if timeout != 0 { @@ -744,6 +774,36 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { return c.DoWithTimeout(c.readTimeout, cmd, args...) } +func (c *conn) DoContext(ctx context.Context, cmd string, args ...interface{}) (interface{}, error) { + var realTimeout time.Duration + if dl, ok := ctx.Deadline(); ok { + timeout := time.Until(dl) + if timeout >= c.readTimeout && c.readTimeout != 0 { + realTimeout = c.readTimeout + } else if timeout <= 0 { + return nil, c.fatal(context.DeadlineExceeded) + } else { + realTimeout = timeout + } + } else { + realTimeout = c.readTimeout + } + endch := make(chan struct{}) + var r interface{} + var e error + go func() { + defer close(endch) + + r, e = c.DoWithTimeout(realTimeout, cmd, args) + }() + select { + case <-ctx.Done(): + return nil, c.fatal(ctx.Err()) + case <-endch: + return r, e + } +} + func (c *conn) DoWithTimeout(readTimeout time.Duration, cmd string, args ...interface{}) (interface{}, error) { c.mu.Lock() pending := c.pending diff --git a/redis/log.go b/redis/log.go index ef8cd7a0..72e054f0 100644 --- a/redis/log.go +++ b/redis/log.go @@ -16,6 +16,7 @@ package redis import ( "bytes" + "context" "fmt" "log" "time" @@ -121,6 +122,12 @@ func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{}, return reply, err } +func (c *loggingConn) DoContext(ctx context.Context, commandName string, args ...interface{}) (interface{}, error) { + reply, err := DoContext(c.Conn, ctx, commandName, args...) + c.print("DoContext", commandName, args, reply, err) + return reply, err +} + func (c *loggingConn) DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (interface{}, error) { reply, err := DoWithTimeout(c.Conn, timeout, commandName, args...) c.print("DoWithTimeout", commandName, args, reply, err) @@ -139,6 +146,12 @@ func (c *loggingConn) Receive() (interface{}, error) { return reply, err } +func (c *loggingConn) ReceiveContext(ctx context.Context) (interface{}, error) { + reply, err := ReceiveContext(c.Conn, ctx) + c.print("ReceiveContext", "", nil, reply, err) + return reply, err +} + func (c *loggingConn) ReceiveWithTimeout(timeout time.Duration) (interface{}, error) { reply, err := ReceiveWithTimeout(c.Conn, timeout) c.print("ReceiveWithTimeout", "", nil, reply, err) diff --git a/redis/pool.go b/redis/pool.go index c7a2f194..d7bb71e0 100644 --- a/redis/pool.go +++ b/redis/pool.go @@ -512,6 +512,20 @@ func (ac *activeConn) Err() error { return pc.c.Err() } +func (ac *activeConn) DoContext(ctx context.Context, commandName string, args ...interface{}) (reply interface{}, err error) { + pc := ac.pc + if pc == nil { + return nil, errConnClosed + } + cwt, ok := pc.c.(ConnWithContext) + if !ok { + return nil, errContextNotSupported + } + ci := lookupCommandInfo(commandName) + ac.state = (ac.state | ci.Set) &^ ci.Clear + return cwt.DoContext(ctx, commandName, args...) +} + func (ac *activeConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) { pc := ac.pc if pc == nil { @@ -562,6 +576,18 @@ func (ac *activeConn) Receive() (reply interface{}, err error) { return pc.c.Receive() } +func (ac *activeConn) ReceiveContext(ctx context.Context) (reply interface{}, err error) { + pc := ac.pc + if pc == nil { + return nil, errConnClosed + } + cwt, ok := pc.c.(ConnWithContext) + if !ok { + return nil, errContextNotSupported + } + return cwt.ReceiveContext(ctx) +} + func (ac *activeConn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) { pc := ac.pc if pc == nil { @@ -577,6 +603,9 @@ func (ac *activeConn) ReceiveWithTimeout(timeout time.Duration) (reply interface type errorConn struct{ err error } func (ec errorConn) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err } +func (ec errorConn) DoContext(context.Context, string, ...interface{}) (interface{}, error) { + return nil, ec.err +} func (ec errorConn) DoWithTimeout(time.Duration, string, ...interface{}) (interface{}, error) { return nil, ec.err } @@ -585,6 +614,7 @@ func (ec errorConn) Err() error { ret func (ec errorConn) Close() error { return nil } func (ec errorConn) Flush() error { return ec.err } func (ec errorConn) Receive() (interface{}, error) { return nil, ec.err } +func (ec errorConn) ReceiveContext(context.Context) (interface{}, error) { return nil, ec.err } func (ec errorConn) ReceiveWithTimeout(time.Duration) (interface{}, error) { return nil, ec.err } type idleList struct { diff --git a/redis/redis.go b/redis/redis.go index e4464874..5529dbd2 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -15,6 +15,7 @@ package redis import ( + "context" "errors" "time" ) @@ -33,6 +34,7 @@ type Conn interface { Err() error // Do sends a command to the server and returns the received reply. + // This function will use the timeout which was set when the connection is created Do(commandName string, args ...interface{}) (reply interface{}, err error) // Send writes the command to the client's output buffer. @@ -82,17 +84,52 @@ type Scanner interface { type ConnWithTimeout interface { Conn - // Do sends a command to the server and returns the received reply. - // The timeout overrides the read timeout set when dialing the - // connection. + // DoWithTimeout sends a command to the server and returns the received reply. + // The timeout overrides the readtimeout set when dialing the connection. DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (reply interface{}, err error) - // Receive receives a single reply from the Redis server. The timeout - // overrides the read timeout set when dialing the connection. + // ReceiveWithTimeout receives a single reply from the Redis server. + // The timeout overrides the readtimeout set when dialing the connection. ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) } +// ConnWithContext is an optional interface that allows the caller to control the command's life with context. +type ConnWithContext interface { + Conn + + // DoContext sends a command to server and returns the received reply. + // min(ctx,DialReadTimeout()) will be used as the deadline. + // The connection will be closed if DialReadTimeout() timeout or ctx timeout or ctx canceled when this function is running. + // DialReadTimeout() timeout return err can be checked by strings.Contains(e.Error(), "io/timeout"). + // ctx timeout return err context.DeadlineExceeded. + // ctx canceled return err context.Canceled. + DoContext(ctx context.Context, commandName string, args ...interface{}) (reply interface{}, err error) + + // ReceiveContext receives a single reply from the Redis server. + // min(ctx,DialReadTimeout()) will be used as the deadline. + // The connection will be closed if DialReadTimeout() timeout or ctx timeout or ctx canceled when this function is running. + // DialReadTimeout() timeout return err can be checked by strings.Contains(e.Error(), "io/timeout"). + // ctx timeout return err context.DeadlineExceeded. + // ctx canceled return err context.Canceled. + ReceiveContext(ctx context.Context) (reply interface{}, err error) +} + var errTimeoutNotSupported = errors.New("redis: connection does not support ConnWithTimeout") +var errContextNotSupported = errors.New("redis: connection does not support ConnWithContext") + +// DoContext sends a command to server and returns the received reply. +// min(ctx,DialReadTimeout()) will be used as the deadline. +// The connection will be closed if DialReadTimeout() timeout or ctx timeout or ctx canceled when this function is running. +// DialReadTimeout() timeout return err can be checked by strings.Contains(e.Error(), "io/timeout"). +// ctx timeout return err context.DeadlineExceeded. +// ctx canceled return err context.Canceled. +func DoContext(c Conn, ctx context.Context, cmd string, args ...interface{}) (interface{}, error) { + cwt, ok := c.(ConnWithContext) + if !ok { + return nil, errContextNotSupported + } + return cwt.DoContext(ctx, cmd, args...) +} // DoWithTimeout executes a Redis command with the specified read timeout. If // the connection does not satisfy the ConnWithTimeout interface, then an error @@ -105,6 +142,20 @@ func DoWithTimeout(c Conn, timeout time.Duration, cmd string, args ...interface{ return cwt.DoWithTimeout(timeout, cmd, args...) } +// ReceiveContext receives a single reply from the Redis server. +// min(ctx,DialReadTimeout()) will be used as the deadline. +// The connection will be closed if DialReadTimeout() timeout or ctx timeout or ctx canceled when this function is running. +// DialReadTimeout() timeout return err can be checked by strings.Contains(e.Error(), "io/timeout"). +// ctx timeout return err context.DeadlineExceeded. +// ctx canceled return err context.Canceled. +func ReceiveContext(c Conn, ctx context.Context) (interface{}, error) { + cwt, ok := c.(ConnWithContext) + if !ok { + return nil, errContextNotSupported + } + return cwt.ReceiveContext(ctx) +} + // ReceiveWithTimeout receives a reply with the specified read timeout. If the // connection does not satisfy the ConnWithTimeout interface, then an error is // returned. diff --git a/redis/redis_test.go b/redis/redis_test.go index 5a98f535..e0f4538d 100644 --- a/redis/redis_test.go +++ b/redis/redis_test.go @@ -15,6 +15,7 @@ package redis_test import ( + "context" "testing" "time" @@ -26,6 +27,7 @@ type timeoutTestConn int func (tc timeoutTestConn) Do(string, ...interface{}) (interface{}, error) { return time.Duration(-1), nil } + func (tc timeoutTestConn) DoWithTimeout(timeout time.Duration, cmd string, args ...interface{}) (interface{}, error) { return timeout, nil } @@ -33,6 +35,7 @@ func (tc timeoutTestConn) DoWithTimeout(timeout time.Duration, cmd string, args func (tc timeoutTestConn) Receive() (interface{}, error) { return time.Duration(-1), nil } + func (tc timeoutTestConn) ReceiveWithTimeout(timeout time.Duration) (interface{}, error) { return timeout, nil } @@ -69,3 +72,52 @@ func TestPoolConnTimeout(t *testing.T) { p := &redis.Pool{Dial: func() (redis.Conn, error) { return timeoutTestConn(0), nil }} testTimeout(t, p.Get()) } + +type contextDeadTestConn int + +func (cc contextDeadTestConn) Do(string, ...interface{}) (interface{}, error) { + return -1, nil +} +func (cc contextDeadTestConn) DoContext(ctx context.Context, cmd string, args ...interface{}) (interface{}, error) { + return 1, nil +} +func (cc contextDeadTestConn) Receive() (interface{}, error) { + return -1, nil +} +func (cc contextDeadTestConn) ReceiveContext(ctx context.Context) (interface{}, error) { + return 1, nil +} +func (cc contextDeadTestConn) Send(string, ...interface{}) error { return nil } +func (cc contextDeadTestConn) Err() error { return nil } +func (cc contextDeadTestConn) Close() error { return nil } +func (cc contextDeadTestConn) Flush() error { return nil } + +func testcontext(t *testing.T, c redis.Conn) { + r, e := c.Do("PING") + if r != -1 || e != nil { + t.Errorf("Do() = %v, %v, want %v, %v", r, e, -1, nil) + } + ctx, f := context.WithTimeout(context.Background(), time.Minute) + defer f() + r, e = redis.DoContext(c, ctx, "PING") + if r != 1 || e != nil { + t.Errorf("DoContext() = %v, %v, want %v, %v", r, e, 1, nil) + } + r, e = c.Receive() + if r != -1 || e != nil { + t.Errorf("Receive() = %v, %v, want %v, %v", r, e, -1, nil) + } + r, e = redis.ReceiveContext(c, ctx) + if r != 1 || e != nil { + t.Errorf("ReceiveContext() = %v, %v, want %v, %v", r, e, 1, nil) + } +} + +func TestConnContext(t *testing.T) { + testcontext(t, contextDeadTestConn(0)) +} + +func TestPoolConnContext(t *testing.T) { + p := redis.Pool{Dial: func() (redis.Conn, error) { return contextDeadTestConn(0), nil }} + testcontext(t, p.Get()) +} diff --git a/redis/script.go b/redis/script.go index d0cec1ed..bb5d7b00 100644 --- a/redis/script.go +++ b/redis/script.go @@ -15,6 +15,7 @@ package redis import ( + "context" "crypto/sha1" "encoding/hex" "io" @@ -60,6 +61,18 @@ func (s *Script) Hash() string { return s.hash } +func (s *Script) DoContext(ctx context.Context, c Conn, keysAndArgs ...interface{}) (interface{}, error) { + cwt, ok := c.(ConnWithContext) + if !ok { + return nil, errContextNotSupported + } + v, err := cwt.DoContext(ctx, "EVALSHA", s.args(s.hash, keysAndArgs)...) + if e, ok := err.(Error); ok && strings.HasPrefix(string(e), "NOSCRIPT ") { + v, err = cwt.DoContext(ctx, "EVAL", s.args(s.src, keysAndArgs)...) + } + return v, err +} + // Do evaluates the script. Under the covers, Do optimistically evaluates the // script using the EVALSHA command. If the command fails because the script is // not loaded, then Do evaluates the script using the EVAL command (thus