From 7fdd5261e8cc744f04c0a90d7319688179cf78e9 Mon Sep 17 00:00:00 2001 From: tyltr Date: Fri, 8 Oct 2021 23:45:45 +0800 Subject: [PATCH] feat: a new userData API `Remove` (#1117) * feat:userData new api "delete" * ctx api `remove` * rename * modify --- server.go | 10 ++++++++++ userdata.go | 21 +++++++++++++++++++++ userdata_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+) diff --git a/server.go b/server.go index c8a9966309..4fef94cac9 100644 --- a/server.go +++ b/server.go @@ -698,6 +698,16 @@ func (ctx *RequestCtx) ResetUserValues() { ctx.userValues.Reset() } +// RemoveUserValue removes the given key and the value under it in ctx. +func (ctx *RequestCtx) RemoveUserValue(key string) { + ctx.userValues.Remove(key) +} + +// RemoveUserValueBytes removes the given key and the value under it in ctx. +func (ctx *RequestCtx) RemoveUserValueBytes(key []byte) { + ctx.userValues.RemoveBytes(key) +} + type connTLSer interface { Handshake() error ConnectionState() tls.ConnectionState diff --git a/userdata.go b/userdata.go index 8d2eff9f74..9a7c98835c 100644 --- a/userdata.go +++ b/userdata.go @@ -21,6 +21,7 @@ func (d *userData) Set(key string, value interface{}) { return } } + if value == nil { return } @@ -72,3 +73,23 @@ func (d *userData) Reset() { } *d = (*d)[:0] } + +func (d *userData) Remove(key string) { + args := *d + n := len(args) + for i := 0; i < n; i++ { + kv := &args[i] + if string(kv.key) == key { + n-- + args[i] = args[n] + args[n].value = nil + args = args[:n] + *d = args + return + } + } +} + +func (d *userData) RemoveBytes(key []byte) { + d.Remove(b2s(key)) +} diff --git a/userdata_test.go b/userdata_test.go index d70d387bb6..94f04dd6d9 100644 --- a/userdata_test.go +++ b/userdata_test.go @@ -76,3 +76,31 @@ func (cv *closerValue) Close() error { (*cv.closeCalls)++ return nil } + +func TestUserDataDelete(t *testing.T) { + t.Parallel() + + var u userData + + for i := 0; i < 10; i++ { + key := fmt.Sprintf("key_%d", i) + u.Set(key, i) + testUserDataGet(t, &u, []byte(key), i) + } + + for i := 0; i < 10; i += 2 { + k := fmt.Sprintf("key_%d", i) + u.Remove(k) + if val := u.Get(k); val != nil { + t.Fatalf("unexpected key= %s, value =%v ,Expecting key= %s, value = nil", k, val, k) + } + kk := fmt.Sprintf("key_%d", i+1) + testUserDataGet(t, &u, []byte(kk), i+1) + } + for i := 0; i < 10; i++ { + key := fmt.Sprintf("key_new_%d", i) + u.Set(key, i) + testUserDataGet(t, &u, []byte(key), i) + } + +}