Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: server start handler #419

Merged
merged 6 commits into from Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 32 additions & 0 deletions core/client_test.go
Expand Up @@ -50,6 +50,12 @@ func TestFrameRoundTrip(t *testing.T) {
server.ConfigMetadataBuilder(metadata.DefaultBuilder())
server.ConfigRouter(router.Default([]config.App{{Name: "sfn-1"}}))

// test server hooks
ht := &hookTester{t}
server.SetStartHandlers(ht.startHandler)
server.SetBeforeHandlers(ht.beforeHandler)
server.SetAfterHandlers(ht.afterHandler)

w := newMockFrameWriter()
server.AddDownstreamServer("mockAddr", w)

Expand Down Expand Up @@ -114,6 +120,32 @@ func TestFrameRoundTrip(t *testing.T) {
assert.NoError(t, sfn.Close(), "sfn client.Close() should not return error")
}

type hookTester struct {
t *testing.T
}

func (a *hookTester) startHandler(ctx *Context) error {
ctx.Set("start", "yes")
return nil
}

func (a *hookTester) beforeHandler(ctx *Context) error {
ctx.Set("before", "ok")
return nil
}

func (a *hookTester) afterHandler(ctx *Context) error {
v, ok := ctx.Get("start")
assert.True(a.t, ok)
assert.Equal(a.t, v, "yes")

v, ok = ctx.Get("before")
assert.True(a.t, ok)
assert.Equal(a.t, v, "ok")

return nil
}

// mockFrameWriter mock a FrameWriter
type mockFrameWriter struct {
mu sync.Mutex
Expand Down
32 changes: 18 additions & 14 deletions core/context.go
Expand Up @@ -27,7 +27,7 @@ type Context struct {

mu sync.RWMutex

logger *slog.Logger
Logger *slog.Logger
}

func newContext(conn quic.Connection, stream quic.Stream, logger *slog.Logger) (ctx *Context) {
Expand All @@ -40,18 +40,22 @@ func newContext(conn quic.Connection, stream quic.Stream, logger *slog.Logger) (
ctx.Conn = conn
ctx.Stream = stream
ctx.connID = conn.RemoteAddr().String()
ctx.logger = logger.With("conn_id", conn.RemoteAddr().String(), "stream_id", stream.StreamID())
ctx.Logger = logger.With("conn_id", conn.RemoteAddr().String(), "stream_id", stream.StreamID())
return
}

const clientInfoKey = "client_info"

// ClientInfo holds client info, you can use `*Context.ClientInfo()` to get it after handshake.
type ClientInfo struct {
clientID string
clientType byte
clientName string
authName string
// ID is client id from handshake.
ID string
// Type is client type from handshake.
Type byte
// Type is client type from handshake.
Name string
// AuthName is client authName from handshake.
AuthName string
}

// ClientInfo get client info from context.
Expand All @@ -67,17 +71,17 @@ func (c *Context) ClientInfo() *ClientInfo {
func (c *Context) WithFrame(f frame.Frame) *Context {
if f.Type() == frame.TagOfHandshakeFrame {
handshakeFrame := f.(*frame.HandshakeFrame)
c.logger = c.logger.With(
c.Logger = c.Logger.With(
"client_id", handshakeFrame.ClientID,
"client_type", ClientType(handshakeFrame.ClientType).String(),
"client_name", handshakeFrame.Name,
"auth_name", handshakeFrame.AuthName(),
)
c.Set(clientInfoKey, &ClientInfo{
clientID: handshakeFrame.ClientID,
clientType: handshakeFrame.ClientType,
clientName: handshakeFrame.Name,
authName: handshakeFrame.AuthName(),
ID: handshakeFrame.ClientID,
Type: handshakeFrame.ClientType,
Name: handshakeFrame.Name,
AuthName: handshakeFrame.AuthName(),
})
}
c.Frame = f
Expand All @@ -86,7 +90,7 @@ func (c *Context) WithFrame(f frame.Frame) *Context {

// Clean the context.
func (c *Context) Clean() {
c.logger.Debug("conn context clean", "conn_id", c.connID)
c.Logger.Debug("conn context clean", "conn_id", c.connID)
c.reset()
ctxPool.Put(c)
}
Expand All @@ -96,15 +100,15 @@ func (c *Context) reset() {
c.connID = ""
c.Stream = nil
c.Frame = nil
c.logger = nil
c.Logger = nil
for k := range c.Keys {
delete(c.Keys, k)
}
}

// CloseWithError closes the stream and cleans the context.
func (c *Context) CloseWithError(code yerr.ErrorCode, msg string) {
c.logger.Debug("conn context close, ", "err_code", code, "err_msg", msg)
c.Logger.Debug("conn context close, ", "err_code", code, "err_msg", msg)
if c.Stream != nil {
c.Stream.Close()
}
Expand Down