Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add DoContext and ReceiveContext,use context to control the life #537

Merged
merged 12 commits into from Sep 30, 2021
60 changes: 60 additions & 0 deletions redis/conn.go
Expand Up @@ -678,6 +678,36 @@ func (c *conn) Receive() (interface{}, error) {
return c.ReceiveWithTimeout(c.readTimeout)
}

func (c *conn) ReceiveContext(ctx context.Context) (interface{}, error) {
chenjie199234 marked this conversation as resolved.
Show resolved Hide resolved
var realTimeout time.Duration
if dl, ok := ctx.Deadline(); ok {
timeout := time.Until(dl)
if timeout >= c.readTimeout && c.readTimeout != 0 {
realTimeout = c.readTimeout
} else if timeout <= 0 {
return nil, c.fatal(context.DeadlineExceeded)
} else {
realTimeout = timeout
}
} else {
realTimeout = c.readTimeout
}
endch := make(chan struct{})
var r interface{}
var e error
go func() {
defer close(endch)

r, e = c.ReceiveWithTimeout(realTimeout)
}()
select {
case <-ctx.Done():
return nil, c.fatal(ctx.Err())
case <-endch:
return r, e
}
}

func (c *conn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
var deadline time.Time
if timeout != 0 {
Expand Down Expand Up @@ -710,6 +740,36 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
return c.DoWithTimeout(c.readTimeout, cmd, args...)
}

func (c *conn) DoContext(ctx context.Context, cmd string, args ...interface{}) (interface{}, error) {
chenjie199234 marked this conversation as resolved.
Show resolved Hide resolved
var realTimeout time.Duration
if dl, ok := ctx.Deadline(); ok {
timeout := time.Until(dl)
if timeout >= c.readTimeout && c.readTimeout != 0 {
realTimeout = c.readTimeout
} else if timeout <= 0 {
return nil, c.fatal(context.DeadlineExceeded)
} else {
realTimeout = timeout
}
} else {
realTimeout = c.readTimeout
}
endch := make(chan struct{})
var r interface{}
var e error
go func() {
defer close(endch)

r, e = c.DoWithTimeout(realTimeout, cmd, args)
}()
select {
case <-ctx.Done():
return nil, c.fatal(ctx.Err())
case <-endch:
return r, e
}
}
chenjie199234 marked this conversation as resolved.
Show resolved Hide resolved

func (c *conn) DoWithTimeout(readTimeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
c.mu.Lock()
pending := c.pending
Expand Down
13 changes: 13 additions & 0 deletions redis/log.go
Expand Up @@ -16,6 +16,7 @@ package redis

import (
"bytes"
"context"
"fmt"
"log"
"time"
Expand Down Expand Up @@ -121,6 +122,12 @@ func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{},
return reply, err
}

func (c *loggingConn) DoContext(ctx context.Context, commandName string, args ...interface{}) (interface{}, error) {
reply, err := DoContext(c.Conn, ctx, commandName, args...)
c.print("DoContext", commandName, args, reply, err)
return reply, err
}

func (c *loggingConn) DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (interface{}, error) {
reply, err := DoWithTimeout(c.Conn, timeout, commandName, args...)
c.print("DoWithTimeout", commandName, args, reply, err)
Expand All @@ -139,6 +146,12 @@ func (c *loggingConn) Receive() (interface{}, error) {
return reply, err
}

func (c *loggingConn) ReceiveContext(ctx context.Context) (interface{}, error) {
reply, err := ReceiveContext(c.Conn, ctx)
c.print("ReceiveContext", "", nil, reply, err)
return reply, err
}

func (c *loggingConn) ReceiveWithTimeout(timeout time.Duration) (interface{}, error) {
reply, err := ReceiveWithTimeout(c.Conn, timeout)
c.print("ReceiveWithTimeout", "", nil, reply, err)
Expand Down
30 changes: 30 additions & 0 deletions redis/pool.go
Expand Up @@ -497,6 +497,20 @@ func (ac *activeConn) Err() error {
return pc.c.Err()
}

func (ac *activeConn) DoContext(ctx context.Context, commandName string, args ...interface{}) (reply interface{}, err error) {
pc := ac.pc
if pc == nil {
return nil, errConnClosed
}
cwt, ok := pc.c.(ConnWithContext)
if !ok {
return nil, errContextNotSupported
}
ci := lookupCommandInfo(commandName)
ac.state = (ac.state | ci.Set) &^ ci.Clear
return cwt.DoContext(ctx, commandName, args...)
}

func (ac *activeConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
pc := ac.pc
if pc == nil {
Expand Down Expand Up @@ -547,6 +561,18 @@ func (ac *activeConn) Receive() (reply interface{}, err error) {
return pc.c.Receive()
}

func (ac *activeConn) ReceiveContext(ctx context.Context) (reply interface{}, err error) {
chenjie199234 marked this conversation as resolved.
Show resolved Hide resolved
pc := ac.pc
if pc == nil {
return nil, errConnClosed
}
cwt, ok := pc.c.(ConnWithContext)
if !ok {
return nil, errContextNotSupported
}
return cwt.ReceiveContext(ctx)
}
chenjie199234 marked this conversation as resolved.
Show resolved Hide resolved

func (ac *activeConn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
pc := ac.pc
if pc == nil {
Expand All @@ -562,6 +588,9 @@ func (ac *activeConn) ReceiveWithTimeout(timeout time.Duration) (reply interface
type errorConn struct{ err error }

func (ec errorConn) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err }
func (ec errorConn) DoContext(context.Context, string, ...interface{}) (interface{}, error) {
return nil, ec.err
}
func (ec errorConn) DoWithTimeout(time.Duration, string, ...interface{}) (interface{}, error) {
return nil, ec.err
}
Expand All @@ -570,6 +599,7 @@ func (ec errorConn) Err() error { ret
func (ec errorConn) Close() error { return nil }
func (ec errorConn) Flush() error { return ec.err }
func (ec errorConn) Receive() (interface{}, error) { return nil, ec.err }
func (ec errorConn) ReceiveContext(context.Context) (interface{}, error) { return nil, ec.err }
func (ec errorConn) ReceiveWithTimeout(time.Duration) (interface{}, error) { return nil, ec.err }

type idleList struct {
Expand Down
61 changes: 56 additions & 5 deletions redis/redis.go
Expand Up @@ -15,6 +15,7 @@
package redis

import (
"context"
"errors"
"time"
)
Expand All @@ -33,6 +34,7 @@ type Conn interface {
Err() error

// Do sends a command to the server and returns the received reply.
// This function will use the timeout which was set when the connection is created
Do(commandName string, args ...interface{}) (reply interface{}, err error)

// Send writes the command to the client's output buffer.
Expand Down Expand Up @@ -82,17 +84,52 @@ type Scanner interface {
type ConnWithTimeout interface {
Conn

// Do sends a command to the server and returns the received reply.
// The timeout overrides the read timeout set when dialing the
// connection.
// DoWithTimeout sends a command to the server and returns the received reply.
// The timeout overrides the readtimeout set when dialing the connection.
DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (reply interface{}, err error)

// Receive receives a single reply from the Redis server. The timeout
// overrides the read timeout set when dialing the connection.
// ReceiveWithTimeout receives a single reply from the Redis server.
// The timeout overrides the readtimeout set when dialing the connection.
ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error)
}

// ConnWithContext is an optional interface that allows the caller to control the command's life with context.
type ConnWithContext interface {
Conn

// DoContext sends a command to server and returns the received reply.
// min(ctx,DialReadTimeout()) will be used as the deadline.
// The connection will be closed if DialReadTimeout() timeout or ctx timeout or ctx canceled when this function is running.
// DialReadTimeout() timeout return err can be checked by strings.Contains(e.Error(), "io/timeout").
// ctx timeout return err context.DeadlineExceeded.
// ctx canceled return err context.Canceled.
DoContext(ctx context.Context, commandName string, args ...interface{}) (reply interface{}, err error)
chenjie199234 marked this conversation as resolved.
Show resolved Hide resolved

// ReceiveContext receives a single reply from the Redis server.
// min(ctx,DialReadTimeout()) will be used as the deadline.
// The connection will be closed if DialReadTimeout() timeout or ctx timeout or ctx canceled when this function is running.
// DialReadTimeout() timeout return err can be checked by strings.Contains(e.Error(), "io/timeout").
// ctx timeout return err context.DeadlineExceeded.
// ctx canceled return err context.Canceled.
ReceiveContext(ctx context.Context) (reply interface{}, err error)
}

var errTimeoutNotSupported = errors.New("redis: connection does not support ConnWithTimeout")
var errContextNotSupported = errors.New("redis: connection does not support ConnWithContext")

// DoContext sends a command to server and returns the received reply.
// min(ctx,DialReadTimeout()) will be used as the deadline.
// The connection will be closed if DialReadTimeout() timeout or ctx timeout or ctx canceled when this function is running.
// DialReadTimeout() timeout return err can be checked by strings.Contains(e.Error(), "io/timeout").
// ctx timeout return err context.DeadlineExceeded.
// ctx canceled return err context.Canceled.
func DoContext(c Conn, ctx context.Context, cmd string, args ...interface{}) (interface{}, error) {
cwt, ok := c.(ConnWithContext)
if !ok {
return nil, errContextNotSupported
}
return cwt.DoContext(ctx, cmd, args...)
}

// DoWithTimeout executes a Redis command with the specified read timeout. If
// the connection does not satisfy the ConnWithTimeout interface, then an error
Expand All @@ -105,6 +142,20 @@ func DoWithTimeout(c Conn, timeout time.Duration, cmd string, args ...interface{
return cwt.DoWithTimeout(timeout, cmd, args...)
}

// ReceiveContext receives a single reply from the Redis server.
// min(ctx,DialReadTimeout()) will be used as the deadline.
// The connection will be closed if DialReadTimeout() timeout or ctx timeout or ctx canceled when this function is running.
// DialReadTimeout() timeout return err can be checked by strings.Contains(e.Error(), "io/timeout").
// ctx timeout return err context.DeadlineExceeded.
// ctx canceled return err context.Canceled.
func ReceiveContext(c Conn, ctx context.Context) (interface{}, error) {
cwt, ok := c.(ConnWithContext)
if !ok {
return nil, errContextNotSupported
}
return cwt.ReceiveContext(ctx)
}

// ReceiveWithTimeout receives a reply with the specified read timeout. If the
// connection does not satisfy the ConnWithTimeout interface, then an error is
// returned.
Expand Down
52 changes: 52 additions & 0 deletions redis/redis_test.go
Expand Up @@ -15,6 +15,7 @@
package redis_test

import (
"context"
"testing"
"time"

Expand All @@ -26,13 +27,15 @@ type timeoutTestConn int
func (tc timeoutTestConn) Do(string, ...interface{}) (interface{}, error) {
return time.Duration(-1), nil
}

func (tc timeoutTestConn) DoWithTimeout(timeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
return timeout, nil
}

func (tc timeoutTestConn) Receive() (interface{}, error) {
return time.Duration(-1), nil
}

func (tc timeoutTestConn) ReceiveWithTimeout(timeout time.Duration) (interface{}, error) {
return timeout, nil
}
Expand Down Expand Up @@ -69,3 +72,52 @@ func TestPoolConnTimeout(t *testing.T) {
p := &redis.Pool{Dial: func() (redis.Conn, error) { return timeoutTestConn(0), nil }}
testTimeout(t, p.Get())
}

type contextDeadTestConn int

func (cc contextDeadTestConn) Do(string, ...interface{}) (interface{}, error) {
return -1, nil
}
func (cc contextDeadTestConn) DoContext(ctx context.Context, cmd string, args ...interface{}) (interface{}, error) {
return 1, nil
}
func (cc contextDeadTestConn) Receive() (interface{}, error) {
return -1, nil
}
func (cc contextDeadTestConn) ReceiveContext(ctx context.Context) (interface{}, error) {
return 1, nil
}
func (cc contextDeadTestConn) Send(string, ...interface{}) error { return nil }
func (cc contextDeadTestConn) Err() error { return nil }
func (cc contextDeadTestConn) Close() error { return nil }
func (cc contextDeadTestConn) Flush() error { return nil }

func testcontext(t *testing.T, c redis.Conn) {
r, e := c.Do("PING")
if r != -1 || e != nil {
t.Errorf("Do() = %v, %v, want %v, %v", r, e, -1, nil)
}
ctx, f := context.WithTimeout(context.Background(), time.Minute)
defer f()
r, e = redis.DoContext(c, ctx, "PING")
if r != 1 || e != nil {
t.Errorf("DoContext() = %v, %v, want %v, %v", r, e, 1, nil)
}
r, e = c.Receive()
if r != -1 || e != nil {
t.Errorf("Receive() = %v, %v, want %v, %v", r, e, -1, nil)
}
r, e = redis.ReceiveContext(c, ctx)
if r != 1 || e != nil {
t.Errorf("ReceiveContext() = %v, %v, want %v, %v", r, e, 1, nil)
}
}

func TestConnContext(t *testing.T) {
testcontext(t, contextDeadTestConn(0))
}

func TestPoolConnContext(t *testing.T) {
p := redis.Pool{Dial: func() (redis.Conn, error) { return contextDeadTestConn(0), nil }}
testcontext(t, p.Get())
}
13 changes: 13 additions & 0 deletions redis/script.go
Expand Up @@ -15,6 +15,7 @@
package redis

import (
"context"
"crypto/sha1"
"encoding/hex"
"io"
Expand Down Expand Up @@ -60,6 +61,18 @@ func (s *Script) Hash() string {
return s.hash
}

func (s *Script) DoContext(ctx context.Context, c Conn, keysAndArgs ...interface{}) (interface{}, error) {
cwt, ok := c.(ConnWithContext)
if !ok {
return nil, errContextNotSupported
}
v, err := cwt.DoContext(ctx, "EVALSHA", s.args(s.hash, keysAndArgs)...)
if e, ok := err.(Error); ok && strings.HasPrefix(string(e), "NOSCRIPT ") {
v, err = cwt.DoContext(ctx, "EVAL", s.args(s.src, keysAndArgs)...)
}
return v, err
}

// Do evaluates the script. Under the covers, Do optimistically evaluates the
// script using the EVALSHA command. If the command fails because the script is
// not loaded, then Do evaluates the script using the EVAL command (thus
Expand Down