From dfa57a516243fb6df82bad27122806b739247501 Mon Sep 17 00:00:00 2001 From: woorui Date: Thu, 22 Dec 2022 19:25:43 +0800 Subject: [PATCH 1/5] feat(server): add start handler --- core/client_test.go | 32 ++++++++++++++++++++++++++++++++ core/server.go | 15 +++++++++++++++ core/yerr/errors.go | 3 +++ 3 files changed, 50 insertions(+) diff --git a/core/client_test.go b/core/client_test.go index 3e788aab1..6cec945dc 100644 --- a/core/client_test.go +++ b/core/client_test.go @@ -50,6 +50,12 @@ func TestFrameRoundTrip(t *testing.T) { server.ConfigMetadataBuilder(metadata.DefaultBuilder()) server.ConfigRouter(router.Default([]config.App{{Name: "sfn-1"}})) + // test server hooks + ht := &hookTester{t} + server.SetStartHandlers(ht.startHandler) + server.SetBeforeHandlers(ht.beforeHandler) + server.SetAfterHandlers(ht.afterHandler) + w := newMockFrameWriter() server.AddDownstreamServer("mockAddr", w) @@ -114,6 +120,32 @@ func TestFrameRoundTrip(t *testing.T) { assert.NoError(t, sfn.Close(), "sfn client.Close() should not return error") } +type hookTester struct { + t *testing.T +} + +func (a *hookTester) startHandler(ctx *Context) error { + ctx.Set("start", "yes") + return nil +} + +func (a *hookTester) beforeHandler(ctx *Context) error { + ctx.Set("before", "ok") + return nil +} + +func (a *hookTester) afterHandler(ctx *Context) error { + v, ok := ctx.Get("start") + assert.True(a.t, ok) + assert.Equal(a.t, v, "yes") + + v, ok = ctx.Get("before") + assert.True(a.t, ok) + assert.Equal(a.t, v, "ok") + + return nil +} + // mockFrameWriter mock a FrameWriter type mockFrameWriter struct { mu sync.Mutex diff --git a/core/server.go b/core/server.go index 70a24ad75..ccd7da158 100644 --- a/core/server.go +++ b/core/server.go @@ -39,6 +39,7 @@ type Server struct { downstreams map[string]frame.Writer mu sync.Mutex opts *serverOptions + startHandlers []FrameHandler beforeHandlers []FrameHandler afterHandlers []FrameHandler connectionCloseHandlers []ConnectionHandler @@ -160,6 +161,14 @@ func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { yctx, ok := s.handshakeWithTimeout(conn, stream, 10*time.Second) + for _, handler := range s.startHandlers { + if err := handler(yctx); err != nil { + yctx.logger.Error("startHandlers error", err) + yctx.CloseWithError(yerr.ErrorCodeStartHandler, err.Error()) + return + } + } + defer func() { if yctx != nil { yctx.Clean() @@ -600,6 +609,12 @@ func (s *Server) Connector() Connector { return s.connector } +// SetStartHandlers sets a function for operating connection, +// this function executes after handshake successful. +func (s *Server) SetStartHandlers(handlers ...FrameHandler) { + s.startHandlers = append(s.beforeHandlers, handlers...) +} + // SetBeforeHandlers set the before handlers of server. func (s *Server) SetBeforeHandlers(handlers ...FrameHandler) { s.beforeHandlers = append(s.beforeHandlers, handlers...) diff --git a/core/yerr/errors.go b/core/yerr/errors.go index 735445cc5..600d0c468 100644 --- a/core/yerr/errors.go +++ b/core/yerr/errors.go @@ -64,6 +64,8 @@ const ( ErrorCodeUnknownClient ErrorCode = 0xCD // ErrorCodeDuplicateName unknown client error ErrorCodeDuplicateName ErrorCode = 0xC6 + // ErrorCodeStartHandler start handler + ErrorCodeStartHandler ErrorCode = 0xC8 ) var errCodeStringMap = map[ErrorCode]string{ @@ -79,6 +81,7 @@ var errCodeStringMap = map[ErrorCode]string{ ErrorCodeData: "DataFrame", ErrorCodeUnknownClient: "UnknownClient", ErrorCodeDuplicateName: "DuplicateName", + ErrorCodeStartHandler: "StartHandler", } func (e ErrorCode) String() string { From d0dda64041f61f8a7baa43e3d2a897ccbf2368ed Mon Sep 17 00:00:00 2001 From: woorui Date: Thu, 22 Dec 2022 19:28:26 +0800 Subject: [PATCH 2/5] fix(server): set start handlers bug --- core/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/server.go b/core/server.go index ccd7da158..790df8460 100644 --- a/core/server.go +++ b/core/server.go @@ -612,7 +612,7 @@ func (s *Server) Connector() Connector { // SetStartHandlers sets a function for operating connection, // this function executes after handshake successful. func (s *Server) SetStartHandlers(handlers ...FrameHandler) { - s.startHandlers = append(s.beforeHandlers, handlers...) + s.startHandlers = append(s.startHandlers, handlers...) } // SetBeforeHandlers set the before handlers of server. From 1be5c2a3534e2a63a64b2070dd28c7991500d064 Mon Sep 17 00:00:00 2001 From: woorui Date: Thu, 22 Dec 2022 23:17:24 +0800 Subject: [PATCH 3/5] feat(context): make context logger public --- core/context.go | 12 ++++---- core/server.go | 70 ++++++++++++++++++++++----------------------- core/server_test.go | 6 ++-- 3 files changed, 44 insertions(+), 44 deletions(-) diff --git a/core/context.go b/core/context.go index 009771d63..65c04b0d3 100644 --- a/core/context.go +++ b/core/context.go @@ -27,7 +27,7 @@ type Context struct { mu sync.RWMutex - logger *slog.Logger + Logger *slog.Logger } func newContext(conn quic.Connection, stream quic.Stream, logger *slog.Logger) (ctx *Context) { @@ -40,7 +40,7 @@ func newContext(conn quic.Connection, stream quic.Stream, logger *slog.Logger) ( ctx.Conn = conn ctx.Stream = stream ctx.connID = conn.RemoteAddr().String() - ctx.logger = logger.With("conn_id", conn.RemoteAddr().String(), "stream_id", stream.StreamID()) + ctx.Logger = logger.With("conn_id", conn.RemoteAddr().String(), "stream_id", stream.StreamID()) return } @@ -67,7 +67,7 @@ func (c *Context) ClientInfo() *ClientInfo { func (c *Context) WithFrame(f frame.Frame) *Context { if f.Type() == frame.TagOfHandshakeFrame { handshakeFrame := f.(*frame.HandshakeFrame) - c.logger = c.logger.With( + c.Logger = c.Logger.With( "client_id", handshakeFrame.ClientID, "client_type", ClientType(handshakeFrame.ClientType).String(), "client_name", handshakeFrame.Name, @@ -86,7 +86,7 @@ func (c *Context) WithFrame(f frame.Frame) *Context { // Clean the context. func (c *Context) Clean() { - c.logger.Debug("conn context clean", "conn_id", c.connID) + c.Logger.Debug("conn context clean", "conn_id", c.connID) c.reset() ctxPool.Put(c) } @@ -96,7 +96,7 @@ func (c *Context) reset() { c.connID = "" c.Stream = nil c.Frame = nil - c.logger = nil + c.Logger = nil for k := range c.Keys { delete(c.Keys, k) } @@ -104,7 +104,7 @@ func (c *Context) reset() { // CloseWithError closes the stream and cleans the context. func (c *Context) CloseWithError(code yerr.ErrorCode, msg string) { - c.logger.Debug("conn context close, ", "err_code", code, "err_msg", msg) + c.Logger.Debug("conn context close, ", "err_code", code, "err_msg", msg) if c.Stream != nil { c.Stream.Close() } diff --git a/core/server.go b/core/server.go index 790df8460..de1c37a5a 100644 --- a/core/server.go +++ b/core/server.go @@ -163,7 +163,7 @@ func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { for _, handler := range s.startHandlers { if err := handler(yctx); err != nil { - yctx.logger.Error("startHandlers error", err) + yctx.Logger.Error("startHandlers error", err) yctx.CloseWithError(yerr.ErrorCodeStartHandler, err.Error()) return } @@ -182,7 +182,7 @@ func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { s.logger.Info("stream created", "stream_id", stream.StreamID(), "conn_id", connID) s.handleConnection(yctx) - yctx.logger.Info("stream handleConnection DONE") + yctx.Logger.Info("stream handleConnection DONE") } }(sctx, conn) } @@ -230,7 +230,7 @@ func (s *Server) handshake(conn quic.Connection, stream quic.Stream, fs frame.Re } if frm.Type() != frame.TagOfHandshakeFrame { - c.logger.Info("client not do handshake right off") + c.Logger.Info("client not do handshake right off") if err := fs.WriteFrame(frame.NewGoawayFrame("handshake failed")); err != nil { s.logger.Error("write to client GoawayFrame error", err) } @@ -238,7 +238,7 @@ func (s *Server) handshake(conn quic.Connection, stream quic.Stream, fs frame.Re } if err := s.handleHandshakeFrame(c); err != nil { - c.logger.Info("handshake failed", "error", err) + c.Logger.Info("handshake failed", "error", err) if err := fs.WriteFrame(frame.NewGoawayFrame(err.Error())); err != nil { s.logger.Error("write to client GoawayFrame error", err) } @@ -270,7 +270,7 @@ func (s *Server) Close() error { } // handleConnection handles streams on a connection, -// use c.logger in this function scope for more complete logger information. +// use c.Logger in this function scope for more complete logger information. func (s *Server) handleConnection(c *Context) { fs := NewFrameStream(c.Stream) // check update for stream @@ -281,27 +281,27 @@ func (s *Server) handleConnection(c *Context) { if e, ok := err.(*quic.ApplicationError); ok { if yerr.Is(e.ErrorCode, yerr.ErrorCodeClientAbort) { // client abort - c.logger.Info("client close the connection") + c.Logger.Info("client close the connection") break } else { ye := yerr.New(yerr.Parse(e.ErrorCode), err) - c.logger.Error("read frame error", ye) + c.Logger.Error("read frame error", ye) } } else if err == io.EOF { - c.logger.Info("connection EOF") + c.Logger.Info("connection EOF") break } if errors.Is(err, net.ErrClosed) { // if client close the connection, net.ErrClosed will be raise // by quic-go IdleTimeoutError after connection's KeepAlive config. - c.logger.Warn("connection error", "error", net.ErrClosed) + c.Logger.Warn("connection error", "error", net.ErrClosed) c.CloseWithError(yerr.ErrorCodeClosed, "net.ErrClosed") break } // any error occurred, we should close the stream // after this, conn.AcceptStream() will raise the error c.CloseWithError(yerr.ErrorCodeUnknown, err.Error()) - c.logger.Warn("connection close") + c.Logger.Warn("connection close") break } @@ -311,21 +311,21 @@ func (s *Server) handleConnection(c *Context) { // before frame handlers for _, handler := range s.beforeHandlers { if err := handler(c); err != nil { - c.logger.Error("beforeFrameHandler error", err) + c.Logger.Error("beforeFrameHandler error", err) c.CloseWithError(yerr.ErrorCodeBeforeHandler, err.Error()) return } } // main handler if err := s.mainFrameHandler(c); err != nil { - c.logger.Error("mainFrameHandler error", err) + c.Logger.Error("mainFrameHandler error", err) c.CloseWithError(yerr.ErrorCodeMainHandler, err.Error()) return } // after frame handler for _, handler := range s.afterHandlers { if err := handler(c); err != nil { - c.logger.Error("afterFrameHandler error", err) + c.Logger.Error("afterFrameHandler error", err) c.CloseWithError(yerr.ErrorCodeAfterHandler, err.Error()) return } @@ -338,7 +338,7 @@ func (s *Server) mainFrameHandler(c *Context) error { switch frameType { case frame.TagOfHandshakeFrame: - c.logger.Warn("receive a handshakeFrame, ingonre it") + c.Logger.Warn("receive a handshakeFrame, ingonre it") case frame.TagOfDataFrame: if err := s.handleDataFrame(c); err != nil { c.CloseWithError(yerr.ErrorCodeData, fmt.Sprintf("handleDataFrame err: %v", err)) @@ -349,7 +349,7 @@ func (s *Server) mainFrameHandler(c *Context) error { s.handleBackflowFrame(c) } default: - c.logger.Warn("unexpected frame", "unexpected_frame_type", frameType) + c.Logger.Warn("unexpected frame", "unexpected_frame_type", frameType) } return nil } @@ -364,17 +364,17 @@ func (s *Server) handleHandshakeFrame(c *Context) error { clientType := ClientType(f.ClientType) stream := c.Stream // credential - c.logger.Debug("GOT HandshakeFrame", "client_type", f.ClientType, "client_id", clientID, "auth_name", authName(f.AuthName())) + c.Logger.Debug("GOT HandshakeFrame", "client_type", f.ClientType, "client_id", clientID, "auth_name", authName(f.AuthName())) // authenticate authed := auth.Authenticate(s.opts.auths, f) - c.logger.Debug("authenticate", "authed", authed) + c.Logger.Debug("authenticate", "authed", authed) if !authed { err := fmt.Errorf("handshake authentication fails, client credential name is %s", authName(f.AuthName())) // return err - c.logger.Debug("authenticated", "authed", authed) + c.Logger.Debug("authenticated", "authed", authed) rejectedFrame := frame.NewRejectedFrame(err.Error()) if _, err = stream.Write(rejectedFrame.Encode()); err != nil { - c.logger.Error("write to RejectedFrame failed", err, "authed", authed) + c.Logger.Error("write to RejectedFrame failed", err, "authed", authed) return err } return nil @@ -389,7 +389,7 @@ func (s *Server) handleHandshakeFrame(c *Context) error { if err != nil { return err } - conn = newConnection(f.Name, f.ClientID, clientType, metadata, stream, f.ObserveDataTags, c.logger) + conn = newConnection(f.Name, f.ClientID, clientType, metadata, stream, f.ObserveDataTags, c.Logger) if clientType == ClientTypeStreamFunction { // route @@ -402,10 +402,10 @@ func (s *Server) handleHandshakeFrame(c *Context) error { if e, ok := err.(yerr.DuplicateNameError); ok { existsConnID := e.ConnID() if conn := s.connector.Get(existsConnID); conn != nil { - c.logger.Debug("write GoawayFrame", "error", e.Error(), "exists_conn_id", existsConnID) + c.Logger.Debug("write GoawayFrame", "error", e.Error(), "exists_conn_id", existsConnID) goawayFrame := frame.NewGoawayFrame(e.Error()) if err := conn.Write(goawayFrame); err != nil { - c.logger.Error("write GoawayFrame failed", err) + c.Logger.Error("write GoawayFrame failed", err) return err } } @@ -415,7 +415,7 @@ func (s *Server) handleHandshakeFrame(c *Context) error { } } case ClientTypeUpstreamZipper: - conn = newConnection(f.Name, f.ClientID, clientType, nil, stream, f.ObserveDataTags, c.logger) + conn = newConnection(f.Name, f.ClientID, clientType, nil, stream, f.ObserveDataTags, c.Logger) default: // TODO: There is no need to Remove, // unknown client type is not be add to connector. @@ -426,11 +426,11 @@ func (s *Server) handleHandshakeFrame(c *Context) error { } if _, err := stream.Write(frame.NewHandshakeAckFrame().Encode()); err != nil { - c.logger.Error("write handshakeAckFrame error", err) + c.Logger.Error("write handshakeAckFrame error", err) } s.connector.Add(connID, conn) - c.logger.Info("client is connected!") + c.Logger.Info("client is connected!") return nil } @@ -447,7 +447,7 @@ func (s *Server) handleDataFrame(c *Context) error { fromID := c.ConnID() from := s.connector.Get(fromID) if from == nil { - c.logger.Warn("handleDataFrame connector cannot find", "from_conn_id", fromID) + c.Logger.Warn("handleDataFrame connector cannot find", "from_conn_id", fromID) return fmt.Errorf("handleDataFrame connector cannot find %s", fromID) } @@ -465,7 +465,7 @@ func (s *Server) handleDataFrame(c *Context) error { // route route := s.router.Route(metadata) if route == nil { - c.logger.Warn("handleDataFrame route is nil") + c.Logger.Warn("handleDataFrame route is nil") return fmt.Errorf("handleDataFrame route is nil") } @@ -474,12 +474,12 @@ func (s *Server) handleDataFrame(c *Context) error { for _, toID := range connIDs { conn := s.connector.Get(toID) if conn == nil { - c.logger.Error("Can't find forward conn", errors.New("conn is nil"), "forward_conn_id", toID) + c.Logger.Error("Can't find forward conn", errors.New("conn is nil"), "forward_conn_id", toID) continue } to := conn.Name() - c.logger.Info( + c.Logger.Info( "handleDataFrame", "from_conn_name", from.Name(), "from_conn_id", fromID, @@ -490,7 +490,7 @@ func (s *Server) handleDataFrame(c *Context) error { // write data frame to stream if err := conn.Write(f); err != nil { - c.logger.Error("handleDataFrame conn.Write", err) + c.Logger.Error("handleDataFrame conn.Write", err) } } @@ -507,9 +507,9 @@ func (s *Server) handleBackflowFrame(c *Context) error { sourceConns := s.connector.GetSourceConns(sourceID, tag) for _, source := range sourceConns { if source != nil { - c.logger.Info("handleBackflowFrame", "source_conn_id", sourceID, "back_flow_frame", f.String()) + c.Logger.Info("handleBackflowFrame", "source_conn_id", sourceID, "back_flow_frame", f.String()) if err := source.Write(bf); err != nil { - c.logger.Error("handleBackflowFrame conn.Write", err) + c.Logger.Error("handleBackflowFrame conn.Write", err) return err } } @@ -568,7 +568,7 @@ func (s *Server) AddDownstreamServer(addr string, c frame.Writer) { func (s *Server) dispatchToDownstreams(c *Context) { conn := s.connector.Get(c.connID) if conn == nil { - c.logger.Debug("dispatchToDownstreams failed") + c.Logger.Debug("dispatchToDownstreams failed") } else if conn.ClientType() == ClientTypeSource { f := c.Frame.(*frame.DataFrame) if f.IsBroadcast() { @@ -576,11 +576,11 @@ func (s *Server) dispatchToDownstreams(c *Context) { f.GetMetaFrame().SetMetadata(conn.Metadata().Encode()) } for addr, ds := range s.downstreams { - c.logger.Info("dispatching to", "dispatch_addr", addr, "tid", f.TransactionID()) + c.Logger.Info("dispatching to", "dispatch_addr", addr, "tid", f.TransactionID()) ds.WriteFrame(f) } } else { - c.logger.Info("do not broadcast", "tid", f.TransactionID()) + c.Logger.Info("do not broadcast", "tid", f.TransactionID()) } } } diff --git a/core/server_test.go b/core/server_test.go index 9da04da96..985280f45 100644 --- a/core/server_test.go +++ b/core/server_test.go @@ -141,7 +141,7 @@ func TestHandleDataFrame(t *testing.T) { connID: sourceConnID, Stream: sourceStream, Frame: dataFrame, - logger: server.logger, + Logger: server.logger, } err := server.handleDataFrame(c) @@ -172,7 +172,7 @@ func TestHandleDataFrame(t *testing.T) { connID: zipperConnID, Stream: zipperStream, Frame: dataFrame, - logger: server.logger, + Logger: server.logger, } err := server.handleDataFrame(c) @@ -308,7 +308,7 @@ func TestHandShake(t *testing.T) { connID: clientID, Stream: stream, Frame: frame.NewHandshakeFrame(clientName, clientID, clientType, []frame.Tag{frame.Tag(1)}, "token", token), - logger: server.logger, + Logger: server.logger, } for n := 0; n < tt.handshakeTimes; n++ { From 2443e87e0e5c5238ee0ee05ec320dbf273552fde Mon Sep 17 00:00:00 2001 From: woorui Date: Thu, 22 Dec 2022 23:17:49 +0800 Subject: [PATCH 4/5] feat(context): make context ClientInfo public --- core/context.go | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/core/context.go b/core/context.go index 65c04b0d3..dd985b0cf 100644 --- a/core/context.go +++ b/core/context.go @@ -48,10 +48,14 @@ const clientInfoKey = "client_info" // ClientInfo holds client info, you can use `*Context.ClientInfo()` to get it after handshake. type ClientInfo struct { - clientID string - clientType byte - clientName string - authName string + // ID is client id from handshake. + ID string + // Type is client type from handshake. + Type byte + // Type is client type from handshake. + Name string + // AuthName is client authName from handshake. + AuthName string } // ClientInfo get client info from context. @@ -74,10 +78,10 @@ func (c *Context) WithFrame(f frame.Frame) *Context { "auth_name", handshakeFrame.AuthName(), ) c.Set(clientInfoKey, &ClientInfo{ - clientID: handshakeFrame.ClientID, - clientType: handshakeFrame.ClientType, - clientName: handshakeFrame.Name, - authName: handshakeFrame.AuthName(), + ID: handshakeFrame.ClientID, + Type: handshakeFrame.ClientType, + Name: handshakeFrame.Name, + AuthName: handshakeFrame.AuthName(), }) } c.Frame = f From 469361fffcc9c7be68a905f24b21dd3979a0171a Mon Sep 17 00:00:00 2001 From: woorui Date: Fri, 23 Dec 2022 17:49:29 +0800 Subject: [PATCH 5/5] fix: typo --- core/context.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/context.go b/core/context.go index dd985b0cf..f0173d5b3 100644 --- a/core/context.go +++ b/core/context.go @@ -52,7 +52,7 @@ type ClientInfo struct { ID string // Type is client type from handshake. Type byte - // Type is client type from handshake. + // Name is client name from handshake. Name string // AuthName is client authName from handshake. AuthName string