Skip to content

Commit

Permalink
feat: add DoContext and ReceiveContext (#537)
Browse files Browse the repository at this point in the history
Add support for context during the Do cycle of a request.

This is supported by DoContext and ReceiveContext to control
the command life by both context and read timeout.

Co-authored-by: Mikhail Mazurskiy <126021+ash2k@users.noreply.github.com>
Co-authored-by: Lilith Games <lilithgames@LilithdeMacBook-Pro.local>
  • Loading branch information
3 people committed Sep 30, 2021
1 parent bf63cd5 commit 56d6448
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 5 deletions.
60 changes: 60 additions & 0 deletions redis/conn.go
Expand Up @@ -710,6 +710,36 @@ func (c *conn) Receive() (interface{}, error) {
return c.ReceiveWithTimeout(c.readTimeout)
}

func (c *conn) ReceiveContext(ctx context.Context) (interface{}, error) {
var realTimeout time.Duration
if dl, ok := ctx.Deadline(); ok {
timeout := time.Until(dl)
if timeout >= c.readTimeout && c.readTimeout != 0 {
realTimeout = c.readTimeout
} else if timeout <= 0 {
return nil, c.fatal(context.DeadlineExceeded)
} else {
realTimeout = timeout
}
} else {
realTimeout = c.readTimeout
}
endch := make(chan struct{})
var r interface{}
var e error
go func() {
defer close(endch)

r, e = c.ReceiveWithTimeout(realTimeout)
}()
select {
case <-ctx.Done():
return nil, c.fatal(ctx.Err())
case <-endch:
return r, e
}
}

func (c *conn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
var deadline time.Time
if timeout != 0 {
Expand Down Expand Up @@ -744,6 +774,36 @@ func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
return c.DoWithTimeout(c.readTimeout, cmd, args...)
}

func (c *conn) DoContext(ctx context.Context, cmd string, args ...interface{}) (interface{}, error) {
var realTimeout time.Duration
if dl, ok := ctx.Deadline(); ok {
timeout := time.Until(dl)
if timeout >= c.readTimeout && c.readTimeout != 0 {
realTimeout = c.readTimeout
} else if timeout <= 0 {
return nil, c.fatal(context.DeadlineExceeded)
} else {
realTimeout = timeout
}
} else {
realTimeout = c.readTimeout
}
endch := make(chan struct{})
var r interface{}
var e error
go func() {
defer close(endch)

r, e = c.DoWithTimeout(realTimeout, cmd, args)
}()
select {
case <-ctx.Done():
return nil, c.fatal(ctx.Err())
case <-endch:
return r, e
}
}

func (c *conn) DoWithTimeout(readTimeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
c.mu.Lock()
pending := c.pending
Expand Down
13 changes: 13 additions & 0 deletions redis/log.go
Expand Up @@ -16,6 +16,7 @@ package redis

import (
"bytes"
"context"
"fmt"
"log"
"time"
Expand Down Expand Up @@ -121,6 +122,12 @@ func (c *loggingConn) Do(commandName string, args ...interface{}) (interface{},
return reply, err
}

func (c *loggingConn) DoContext(ctx context.Context, commandName string, args ...interface{}) (interface{}, error) {
reply, err := DoContext(c.Conn, ctx, commandName, args...)
c.print("DoContext", commandName, args, reply, err)
return reply, err
}

func (c *loggingConn) DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (interface{}, error) {
reply, err := DoWithTimeout(c.Conn, timeout, commandName, args...)
c.print("DoWithTimeout", commandName, args, reply, err)
Expand All @@ -139,6 +146,12 @@ func (c *loggingConn) Receive() (interface{}, error) {
return reply, err
}

func (c *loggingConn) ReceiveContext(ctx context.Context) (interface{}, error) {
reply, err := ReceiveContext(c.Conn, ctx)
c.print("ReceiveContext", "", nil, reply, err)
return reply, err
}

func (c *loggingConn) ReceiveWithTimeout(timeout time.Duration) (interface{}, error) {
reply, err := ReceiveWithTimeout(c.Conn, timeout)
c.print("ReceiveWithTimeout", "", nil, reply, err)
Expand Down
30 changes: 30 additions & 0 deletions redis/pool.go
Expand Up @@ -512,6 +512,20 @@ func (ac *activeConn) Err() error {
return pc.c.Err()
}

func (ac *activeConn) DoContext(ctx context.Context, commandName string, args ...interface{}) (reply interface{}, err error) {
pc := ac.pc
if pc == nil {
return nil, errConnClosed
}
cwt, ok := pc.c.(ConnWithContext)
if !ok {
return nil, errContextNotSupported
}
ci := lookupCommandInfo(commandName)
ac.state = (ac.state | ci.Set) &^ ci.Clear
return cwt.DoContext(ctx, commandName, args...)
}

func (ac *activeConn) Do(commandName string, args ...interface{}) (reply interface{}, err error) {
pc := ac.pc
if pc == nil {
Expand Down Expand Up @@ -562,6 +576,18 @@ func (ac *activeConn) Receive() (reply interface{}, err error) {
return pc.c.Receive()
}

func (ac *activeConn) ReceiveContext(ctx context.Context) (reply interface{}, err error) {
pc := ac.pc
if pc == nil {
return nil, errConnClosed
}
cwt, ok := pc.c.(ConnWithContext)
if !ok {
return nil, errContextNotSupported
}
return cwt.ReceiveContext(ctx)
}

func (ac *activeConn) ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error) {
pc := ac.pc
if pc == nil {
Expand All @@ -577,6 +603,9 @@ func (ac *activeConn) ReceiveWithTimeout(timeout time.Duration) (reply interface
type errorConn struct{ err error }

func (ec errorConn) Do(string, ...interface{}) (interface{}, error) { return nil, ec.err }
func (ec errorConn) DoContext(context.Context, string, ...interface{}) (interface{}, error) {
return nil, ec.err
}
func (ec errorConn) DoWithTimeout(time.Duration, string, ...interface{}) (interface{}, error) {
return nil, ec.err
}
Expand All @@ -585,6 +614,7 @@ func (ec errorConn) Err() error { ret
func (ec errorConn) Close() error { return nil }
func (ec errorConn) Flush() error { return ec.err }
func (ec errorConn) Receive() (interface{}, error) { return nil, ec.err }
func (ec errorConn) ReceiveContext(context.Context) (interface{}, error) { return nil, ec.err }
func (ec errorConn) ReceiveWithTimeout(time.Duration) (interface{}, error) { return nil, ec.err }

type idleList struct {
Expand Down
61 changes: 56 additions & 5 deletions redis/redis.go
Expand Up @@ -15,6 +15,7 @@
package redis

import (
"context"
"errors"
"time"
)
Expand All @@ -33,6 +34,7 @@ type Conn interface {
Err() error

// Do sends a command to the server and returns the received reply.
// This function will use the timeout which was set when the connection is created
Do(commandName string, args ...interface{}) (reply interface{}, err error)

// Send writes the command to the client's output buffer.
Expand Down Expand Up @@ -82,17 +84,52 @@ type Scanner interface {
type ConnWithTimeout interface {
Conn

// Do sends a command to the server and returns the received reply.
// The timeout overrides the read timeout set when dialing the
// connection.
// DoWithTimeout sends a command to the server and returns the received reply.
// The timeout overrides the readtimeout set when dialing the connection.
DoWithTimeout(timeout time.Duration, commandName string, args ...interface{}) (reply interface{}, err error)

// Receive receives a single reply from the Redis server. The timeout
// overrides the read timeout set when dialing the connection.
// ReceiveWithTimeout receives a single reply from the Redis server.
// The timeout overrides the readtimeout set when dialing the connection.
ReceiveWithTimeout(timeout time.Duration) (reply interface{}, err error)
}

// ConnWithContext is an optional interface that allows the caller to control the command's life with context.
type ConnWithContext interface {
Conn

// DoContext sends a command to server and returns the received reply.
// min(ctx,DialReadTimeout()) will be used as the deadline.
// The connection will be closed if DialReadTimeout() timeout or ctx timeout or ctx canceled when this function is running.
// DialReadTimeout() timeout return err can be checked by strings.Contains(e.Error(), "io/timeout").
// ctx timeout return err context.DeadlineExceeded.
// ctx canceled return err context.Canceled.
DoContext(ctx context.Context, commandName string, args ...interface{}) (reply interface{}, err error)

// ReceiveContext receives a single reply from the Redis server.
// min(ctx,DialReadTimeout()) will be used as the deadline.
// The connection will be closed if DialReadTimeout() timeout or ctx timeout or ctx canceled when this function is running.
// DialReadTimeout() timeout return err can be checked by strings.Contains(e.Error(), "io/timeout").
// ctx timeout return err context.DeadlineExceeded.
// ctx canceled return err context.Canceled.
ReceiveContext(ctx context.Context) (reply interface{}, err error)
}

var errTimeoutNotSupported = errors.New("redis: connection does not support ConnWithTimeout")
var errContextNotSupported = errors.New("redis: connection does not support ConnWithContext")

// DoContext sends a command to server and returns the received reply.
// min(ctx,DialReadTimeout()) will be used as the deadline.
// The connection will be closed if DialReadTimeout() timeout or ctx timeout or ctx canceled when this function is running.
// DialReadTimeout() timeout return err can be checked by strings.Contains(e.Error(), "io/timeout").
// ctx timeout return err context.DeadlineExceeded.
// ctx canceled return err context.Canceled.
func DoContext(c Conn, ctx context.Context, cmd string, args ...interface{}) (interface{}, error) {
cwt, ok := c.(ConnWithContext)
if !ok {
return nil, errContextNotSupported
}
return cwt.DoContext(ctx, cmd, args...)
}

// DoWithTimeout executes a Redis command with the specified read timeout. If
// the connection does not satisfy the ConnWithTimeout interface, then an error
Expand All @@ -105,6 +142,20 @@ func DoWithTimeout(c Conn, timeout time.Duration, cmd string, args ...interface{
return cwt.DoWithTimeout(timeout, cmd, args...)
}

// ReceiveContext receives a single reply from the Redis server.
// min(ctx,DialReadTimeout()) will be used as the deadline.
// The connection will be closed if DialReadTimeout() timeout or ctx timeout or ctx canceled when this function is running.
// DialReadTimeout() timeout return err can be checked by strings.Contains(e.Error(), "io/timeout").
// ctx timeout return err context.DeadlineExceeded.
// ctx canceled return err context.Canceled.
func ReceiveContext(c Conn, ctx context.Context) (interface{}, error) {
cwt, ok := c.(ConnWithContext)
if !ok {
return nil, errContextNotSupported
}
return cwt.ReceiveContext(ctx)
}

// ReceiveWithTimeout receives a reply with the specified read timeout. If the
// connection does not satisfy the ConnWithTimeout interface, then an error is
// returned.
Expand Down
52 changes: 52 additions & 0 deletions redis/redis_test.go
Expand Up @@ -15,6 +15,7 @@
package redis_test

import (
"context"
"testing"
"time"

Expand All @@ -26,13 +27,15 @@ type timeoutTestConn int
func (tc timeoutTestConn) Do(string, ...interface{}) (interface{}, error) {
return time.Duration(-1), nil
}

func (tc timeoutTestConn) DoWithTimeout(timeout time.Duration, cmd string, args ...interface{}) (interface{}, error) {
return timeout, nil
}

func (tc timeoutTestConn) Receive() (interface{}, error) {
return time.Duration(-1), nil
}

func (tc timeoutTestConn) ReceiveWithTimeout(timeout time.Duration) (interface{}, error) {
return timeout, nil
}
Expand Down Expand Up @@ -69,3 +72,52 @@ func TestPoolConnTimeout(t *testing.T) {
p := &redis.Pool{Dial: func() (redis.Conn, error) { return timeoutTestConn(0), nil }}
testTimeout(t, p.Get())
}

type contextDeadTestConn int

func (cc contextDeadTestConn) Do(string, ...interface{}) (interface{}, error) {
return -1, nil
}
func (cc contextDeadTestConn) DoContext(ctx context.Context, cmd string, args ...interface{}) (interface{}, error) {
return 1, nil
}
func (cc contextDeadTestConn) Receive() (interface{}, error) {
return -1, nil
}
func (cc contextDeadTestConn) ReceiveContext(ctx context.Context) (interface{}, error) {
return 1, nil
}
func (cc contextDeadTestConn) Send(string, ...interface{}) error { return nil }
func (cc contextDeadTestConn) Err() error { return nil }
func (cc contextDeadTestConn) Close() error { return nil }
func (cc contextDeadTestConn) Flush() error { return nil }

func testcontext(t *testing.T, c redis.Conn) {
r, e := c.Do("PING")
if r != -1 || e != nil {
t.Errorf("Do() = %v, %v, want %v, %v", r, e, -1, nil)
}
ctx, f := context.WithTimeout(context.Background(), time.Minute)
defer f()
r, e = redis.DoContext(c, ctx, "PING")
if r != 1 || e != nil {
t.Errorf("DoContext() = %v, %v, want %v, %v", r, e, 1, nil)
}
r, e = c.Receive()
if r != -1 || e != nil {
t.Errorf("Receive() = %v, %v, want %v, %v", r, e, -1, nil)
}
r, e = redis.ReceiveContext(c, ctx)
if r != 1 || e != nil {
t.Errorf("ReceiveContext() = %v, %v, want %v, %v", r, e, 1, nil)
}
}

func TestConnContext(t *testing.T) {
testcontext(t, contextDeadTestConn(0))
}

func TestPoolConnContext(t *testing.T) {
p := redis.Pool{Dial: func() (redis.Conn, error) { return contextDeadTestConn(0), nil }}
testcontext(t, p.Get())
}
13 changes: 13 additions & 0 deletions redis/script.go
Expand Up @@ -15,6 +15,7 @@
package redis

import (
"context"
"crypto/sha1"
"encoding/hex"
"io"
Expand Down Expand Up @@ -60,6 +61,18 @@ func (s *Script) Hash() string {
return s.hash
}

func (s *Script) DoContext(ctx context.Context, c Conn, keysAndArgs ...interface{}) (interface{}, error) {
cwt, ok := c.(ConnWithContext)
if !ok {
return nil, errContextNotSupported
}
v, err := cwt.DoContext(ctx, "EVALSHA", s.args(s.hash, keysAndArgs)...)
if e, ok := err.(Error); ok && strings.HasPrefix(string(e), "NOSCRIPT ") {
v, err = cwt.DoContext(ctx, "EVAL", s.args(s.src, keysAndArgs)...)
}
return v, err
}

// Do evaluates the script. Under the covers, Do optimistically evaluates the
// script using the EVALSHA command. If the command fails because the script is
// not loaded, then Do evaluates the script using the EVAL command (thus
Expand Down

0 comments on commit 56d6448

Please sign in to comment.