Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make RequestCtx's userdata accept keys that are of type: interface{} #1387

Merged
merged 1 commit into from Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 18 additions & 8 deletions server.go
Expand Up @@ -670,7 +670,7 @@ func (ctx *RequestCtx) Hijacked() bool {
// All the values are removed from ctx after returning from the top
// RequestHandler. Additionally, Close method is called on each value
// implementing io.Closer before removing the value from ctx.
func (ctx *RequestCtx) SetUserValue(key string, value interface{}) {
func (ctx *RequestCtx) SetUserValue(key interface{}, value interface{}) {
ctx.userValues.Set(key, value)
}

Expand All @@ -688,7 +688,7 @@ func (ctx *RequestCtx) SetUserValueBytes(key []byte, value interface{}) {
}

// UserValue returns the value stored via SetUserValue* under the given key.
func (ctx *RequestCtx) UserValue(key string) interface{} {
func (ctx *RequestCtx) UserValue(key interface{}) interface{} {
return ctx.userValues.Get(key)
}

Expand All @@ -698,11 +698,24 @@ func (ctx *RequestCtx) UserValueBytes(key []byte) interface{} {
return ctx.userValues.GetBytes(key)
}

// VisitUserValues calls visitor for each existing userValue.
// VisitUserValues calls visitor for each existing userValue with a key that is a string or []byte.
//
// visitor must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (ctx *RequestCtx) VisitUserValues(visitor func([]byte, interface{})) {
for i, n := 0, len(ctx.userValues); i < n; i++ {
kv := &ctx.userValues[i]
if _, ok := kv.key.(string); ok {
visitor(s2b(kv.key.(string)), kv.value)
}
}
}

// VisitUserValuesAll calls visitor for each existing userValue.
//
// visitor must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (ctx *RequestCtx) VisitUserValuesAll(visitor func(interface{}, interface{})) {
for i, n := 0, len(ctx.userValues); i < n; i++ {
kv := &ctx.userValues[i]
visitor(kv.key, kv.value)
Expand All @@ -715,7 +728,7 @@ func (ctx *RequestCtx) ResetUserValues() {
}

// RemoveUserValue removes the given key and the value under it in ctx.
func (ctx *RequestCtx) RemoveUserValue(key string) {
func (ctx *RequestCtx) RemoveUserValue(key interface{}) {
ctx.userValues.Remove(key)
}

Expand Down Expand Up @@ -2696,10 +2709,7 @@ func (ctx *RequestCtx) Err() error {
// This method is present to make RequestCtx implement the context interface.
// This method is the same as calling ctx.UserValue(key)
func (ctx *RequestCtx) Value(key interface{}) interface{} {
if keyString, ok := key.(string); ok {
return ctx.UserValue(keyString)
}
return nil
return ctx.UserValue(key)
}

var fakeServer = &Server{
Expand Down
2 changes: 1 addition & 1 deletion server_test.go
Expand Up @@ -1737,7 +1737,7 @@ func TestRequestCtxUserValue(t *testing.T) {
vlen := 0
ctx.VisitUserValues(func(key []byte, value interface{}) {
vlen++
v := ctx.UserValueBytes(key)
v := ctx.UserValue(key)
if v != value {
t.Fatalf("unexpected value obtained from VisitUserValues for key: %q, expecting: %#v but got: %#v", key, v, value)
}
Expand Down
33 changes: 21 additions & 12 deletions userdata.go
Expand Up @@ -5,18 +5,21 @@ import (
)

type userDataKV struct {
key []byte
key interface{}
value interface{}
}

type userData []userDataKV

func (d *userData) Set(key string, value interface{}) {
func (d *userData) Set(key interface{}, value interface{}) {
if b, ok := key.([]byte); ok {
key = string(b)
}
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if string(kv.key) == key {
if kv.key == key {
kv.value = value
return
}
Expand All @@ -30,36 +33,39 @@ func (d *userData) Set(key string, value interface{}) {
if c > n {
args = args[:n+1]
kv := &args[n]
kv.key = append(kv.key[:0], key...)
kv.key = key
kv.value = value
*d = args
return
}

kv := userDataKV{}
kv.key = append(kv.key[:0], key...)
kv.key = key
kv.value = value
*d = append(args, kv)
}

func (d *userData) SetBytes(key []byte, value interface{}) {
d.Set(b2s(key), value)
d.Set(key, value)
}

func (d *userData) Get(key string) interface{} {
func (d *userData) Get(key interface{}) interface{} {
if b, ok := key.([]byte); ok {
key = b2s(b)
}
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if string(kv.key) == key {
if kv.key == key {
return kv.value
}
}
return nil
}

func (d *userData) GetBytes(key []byte) interface{} {
return d.Get(b2s(key))
return d.Get(key)
}

func (d *userData) Reset() {
Expand All @@ -74,12 +80,15 @@ func (d *userData) Reset() {
*d = (*d)[:0]
}

func (d *userData) Remove(key string) {
func (d *userData) Remove(key interface{}) {
if b, ok := key.([]byte); ok {
key = b2s(b)
}
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if string(kv.key) == key {
if kv.key == key {
n--
args[i], args[n] = args[n], args[i]
args[n].value = nil
Expand All @@ -91,5 +100,5 @@ func (d *userData) Remove(key string) {
}

func (d *userData) RemoveBytes(key []byte) {
d.Remove(b2s(key))
d.Remove(key)
}