diff --git a/CHANGELOG.md b/CHANGELOG.md index 484d40f5e..6b540a4f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,11 @@ > :heart: > [**Uptrace.dev** - All-in-one tool to optimize performance and monitor errors & logs](https://uptrace.dev) +## v8.9 + +- Changed `PubSub.Channel` to only rely on `Ping` result. You can now use `WithChannelSize`, + `WithChannelHealthCheckInterval`, and `WithChannelSendTimeout` to override default settings. + ## v8.8 - To make updating easier, extra modules now have the same version as go-redis does. That means that diff --git a/commands.go b/commands.go index 7990b72eb..540c3b59a 100644 --- a/commands.go +++ b/commands.go @@ -255,6 +255,8 @@ type Cmdable interface { ZCount(ctx context.Context, key, min, max string) *IntCmd ZLexCount(ctx context.Context, key, min, max string) *IntCmd ZIncrBy(ctx context.Context, key string, increment float64, member string) *FloatCmd + ZInter(ctx context.Context, store *ZStore) *StringSliceCmd + ZInterWithScores(ctx context.Context, store *ZStore) *ZSliceCmd ZInterStore(ctx context.Context, destination string, store *ZStore) *IntCmd ZMScore(ctx context.Context, key string, members ...string) *FloatSliceCmd ZPopMax(ctx context.Context, key string, count ...int64) *ZSliceCmd @@ -279,6 +281,9 @@ type Cmdable interface { ZUnionStore(ctx context.Context, dest string, store *ZStore) *IntCmd ZRandMember(ctx context.Context, key string, count int) *StringSliceCmd ZRandMemberWithScores(ctx context.Context, key string, count int) *ZSliceCmd + ZDiff(ctx context.Context, keys ...string) *StringSliceCmd + ZDiffWithScores(ctx context.Context, keys ...string) *ZSliceCmd + ZDiffStore(ctx context.Context, destination string, keys ...string) *IntCmd PFAdd(ctx context.Context, key string, els ...interface{}) *IntCmd PFCount(ctx context.Context, keys ...string) *IntCmd @@ -384,7 +389,7 @@ func (c statefulCmdable) Auth(ctx context.Context, password string) *StatusCmd { return cmd } -// Perform an AUTH command, using the given user and pass. +// AuthACL Perform an AUTH command, using the given user and pass. // Should be used to authenticate the current connection with one of the connections defined in the ACL list // when connecting to a Redis 6.0 instance, or greater, that is using the Redis ACL system. func (c statefulCmdable) AuthACL(ctx context.Context, username, password string) *StatusCmd { @@ -418,7 +423,7 @@ func (c statefulCmdable) ClientSetName(ctx context.Context, name string) *BoolCm return cmd } -// Set the resp protocol used. +// Hello Set the resp protocol used. func (c statefulCmdable) Hello(ctx context.Context, ver int, username, password, clientName string) *MapStringInterfaceCmd { args := make([]interface{}, 0, 7) @@ -728,7 +733,7 @@ func (c cmdable) DecrBy(ctx context.Context, key string, decrement int64) *IntCm return cmd } -// Redis `GET key` command. It returns redis.Nil error when key does not exist. +// Get redis `GET key` command. It returns redis.Nil error when key does not exist. func (c cmdable) Get(ctx context.Context, key string) *StringCmd { cmd := NewStringCmd(ctx, "get", key) _ = c(ctx, cmd) @@ -747,7 +752,7 @@ func (c cmdable) GetSet(ctx context.Context, key string, value interface{}) *Str return cmd } -// An expiration of zero removes the TTL associated with the key (i.e. GETEX key persist). +// GetEx An expiration of zero removes the TTL associated with the key (i.e. GETEX key persist). // Requires Redis >= 6.2.0. func (c cmdable) GetEx(ctx context.Context, key string, expiration time.Duration) *StringCmd { args := make([]interface{}, 0, 4) @@ -767,7 +772,7 @@ func (c cmdable) GetEx(ctx context.Context, key string, expiration time.Duration return cmd } -// redis-server version >= 6.2.0. +// GetDel redis-server version >= 6.2.0. func (c cmdable) GetDel(ctx context.Context, key string) *StringCmd { cmd := NewStringCmd(ctx, "getdel", key) _ = c(ctx, cmd) @@ -829,7 +834,7 @@ func (c cmdable) MSetNX(ctx context.Context, values ...interface{}) *BoolCmd { return cmd } -// Redis `SET key value [expiration]` command. +// Set Redis `SET key value [expiration]` command. // Use expiration for `SETEX`-like behavior. // // Zero expiration means the key has no expiration time. @@ -904,14 +909,14 @@ func (c cmdable) SetArgs(ctx context.Context, key string, value interface{}, a S return cmd } -// Redis `SETEX key expiration value` command. +// SetEX Redis `SETEX key expiration value` command. func (c cmdable) SetEX(ctx context.Context, key string, value interface{}, expiration time.Duration) *StatusCmd { cmd := NewStatusCmd(ctx, "setex", key, formatSec(ctx, expiration), value) _ = c(ctx, cmd) return cmd } -// Redis `SET key value [expiration] NX` command. +// SetNX Redis `SET key value [expiration] NX` command. // // Zero expiration means the key has no expiration time. // KeepTTL(-1) expiration is a Redis KEEPTTL option to keep existing TTL. @@ -935,7 +940,7 @@ func (c cmdable) SetNX(ctx context.Context, key string, value interface{}, expir return cmd } -// Redis `SET key value [expiration] XX` command. +// SetXX Redis `SET key value [expiration] XX` command. // // Zero expiration means the key has no expiration time. // KeepTTL(-1) expiration is a Redis KEEPTTL option to keep existing TTL. @@ -1246,14 +1251,14 @@ func (c cmdable) HVals(ctx context.Context, key string) *StringSliceCmd { return cmd } -// redis-server version >= 6.2.0. +// HRandField redis-server version >= 6.2.0. func (c cmdable) HRandField(ctx context.Context, key string, count int) *StringSliceCmd { cmd := NewStringSliceCmd(ctx, "hrandfield", key, count) _ = c(ctx, cmd) return cmd } -// redis-server version >= 6.2.0. +// HRandFieldWithValues redis-server version >= 6.2.0. func (c cmdable) HRandFieldWithValues(ctx context.Context, key string, count int) *KeyValueSliceCmd { cmd := NewKeyValueSliceCmd(ctx, "hrandfield", key, count, "withvalues") _ = c(ctx, cmd) @@ -1538,7 +1543,7 @@ func (c cmdable) SIsMember(ctx context.Context, key string, member interface{}) return cmd } -// Redis `SMISMEMBER key member [member ...]` command. +// SMIsMember Redis `SMISMEMBER key member [member ...]` command. func (c cmdable) SMIsMember(ctx context.Context, key string, members ...interface{}) *BoolSliceCmd { args := make([]interface{}, 2, 2+len(members)) args[0] = "smismember" @@ -1549,14 +1554,14 @@ func (c cmdable) SMIsMember(ctx context.Context, key string, members ...interfac return cmd } -// Redis `SMEMBERS key` command output as a slice. +// SMembers Redis `SMEMBERS key` command output as a slice. func (c cmdable) SMembers(ctx context.Context, key string) *StringSliceCmd { cmd := NewStringSliceCmd(ctx, "smembers", key) _ = c(ctx, cmd) return cmd } -// Redis `SMEMBERS key` command output as a map. +// SMembersMap Redis `SMEMBERS key` command output as a map. func (c cmdable) SMembersMap(ctx context.Context, key string) *StringStructMapCmd { cmd := NewStringStructMapCmd(ctx, "smembers", key) _ = c(ctx, cmd) @@ -1569,28 +1574,28 @@ func (c cmdable) SMove(ctx context.Context, source, destination string, member i return cmd } -// Redis `SPOP key` command. +// SPop Redis `SPOP key` command. func (c cmdable) SPop(ctx context.Context, key string) *StringCmd { cmd := NewStringCmd(ctx, "spop", key) _ = c(ctx, cmd) return cmd } -// Redis `SPOP key count` command. +// SPopN Redis `SPOP key count` command. func (c cmdable) SPopN(ctx context.Context, key string, count int64) *StringSliceCmd { cmd := NewStringSliceCmd(ctx, "spop", key, count) _ = c(ctx, cmd) return cmd } -// Redis `SRANDMEMBER key` command. +// SRandMember Redis `SRANDMEMBER key` command. func (c cmdable) SRandMember(ctx context.Context, key string) *StringCmd { cmd := NewStringCmd(ctx, "srandmember", key) _ = c(ctx, cmd) return cmd } -// Redis `SRANDMEMBER key count` command. +// SRandMemberN Redis `SRANDMEMBER key count` command. func (c cmdable) SRandMemberN(ctx context.Context, key string, count int64) *StringSliceCmd { cmd := NewStringSliceCmd(ctx, "srandmember", key, count) _ = c(ctx, cmd) @@ -1793,7 +1798,7 @@ func (c cmdable) XReadGroup(ctx context.Context, a *XReadGroupArgs) *XStreamSlic args := make([]interface{}, 0, 8+len(a.Streams)) args = append(args, "xreadgroup", "group", a.Group, a.Consumer) - keyPos := int8(1) + keyPos := int8(4) if a.Count > 0 { args = append(args, "count", a.Count) keyPos += 2 @@ -1963,7 +1968,18 @@ type ZStore struct { Aggregate string } -// Redis `BZPOPMAX key [key ...] timeout` command. +func (z *ZStore) len() (n int) { + n = len(z.Keys) + if len(z.Weights) > 0 { + n += 1 + len(z.Weights) + } + if z.Aggregate != "" { + n += 2 + } + return n +} + +// BZPopMax Redis `BZPOPMAX key [key ...] timeout` command. func (c cmdable) BZPopMax(ctx context.Context, timeout time.Duration, keys ...string) *ZWithKeyCmd { args := make([]interface{}, 1+len(keys)+1) args[0] = "bzpopmax" @@ -1977,7 +1993,7 @@ func (c cmdable) BZPopMax(ctx context.Context, timeout time.Duration, keys ...st return cmd } -// Redis `BZPOPMIN key [key ...] timeout` command. +// BZPopMin Redis `BZPOPMIN key [key ...] timeout` command. func (c cmdable) BZPopMin(ctx context.Context, timeout time.Duration, keys ...string) *ZWithKeyCmd { args := make([]interface{}, 1+len(keys)+1) args[0] = "bzpopmin" @@ -2001,7 +2017,7 @@ func (c cmdable) zAdd(ctx context.Context, a []interface{}, n int, members ...*Z return cmd } -// Redis `ZADD key score member [score member ...]` command. +// ZAdd Redis `ZADD key score member [score member ...]` command. func (c cmdable) ZAdd(ctx context.Context, key string, members ...*Z) *IntCmd { const n = 2 a := make([]interface{}, n+2*len(members)) @@ -2009,7 +2025,7 @@ func (c cmdable) ZAdd(ctx context.Context, key string, members ...*Z) *IntCmd { return c.zAdd(ctx, a, n, members...) } -// Redis `ZADD key NX score member [score member ...]` command. +// ZAddNX Redis `ZADD key NX score member [score member ...]` command. func (c cmdable) ZAddNX(ctx context.Context, key string, members ...*Z) *IntCmd { const n = 3 a := make([]interface{}, n+2*len(members)) @@ -2017,7 +2033,7 @@ func (c cmdable) ZAddNX(ctx context.Context, key string, members ...*Z) *IntCmd return c.zAdd(ctx, a, n, members...) } -// Redis `ZADD key XX score member [score member ...]` command. +// ZAddXX Redis `ZADD key XX score member [score member ...]` command. func (c cmdable) ZAddXX(ctx context.Context, key string, members ...*Z) *IntCmd { const n = 3 a := make([]interface{}, n+2*len(members)) @@ -2025,7 +2041,7 @@ func (c cmdable) ZAddXX(ctx context.Context, key string, members ...*Z) *IntCmd return c.zAdd(ctx, a, n, members...) } -// Redis `ZADD key CH score member [score member ...]` command. +// ZAddCh Redis `ZADD key CH score member [score member ...]` command. func (c cmdable) ZAddCh(ctx context.Context, key string, members ...*Z) *IntCmd { const n = 3 a := make([]interface{}, n+2*len(members)) @@ -2033,7 +2049,7 @@ func (c cmdable) ZAddCh(ctx context.Context, key string, members ...*Z) *IntCmd return c.zAdd(ctx, a, n, members...) } -// Redis `ZADD key NX CH score member [score member ...]` command. +// ZAddNXCh Redis `ZADD key NX CH score member [score member ...]` command. func (c cmdable) ZAddNXCh(ctx context.Context, key string, members ...*Z) *IntCmd { const n = 4 a := make([]interface{}, n+2*len(members)) @@ -2041,7 +2057,7 @@ func (c cmdable) ZAddNXCh(ctx context.Context, key string, members ...*Z) *IntCm return c.zAdd(ctx, a, n, members...) } -// Redis `ZADD key XX CH score member [score member ...]` command. +// ZAddXXCh Redis `ZADD key XX CH score member [score member ...]` command. func (c cmdable) ZAddXXCh(ctx context.Context, key string, members ...*Z) *IntCmd { const n = 4 a := make([]interface{}, n+2*len(members)) @@ -2059,7 +2075,7 @@ func (c cmdable) zIncr(ctx context.Context, a []interface{}, n int, members ...* return cmd } -// Redis `ZADD key INCR score member` command. +// ZIncr Redis `ZADD key INCR score member` command. func (c cmdable) ZIncr(ctx context.Context, key string, member *Z) *FloatCmd { const n = 3 a := make([]interface{}, n+2) @@ -2067,7 +2083,7 @@ func (c cmdable) ZIncr(ctx context.Context, key string, member *Z) *FloatCmd { return c.zIncr(ctx, a, n, member) } -// Redis `ZADD key NX INCR score member` command. +// ZIncrNX Redis `ZADD key NX INCR score member` command. func (c cmdable) ZIncrNX(ctx context.Context, key string, member *Z) *FloatCmd { const n = 4 a := make([]interface{}, n+2) @@ -2075,7 +2091,7 @@ func (c cmdable) ZIncrNX(ctx context.Context, key string, member *Z) *FloatCmd { return c.zIncr(ctx, a, n, member) } -// Redis `ZADD key XX INCR score member` command. +// ZIncrXX Redis `ZADD key XX INCR score member` command. func (c cmdable) ZIncrXX(ctx context.Context, key string, member *Z) *FloatCmd { const n = 4 a := make([]interface{}, n+2) @@ -2108,7 +2124,7 @@ func (c cmdable) ZIncrBy(ctx context.Context, key string, increment float64, mem } func (c cmdable) ZInterStore(ctx context.Context, destination string, store *ZStore) *IntCmd { - args := make([]interface{}, 0, 3+len(store.Keys)) + args := make([]interface{}, 0, 3+store.len()) args = append(args, "zinterstore", destination, len(store.Keys)) for _, key := range store.Keys { args = append(args, key) @@ -2128,6 +2144,50 @@ func (c cmdable) ZInterStore(ctx context.Context, destination string, store *ZSt return cmd } +func (c cmdable) ZInter(ctx context.Context, store *ZStore) *StringSliceCmd { + args := make([]interface{}, 0, 2+store.len()) + args = append(args, "zinter", len(store.Keys)) + for _, key := range store.Keys { + args = append(args, key) + } + if len(store.Weights) > 0 { + args = append(args, "weights") + for _, weights := range store.Weights { + args = append(args, weights) + } + } + + if store.Aggregate != "" { + args = append(args, "aggregate", store.Aggregate) + } + cmd := NewStringSliceCmd(ctx, args...) + cmd.setFirstKeyPos(2) + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) ZInterWithScores(ctx context.Context, store *ZStore) *ZSliceCmd { + args := make([]interface{}, 0, 3+store.len()) + args = append(args, "zinter", len(store.Keys)) + for _, key := range store.Keys { + args = append(args, key) + } + if len(store.Weights) > 0 { + args = append(args, "weights") + for _, weights := range store.Weights { + args = append(args, weights) + } + } + if store.Aggregate != "" { + args = append(args, "aggregate", store.Aggregate) + } + args = append(args, "withscores") + cmd := NewZSliceCmd(ctx, args...) + cmd.setFirstKeyPos(2) + _ = c(ctx, cmd) + return cmd +} + func (c cmdable) ZMScore(ctx context.Context, key string, members ...string) *FloatSliceCmd { args := make([]interface{}, 2+len(members)) args[0] = "zmscore" @@ -2354,7 +2414,7 @@ func (c cmdable) ZScore(ctx context.Context, key, member string) *FloatCmd { } func (c cmdable) ZUnionStore(ctx context.Context, dest string, store *ZStore) *IntCmd { - args := make([]interface{}, 0, 3+len(store.Keys)) + args := make([]interface{}, 0, 3+store.len()) args = append(args, "zunionstore", dest, len(store.Keys)) for _, key := range store.Keys { args = append(args, key) @@ -2389,6 +2449,49 @@ func (c cmdable) ZRandMemberWithScores(ctx context.Context, key string, count in return cmd } +// ZDiff redis-server version >= 6.2.0. +func (c cmdable) ZDiff(ctx context.Context, keys ...string) *StringSliceCmd { + args := make([]interface{}, 2+len(keys)) + args[0] = "zdiff" + args[1] = len(keys) + for i, key := range keys { + args[i+2] = key + } + + cmd := NewStringSliceCmd(ctx, args...) + cmd.setFirstKeyPos(2) + _ = c(ctx, cmd) + return cmd +} + +// ZDiffWithScores redis-server version >= 6.2.0. +func (c cmdable) ZDiffWithScores(ctx context.Context, keys ...string) *ZSliceCmd { + args := make([]interface{}, 3+len(keys)) + args[0] = "zdiff" + args[1] = len(keys) + for i, key := range keys { + args[i+2] = key + } + args[len(keys)+2] = "withscores" + + cmd := NewZSliceCmd(ctx, args...) + cmd.setFirstKeyPos(2) + _ = c(ctx, cmd) + return cmd +} + +// ZDiffStore redis-server version >=6.2.0. +func (c cmdable) ZDiffStore(ctx context.Context, destination string, keys ...string) *IntCmd { + args := make([]interface{}, 0, 3+len(keys)) + args = append(args, "zdiffstore", destination, len(keys)) + for _, key := range keys { + args = append(args, key) + } + cmd := NewIntCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + //------------------------------------------------------------------------------ func (c cmdable) PFAdd(ctx context.Context, key string, els ...interface{}) *IntCmd { @@ -2648,6 +2751,7 @@ func (c cmdable) MemoryUsage(ctx context.Context, key string, samples ...int) *I args = append(args, "SAMPLES", samples[0]) } cmd := NewIntCmd(ctx, args...) + cmd.setFirstKeyPos(2) _ = c(ctx, cmd) return cmd } @@ -2664,6 +2768,7 @@ func (c cmdable) Eval(ctx context.Context, script string, keys []string, args .. } cmdArgs = appendArgs(cmdArgs, args) cmd := NewCmd(ctx, cmdArgs...) + cmd.setFirstKeyPos(3) _ = c(ctx, cmd) return cmd } @@ -2678,6 +2783,7 @@ func (c cmdable) EvalSha(ctx context.Context, sha1 string, keys []string, args . } cmdArgs = appendArgs(cmdArgs, args) cmd := NewCmd(ctx, cmdArgs...) + cmd.setFirstKeyPos(3) _ = c(ctx, cmd) return cmd } @@ -2926,7 +3032,7 @@ func (c cmdable) GeoRadiusStore( return cmd } -// GeoRadius is a read-only GEORADIUSBYMEMBER_RO command. +// GeoRadiusByMember is a read-only GEORADIUSBYMEMBER_RO command. func (c cmdable) GeoRadiusByMember( ctx context.Context, key, member string, query *GeoRadiusQuery, ) *GeoLocationCmd { diff --git a/commands_test.go b/commands_test.go index 3db4fe029..a47a5fb31 100644 --- a/commands_test.go +++ b/commands_test.go @@ -3976,6 +3976,119 @@ var _ = Describe("Commands", func() { Equal([]redis.Z{{Member: "two", Score: 2}}), )) }) + + It("should ZDiff", func() { + err := client.ZAdd(ctx, "zset1", &redis.Z{Score: 1, Member: "one"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset1", &redis.Z{Score: 2, Member: "two"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset1", &redis.Z{Score: 3, Member: "three"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 1, Member: "one"}).Err() + Expect(err).NotTo(HaveOccurred()) + + v, err := client.ZDiff(ctx, "zset1", "zset2").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(v).To(Equal([]string{"two", "three"})) + }) + + It("should ZDiffWithScores", func() { + err := client.ZAdd(ctx, "zset1", &redis.Z{Score: 1, Member: "one"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset1", &redis.Z{Score: 2, Member: "two"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset1", &redis.Z{Score: 3, Member: "three"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 1, Member: "one"}).Err() + Expect(err).NotTo(HaveOccurred()) + + v, err := client.ZDiffWithScores(ctx, "zset1", "zset2").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(v).To(Equal([]redis.Z{ + { + Member: "two", + Score: 2, + }, + { + Member: "three", + Score: 3, + }, + })) + }) + + It("should ZInter", func() { + err := client.ZAdd(ctx, "zset1", &redis.Z{Score: 1, Member: "one"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset1", &redis.Z{Score: 2, Member: "two"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 1, Member: "one"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 2, Member: "two"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 3, Member: "three"}).Err() + Expect(err).NotTo(HaveOccurred()) + + v, err := client.ZInter(ctx, &redis.ZStore{ + Keys: []string{"zset1", "zset2"}, + }).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(v).To(Equal([]string{"one", "two"})) + }) + + It("should ZInterWithScores", func() { + err := client.ZAdd(ctx, "zset1", &redis.Z{Score: 1, Member: "one"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset1", &redis.Z{Score: 2, Member: "two"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 1, Member: "one"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 2, Member: "two"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 3, Member: "three"}).Err() + Expect(err).NotTo(HaveOccurred()) + + v, err := client.ZInterWithScores(ctx, &redis.ZStore{ + Keys: []string{"zset1", "zset2"}, + Weights: []float64{2, 3}, + Aggregate: "Max", + }).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(v).To(Equal([]redis.Z{ + { + Member: "one", + Score: 3, + }, + { + Member: "two", + Score: 6, + }, + })) + }) + + It("should ZDiffStore", func() { + err := client.ZAdd(ctx, "zset1", &redis.Z{Score: 1, Member: "one"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset1", &redis.Z{Score: 2, Member: "two"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 1, Member: "one"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 2, Member: "two"}).Err() + Expect(err).NotTo(HaveOccurred()) + err = client.ZAdd(ctx, "zset2", &redis.Z{Score: 3, Member: "three"}).Err() + Expect(err).NotTo(HaveOccurred()) + v, err := client.ZDiffStore(ctx, "out1", "zset1", "zset2").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(v).To(Equal(int64(0))) + v, err = client.ZDiffStore(ctx, "out1", "zset2", "zset1").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(v).To(Equal(int64(1))) + vals, err := client.ZRangeWithScores(ctx, "out1", 0, -1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(vals).To(Equal([]redis.Z{{ + Score: 3, + Member: "three", + }})) + }) }) Describe("streams", func() { diff --git a/pubsub.go b/pubsub.go index c56270b44..c6ffb2562 100644 --- a/pubsub.go +++ b/pubsub.go @@ -2,7 +2,6 @@ package redis import ( "context" - "errors" "fmt" "strings" "sync" @@ -13,13 +12,6 @@ import ( "github.com/go-redis/redis/v8/internal/proto" ) -const ( - pingTimeout = time.Second - chanSendTimeout = time.Minute -) - -var errPingTimeout = errors.New("redis: ping timeout") - // PubSub implements Pub/Sub commands as described in // http://redis.io/topics/pubsub. Message receiving is NOT safe // for concurrent use by multiple goroutines. @@ -43,9 +35,12 @@ type PubSub struct { cmd *Cmd chOnce sync.Once - msgCh chan *Message - allCh chan interface{} - ping chan struct{} + msgCh *channel + allCh *channel +} + +func (c *PubSub) init() { + c.exit = make(chan struct{}) } func (c *PubSub) String() string { @@ -54,10 +49,6 @@ func (c *PubSub) String() string { return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", ")) } -func (c *PubSub) init() { - c.exit = make(chan struct{}) -} - func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) { c.mu.Lock() cn, err := c.conn(ctx, nil) @@ -418,6 +409,15 @@ func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) { } } +func (c *PubSub) getContext() context.Context { + if c.cmd != nil { + return c.cmd.ctx + } + return context.Background() +} + +//------------------------------------------------------------------------------ + // Channel returns a Go channel for concurrently receiving messages. // The channel is closed together with the PubSub. If the Go channel // is blocked full for 30 seconds the message is dropped. @@ -425,26 +425,24 @@ func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) { // // go-redis periodically sends ping messages to test connection health // and re-subscribes if ping can not not received for 30 seconds. -func (c *PubSub) Channel() <-chan *Message { - return c.ChannelSize(100) -} - -// ChannelSize is like Channel, but creates a Go channel -// with specified buffer size. -func (c *PubSub) ChannelSize(size int) <-chan *Message { +func (c *PubSub) Channel(opts ...ChannelOption) <-chan *Message { c.chOnce.Do(func() { - c.initPing() - c.initMsgChan(size) + c.msgCh = newChannel(c, opts...) + c.msgCh.initMsgChan() }) if c.msgCh == nil { err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions") panic(err) } - if cap(c.msgCh) != size { - err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created") - panic(err) - } - return c.msgCh + return c.msgCh.msgCh +} + +// ChannelSize is like Channel, but creates a Go channel +// with specified buffer size. +// +// Deprecated: use Channel(WithChannelSize(size)), remove in v9. +func (c *PubSub) ChannelSize(size int) <-chan *Message { + return c.Channel(WithChannelSize(size)) } // ChannelWithSubscriptions is like Channel, but message type can be either @@ -452,59 +450,101 @@ func (c *PubSub) ChannelSize(size int) <-chan *Message { // reconnections. // // ChannelWithSubscriptions can not be used together with Channel or ChannelSize. -func (c *PubSub) ChannelWithSubscriptions(ctx context.Context, size int) <-chan interface{} { +func (c *PubSub) ChannelWithSubscriptions(_ context.Context, size int) <-chan interface{} { c.chOnce.Do(func() { - c.initPing() - c.initAllChan(size) + c.allCh = newChannel(c, WithChannelSize(size)) + c.allCh.initAllChan() }) if c.allCh == nil { err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel") panic(err) } - if cap(c.allCh) != size { - err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created") - panic(err) + return c.allCh.allCh +} + +type ChannelOption func(c *channel) + +// WithChannelSize specifies the Go chan size that is used to buffer incoming messages. +// +// The default is 100 messages. +func WithChannelSize(size int) ChannelOption { + return func(c *channel) { + c.chanSize = size } - return c.allCh } -func (c *PubSub) getContext() context.Context { - if c.cmd != nil { - return c.cmd.ctx +// WithChannelHealthCheckInterval specifies the health check interval. +// PubSub will ping Redis Server if it does not receive any messages within the interval. +// To disable health check, use zero interval. +// +// The default is 3 seconds. +func WithChannelHealthCheckInterval(d time.Duration) ChannelOption { + return func(c *channel) { + c.checkInterval = d } - return context.Background() } -func (c *PubSub) initPing() { +// WithChannelSendTimeout specifies the channel send timeout after which +// the message is dropped. +// +// The default is 60 seconds. +func WithChannelSendTimeout(d time.Duration) ChannelOption { + return func(c *channel) { + c.chanSendTimeout = d + } +} + +type channel struct { + pubSub *PubSub + + msgCh chan *Message + allCh chan interface{} + ping chan struct{} + + chanSize int + chanSendTimeout time.Duration + checkInterval time.Duration +} + +func newChannel(pubSub *PubSub, opts ...ChannelOption) *channel { + c := &channel{ + pubSub: pubSub, + + chanSize: 100, + chanSendTimeout: time.Minute, + checkInterval: 3 * time.Second, + } + for _, opt := range opts { + opt(c) + } + if c.checkInterval > 0 { + c.initHealthCheck() + } + return c +} + +func (c *channel) initHealthCheck() { ctx := context.TODO() c.ping = make(chan struct{}, 1) + go func() { timer := time.NewTimer(time.Minute) timer.Stop() - healthy := true for { - timer.Reset(pingTimeout) + timer.Reset(c.checkInterval) select { case <-c.ping: - healthy = true if !timer.Stop() { <-timer.C } case <-timer.C: - pingErr := c.Ping(ctx) - if healthy { - healthy = false - } else { - if pingErr == nil { - pingErr = errPingTimeout - } - c.mu.Lock() - c.reconnect(ctx, pingErr) - healthy = true - c.mu.Unlock() + if pingErr := c.pubSub.Ping(ctx); pingErr != nil { + c.pubSub.mu.Lock() + c.pubSub.reconnect(ctx, pingErr) + c.pubSub.mu.Unlock() } - case <-c.exit: + case <-c.pubSub.exit: return } } @@ -512,16 +552,17 @@ func (c *PubSub) initPing() { } // initMsgChan must be in sync with initAllChan. -func (c *PubSub) initMsgChan(size int) { +func (c *channel) initMsgChan() { ctx := context.TODO() - c.msgCh = make(chan *Message, size) + c.msgCh = make(chan *Message, c.chanSize) + go func() { timer := time.NewTimer(time.Minute) timer.Stop() var errCount int for { - msg, err := c.Receive(ctx) + msg, err := c.pubSub.Receive(ctx) if err != nil { if err == pool.ErrClosed { close(c.msgCh) @@ -548,7 +589,7 @@ func (c *PubSub) initMsgChan(size int) { case *Pong: // Ignore. case *Message: - timer.Reset(chanSendTimeout) + timer.Reset(c.chanSendTimeout) select { case c.msgCh <- msg: if !timer.Stop() { @@ -556,30 +597,28 @@ func (c *PubSub) initMsgChan(size int) { } case <-timer.C: internal.Logger.Printf( - c.getContext(), - "redis: %s channel is full for %s (message is dropped)", - c, - chanSendTimeout, - ) + ctx, "redis: %s channel is full for %s (message is dropped)", + c, c.chanSendTimeout) } default: - internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg) + internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) } } }() } // initAllChan must be in sync with initMsgChan. -func (c *PubSub) initAllChan(size int) { +func (c *channel) initAllChan() { ctx := context.TODO() - c.allCh = make(chan interface{}, size) + c.allCh = make(chan interface{}, c.chanSize) + go func() { - timer := time.NewTimer(pingTimeout) + timer := time.NewTimer(time.Minute) timer.Stop() var errCount int for { - msg, err := c.Receive(ctx) + msg, err := c.pubSub.Receive(ctx) if err != nil { if err == pool.ErrClosed { close(c.allCh) @@ -601,29 +640,23 @@ func (c *PubSub) initAllChan(size int) { } switch msg := msg.(type) { - case *Subscription: - c.sendMessage(msg, timer) case *Pong: // Ignore. - case *Message: - c.sendMessage(msg, timer) + case *Subscription, *Message: + timer.Reset(c.chanSendTimeout) + select { + case c.allCh <- msg: + if !timer.Stop() { + <-timer.C + } + case <-timer.C: + internal.Logger.Printf( + ctx, "redis: %s channel is full for %s (message is dropped)", + c, c.chanSendTimeout) + } default: - internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg) + internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) } } }() } - -func (c *PubSub) sendMessage(msg interface{}, timer *time.Timer) { - timer.Reset(pingTimeout) - select { - case c.allCh <- msg: - if !timer.Stop() { - <-timer.C - } - case <-timer.C: - internal.Logger.Printf( - c.getContext(), - "redis: %s channel is full for %s (message is dropped)", c, pingTimeout) - } -} diff --git a/pubsub_test.go b/pubsub_test.go index b9633b2fc..14e37c2be 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -443,4 +443,23 @@ var _ = Describe("PubSub", func() { Fail("timeout") } }) + + It("should ChannelMessage", func() { + pubsub := client.Subscribe(ctx, "mychannel") + defer pubsub.Close() + + ch := pubsub.Channel( + redis.WithChannelSize(10), + redis.WithChannelHealthCheckInterval(time.Second), + ) + + text := "test channel message" + err := client.Publish(ctx, "mychannel", text).Err() + Expect(err).NotTo(HaveOccurred()) + + var msg *redis.Message + Eventually(ch).Should(Receive(&msg)) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal(text)) + }) })