From ad7e9d236a5f934fac7cfa50e8691a440705d5d9 Mon Sep 17 00:00:00 2001 From: "Peter A. Bigot" Date: Thu, 17 Mar 2022 09:15:09 -0700 Subject: [PATCH] feat: add RequestContext to PubSubConn (#603) Add a wrapper that goes through the standard receiveInternal processing to match the API of the existing PubSubConn Receive methods. Fixes: #592 --- redis/pubsub.go | 8 ++++++++ redis/pubsub_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/redis/pubsub.go b/redis/pubsub.go index cc585757..67b885b9 100644 --- a/redis/pubsub.go +++ b/redis/pubsub.go @@ -15,6 +15,7 @@ package redis import ( + "context" "errors" "time" ) @@ -116,6 +117,13 @@ func (c PubSubConn) ReceiveWithTimeout(timeout time.Duration) interface{} { return c.receiveInternal(ReceiveWithTimeout(c.Conn, timeout)) } +// ReceiveContext is like Receive, but it allows termination of the receive +// via a Context. If the call returns due to closure of the context's Done +// channel the underlying Conn will have been closed. +func (c PubSubConn) ReceiveContext(ctx context.Context) interface{} { + return c.receiveInternal(ReceiveContext(c.Conn, ctx)) +} + func (c PubSubConn) receiveInternal(replyArg interface{}, errArg error) interface{} { reply, err := Values(replyArg, errArg) if err != nil { diff --git a/redis/pubsub_test.go b/redis/pubsub_test.go index 83b91580..63e08692 100644 --- a/redis/pubsub_test.go +++ b/redis/pubsub_test.go @@ -15,6 +15,8 @@ package redis_test import ( + "context" + "errors" "reflect" "testing" "time" @@ -74,3 +76,25 @@ func TestPushed(t *testing.T) { t.Errorf("recv /w timeout got %v, want %v", got, want) } } + +func TestPubSubReceiveContext(t *testing.T) { + sc, err := redis.DialDefaultServer() + if err != nil { + t.Fatalf("error connection to database, %v", err) + } + defer sc.Close() + + c := redis.PubSubConn{Conn: sc} + + require.NoError(t, c.Subscribe("c1")) + expectPushed(t, c, "Subscribe(c1)", redis.Subscription{Kind: "subscribe", Channel: "c1", Count: 1}) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + got := c.ReceiveContext(ctx) + if err, ok := got.(error); !ok { + t.Errorf("recv w/canceled expected Canceled got non-error type %T", got) + } else if !errors.Is(err, context.Canceled) { + t.Errorf("recv w/canceled expected Canceled got %v", err) + } +}