From d404f2db91fcf59244f75272ceaa1d28d9ff352c Mon Sep 17 00:00:00 2001 From: pj Date: Fri, 7 Oct 2022 01:25:32 +1100 Subject: [PATCH] make RequestCtx's userdata accept keys that are of type: interface{} (#1387) Co-authored-by: rocketlaunchr-cto --- server.go | 26 ++++++++++++++++++-------- server_test.go | 2 +- userdata.go | 33 +++++++++++++++++++++------------ 3 files changed, 40 insertions(+), 21 deletions(-) diff --git a/server.go b/server.go index a7e93b174e..3f68eae705 100644 --- a/server.go +++ b/server.go @@ -670,7 +670,7 @@ func (ctx *RequestCtx) Hijacked() bool { // All the values are removed from ctx after returning from the top // RequestHandler. Additionally, Close method is called on each value // implementing io.Closer before removing the value from ctx. -func (ctx *RequestCtx) SetUserValue(key string, value interface{}) { +func (ctx *RequestCtx) SetUserValue(key interface{}, value interface{}) { ctx.userValues.Set(key, value) } @@ -688,7 +688,7 @@ func (ctx *RequestCtx) SetUserValueBytes(key []byte, value interface{}) { } // UserValue returns the value stored via SetUserValue* under the given key. -func (ctx *RequestCtx) UserValue(key string) interface{} { +func (ctx *RequestCtx) UserValue(key interface{}) interface{} { return ctx.userValues.Get(key) } @@ -698,11 +698,24 @@ func (ctx *RequestCtx) UserValueBytes(key []byte) interface{} { return ctx.userValues.GetBytes(key) } -// VisitUserValues calls visitor for each existing userValue. +// VisitUserValues calls visitor for each existing userValue with a key that is a string or []byte. // // visitor must not retain references to key and value after returning. // Make key and/or value copies if you need storing them after returning. func (ctx *RequestCtx) VisitUserValues(visitor func([]byte, interface{})) { + for i, n := 0, len(ctx.userValues); i < n; i++ { + kv := &ctx.userValues[i] + if _, ok := kv.key.(string); ok { + visitor(s2b(kv.key.(string)), kv.value) + } + } +} + +// VisitUserValuesAll calls visitor for each existing userValue. +// +// visitor must not retain references to key and value after returning. +// Make key and/or value copies if you need storing them after returning. +func (ctx *RequestCtx) VisitUserValuesAll(visitor func(interface{}, interface{})) { for i, n := 0, len(ctx.userValues); i < n; i++ { kv := &ctx.userValues[i] visitor(kv.key, kv.value) @@ -715,7 +728,7 @@ func (ctx *RequestCtx) ResetUserValues() { } // RemoveUserValue removes the given key and the value under it in ctx. -func (ctx *RequestCtx) RemoveUserValue(key string) { +func (ctx *RequestCtx) RemoveUserValue(key interface{}) { ctx.userValues.Remove(key) } @@ -2696,10 +2709,7 @@ func (ctx *RequestCtx) Err() error { // This method is present to make RequestCtx implement the context interface. // This method is the same as calling ctx.UserValue(key) func (ctx *RequestCtx) Value(key interface{}) interface{} { - if keyString, ok := key.(string); ok { - return ctx.UserValue(keyString) - } - return nil + return ctx.UserValue(key) } var fakeServer = &Server{ diff --git a/server_test.go b/server_test.go index af3cd4a81d..333a866364 100644 --- a/server_test.go +++ b/server_test.go @@ -1737,7 +1737,7 @@ func TestRequestCtxUserValue(t *testing.T) { vlen := 0 ctx.VisitUserValues(func(key []byte, value interface{}) { vlen++ - v := ctx.UserValueBytes(key) + v := ctx.UserValue(key) if v != value { t.Fatalf("unexpected value obtained from VisitUserValues for key: %q, expecting: %#v but got: %#v", key, v, value) } diff --git a/userdata.go b/userdata.go index 37f7d9b1f9..40690f69bf 100644 --- a/userdata.go +++ b/userdata.go @@ -5,18 +5,21 @@ import ( ) type userDataKV struct { - key []byte + key interface{} value interface{} } type userData []userDataKV -func (d *userData) Set(key string, value interface{}) { +func (d *userData) Set(key interface{}, value interface{}) { + if b, ok := key.([]byte); ok { + key = string(b) + } args := *d n := len(args) for i := 0; i < n; i++ { kv := &args[i] - if string(kv.key) == key { + if kv.key == key { kv.value = value return } @@ -30,28 +33,31 @@ func (d *userData) Set(key string, value interface{}) { if c > n { args = args[:n+1] kv := &args[n] - kv.key = append(kv.key[:0], key...) + kv.key = key kv.value = value *d = args return } kv := userDataKV{} - kv.key = append(kv.key[:0], key...) + kv.key = key kv.value = value *d = append(args, kv) } func (d *userData) SetBytes(key []byte, value interface{}) { - d.Set(b2s(key), value) + d.Set(key, value) } -func (d *userData) Get(key string) interface{} { +func (d *userData) Get(key interface{}) interface{} { + if b, ok := key.([]byte); ok { + key = b2s(b) + } args := *d n := len(args) for i := 0; i < n; i++ { kv := &args[i] - if string(kv.key) == key { + if kv.key == key { return kv.value } } @@ -59,7 +65,7 @@ func (d *userData) Get(key string) interface{} { } func (d *userData) GetBytes(key []byte) interface{} { - return d.Get(b2s(key)) + return d.Get(key) } func (d *userData) Reset() { @@ -74,12 +80,15 @@ func (d *userData) Reset() { *d = (*d)[:0] } -func (d *userData) Remove(key string) { +func (d *userData) Remove(key interface{}) { + if b, ok := key.([]byte); ok { + key = b2s(b) + } args := *d n := len(args) for i := 0; i < n; i++ { kv := &args[i] - if string(kv.key) == key { + if kv.key == key { n-- args[i], args[n] = args[n], args[i] args[n].value = nil @@ -91,5 +100,5 @@ func (d *userData) Remove(key string) { } func (d *userData) RemoveBytes(key []byte) { - d.Remove(b2s(key)) + d.Remove(key) }