Skip to content

Commit

Permalink
make RequestCtx's userdata accept keys that are of type: interface{} (#…
Browse files Browse the repository at this point in the history
…1387)

Co-authored-by: rocketlaunchr-cto <rocketlaunchr.cloud@gmail.com>
  • Loading branch information
pjebs and rocketlaunchr-cto committed Oct 6, 2022
1 parent bcf7e8e commit d404f2d
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 21 deletions.
26 changes: 18 additions & 8 deletions server.go
Expand Up @@ -670,7 +670,7 @@ func (ctx *RequestCtx) Hijacked() bool {
// All the values are removed from ctx after returning from the top
// RequestHandler. Additionally, Close method is called on each value
// implementing io.Closer before removing the value from ctx.
func (ctx *RequestCtx) SetUserValue(key string, value interface{}) {
func (ctx *RequestCtx) SetUserValue(key interface{}, value interface{}) {
ctx.userValues.Set(key, value)
}

Expand All @@ -688,7 +688,7 @@ func (ctx *RequestCtx) SetUserValueBytes(key []byte, value interface{}) {
}

// UserValue returns the value stored via SetUserValue* under the given key.
func (ctx *RequestCtx) UserValue(key string) interface{} {
func (ctx *RequestCtx) UserValue(key interface{}) interface{} {
return ctx.userValues.Get(key)
}

Expand All @@ -698,11 +698,24 @@ func (ctx *RequestCtx) UserValueBytes(key []byte) interface{} {
return ctx.userValues.GetBytes(key)
}

// VisitUserValues calls visitor for each existing userValue.
// VisitUserValues calls visitor for each existing userValue with a key that is a string or []byte.
//
// visitor must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (ctx *RequestCtx) VisitUserValues(visitor func([]byte, interface{})) {
for i, n := 0, len(ctx.userValues); i < n; i++ {
kv := &ctx.userValues[i]
if _, ok := kv.key.(string); ok {
visitor(s2b(kv.key.(string)), kv.value)
}
}
}

// VisitUserValuesAll calls visitor for each existing userValue.
//
// visitor must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (ctx *RequestCtx) VisitUserValuesAll(visitor func(interface{}, interface{})) {
for i, n := 0, len(ctx.userValues); i < n; i++ {
kv := &ctx.userValues[i]
visitor(kv.key, kv.value)
Expand All @@ -715,7 +728,7 @@ func (ctx *RequestCtx) ResetUserValues() {
}

// RemoveUserValue removes the given key and the value under it in ctx.
func (ctx *RequestCtx) RemoveUserValue(key string) {
func (ctx *RequestCtx) RemoveUserValue(key interface{}) {
ctx.userValues.Remove(key)
}

Expand Down Expand Up @@ -2696,10 +2709,7 @@ func (ctx *RequestCtx) Err() error {
// This method is present to make RequestCtx implement the context interface.
// This method is the same as calling ctx.UserValue(key)
func (ctx *RequestCtx) Value(key interface{}) interface{} {
if keyString, ok := key.(string); ok {
return ctx.UserValue(keyString)
}
return nil
return ctx.UserValue(key)
}

var fakeServer = &Server{
Expand Down
2 changes: 1 addition & 1 deletion server_test.go
Expand Up @@ -1737,7 +1737,7 @@ func TestRequestCtxUserValue(t *testing.T) {
vlen := 0
ctx.VisitUserValues(func(key []byte, value interface{}) {
vlen++
v := ctx.UserValueBytes(key)
v := ctx.UserValue(key)
if v != value {
t.Fatalf("unexpected value obtained from VisitUserValues for key: %q, expecting: %#v but got: %#v", key, v, value)
}
Expand Down
33 changes: 21 additions & 12 deletions userdata.go
Expand Up @@ -5,18 +5,21 @@ import (
)

type userDataKV struct {
key []byte
key interface{}
value interface{}
}

type userData []userDataKV

func (d *userData) Set(key string, value interface{}) {
func (d *userData) Set(key interface{}, value interface{}) {
if b, ok := key.([]byte); ok {
key = string(b)
}
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if string(kv.key) == key {
if kv.key == key {
kv.value = value
return
}
Expand All @@ -30,36 +33,39 @@ func (d *userData) Set(key string, value interface{}) {
if c > n {
args = args[:n+1]
kv := &args[n]
kv.key = append(kv.key[:0], key...)
kv.key = key
kv.value = value
*d = args
return
}

kv := userDataKV{}
kv.key = append(kv.key[:0], key...)
kv.key = key
kv.value = value
*d = append(args, kv)
}

func (d *userData) SetBytes(key []byte, value interface{}) {
d.Set(b2s(key), value)
d.Set(key, value)
}

func (d *userData) Get(key string) interface{} {
func (d *userData) Get(key interface{}) interface{} {
if b, ok := key.([]byte); ok {
key = b2s(b)
}
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if string(kv.key) == key {
if kv.key == key {
return kv.value
}
}
return nil
}

func (d *userData) GetBytes(key []byte) interface{} {
return d.Get(b2s(key))
return d.Get(key)
}

func (d *userData) Reset() {
Expand All @@ -74,12 +80,15 @@ func (d *userData) Reset() {
*d = (*d)[:0]
}

func (d *userData) Remove(key string) {
func (d *userData) Remove(key interface{}) {
if b, ok := key.([]byte); ok {
key = b2s(b)
}
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if string(kv.key) == key {
if kv.key == key {
n--
args[i], args[n] = args[n], args[i]
args[n].value = nil
Expand All @@ -91,5 +100,5 @@ func (d *userData) Remove(key string) {
}

func (d *userData) RemoveBytes(key []byte) {
d.Remove(b2s(key))
d.Remove(key)
}

0 comments on commit d404f2d

Please sign in to comment.