diff --git a/core/client_test.go b/core/client_test.go index 3e788aab1..6cec945dc 100644 --- a/core/client_test.go +++ b/core/client_test.go @@ -50,6 +50,12 @@ func TestFrameRoundTrip(t *testing.T) { server.ConfigMetadataBuilder(metadata.DefaultBuilder()) server.ConfigRouter(router.Default([]config.App{{Name: "sfn-1"}})) + // test server hooks + ht := &hookTester{t} + server.SetStartHandlers(ht.startHandler) + server.SetBeforeHandlers(ht.beforeHandler) + server.SetAfterHandlers(ht.afterHandler) + w := newMockFrameWriter() server.AddDownstreamServer("mockAddr", w) @@ -114,6 +120,32 @@ func TestFrameRoundTrip(t *testing.T) { assert.NoError(t, sfn.Close(), "sfn client.Close() should not return error") } +type hookTester struct { + t *testing.T +} + +func (a *hookTester) startHandler(ctx *Context) error { + ctx.Set("start", "yes") + return nil +} + +func (a *hookTester) beforeHandler(ctx *Context) error { + ctx.Set("before", "ok") + return nil +} + +func (a *hookTester) afterHandler(ctx *Context) error { + v, ok := ctx.Get("start") + assert.True(a.t, ok) + assert.Equal(a.t, v, "yes") + + v, ok = ctx.Get("before") + assert.True(a.t, ok) + assert.Equal(a.t, v, "ok") + + return nil +} + // mockFrameWriter mock a FrameWriter type mockFrameWriter struct { mu sync.Mutex diff --git a/core/context.go b/core/context.go index 009771d63..f0173d5b3 100644 --- a/core/context.go +++ b/core/context.go @@ -27,7 +27,7 @@ type Context struct { mu sync.RWMutex - logger *slog.Logger + Logger *slog.Logger } func newContext(conn quic.Connection, stream quic.Stream, logger *slog.Logger) (ctx *Context) { @@ -40,7 +40,7 @@ func newContext(conn quic.Connection, stream quic.Stream, logger *slog.Logger) ( ctx.Conn = conn ctx.Stream = stream ctx.connID = conn.RemoteAddr().String() - ctx.logger = logger.With("conn_id", conn.RemoteAddr().String(), "stream_id", stream.StreamID()) + ctx.Logger = logger.With("conn_id", conn.RemoteAddr().String(), "stream_id", stream.StreamID()) return } @@ -48,10 +48,14 @@ const clientInfoKey = "client_info" // ClientInfo holds client info, you can use `*Context.ClientInfo()` to get it after handshake. type ClientInfo struct { - clientID string - clientType byte - clientName string - authName string + // ID is client id from handshake. + ID string + // Type is client type from handshake. + Type byte + // Name is client name from handshake. + Name string + // AuthName is client authName from handshake. + AuthName string } // ClientInfo get client info from context. @@ -67,17 +71,17 @@ func (c *Context) ClientInfo() *ClientInfo { func (c *Context) WithFrame(f frame.Frame) *Context { if f.Type() == frame.TagOfHandshakeFrame { handshakeFrame := f.(*frame.HandshakeFrame) - c.logger = c.logger.With( + c.Logger = c.Logger.With( "client_id", handshakeFrame.ClientID, "client_type", ClientType(handshakeFrame.ClientType).String(), "client_name", handshakeFrame.Name, "auth_name", handshakeFrame.AuthName(), ) c.Set(clientInfoKey, &ClientInfo{ - clientID: handshakeFrame.ClientID, - clientType: handshakeFrame.ClientType, - clientName: handshakeFrame.Name, - authName: handshakeFrame.AuthName(), + ID: handshakeFrame.ClientID, + Type: handshakeFrame.ClientType, + Name: handshakeFrame.Name, + AuthName: handshakeFrame.AuthName(), }) } c.Frame = f @@ -86,7 +90,7 @@ func (c *Context) WithFrame(f frame.Frame) *Context { // Clean the context. func (c *Context) Clean() { - c.logger.Debug("conn context clean", "conn_id", c.connID) + c.Logger.Debug("conn context clean", "conn_id", c.connID) c.reset() ctxPool.Put(c) } @@ -96,7 +100,7 @@ func (c *Context) reset() { c.connID = "" c.Stream = nil c.Frame = nil - c.logger = nil + c.Logger = nil for k := range c.Keys { delete(c.Keys, k) } @@ -104,7 +108,7 @@ func (c *Context) reset() { // CloseWithError closes the stream and cleans the context. func (c *Context) CloseWithError(code yerr.ErrorCode, msg string) { - c.logger.Debug("conn context close, ", "err_code", code, "err_msg", msg) + c.Logger.Debug("conn context close, ", "err_code", code, "err_msg", msg) if c.Stream != nil { c.Stream.Close() } diff --git a/core/server.go b/core/server.go index 4452be7c7..8f26065ca 100644 --- a/core/server.go +++ b/core/server.go @@ -39,6 +39,7 @@ type Server struct { downstreams map[string]frame.Writer mu sync.Mutex opts *serverOptions + startHandlers []FrameHandler beforeHandlers []FrameHandler afterHandlers []FrameHandler connectionCloseHandlers []ConnectionHandler @@ -160,6 +161,14 @@ func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { yctx, ok := s.handshakeWithTimeout(conn, stream, 10*time.Second) + for _, handler := range s.startHandlers { + if err := handler(yctx); err != nil { + yctx.Logger.Error("startHandlers error", err) + yctx.CloseWithError(yerr.ErrorCodeStartHandler, err.Error()) + return + } + } + defer func() { if yctx != nil { yctx.Clean() @@ -173,7 +182,7 @@ func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { s.logger.Info("stream created", "stream_id", stream.StreamID(), "conn_id", connID) s.handleConnection(yctx) - yctx.logger.Info("stream handleConnection DONE") + yctx.Logger.Info("stream handleConnection DONE") } }(sctx, conn) } @@ -223,7 +232,7 @@ func (s *Server) handshake(conn quic.Connection, stream quic.Stream, fs frame.Re c = c.WithFrame(frm) if frm.Type() != frame.TagOfHandshakeFrame { - c.logger.Info("client not do handshake right off") + c.Logger.Info("client not do handshake right off") if err := fs.WriteFrame(frame.NewGoawayFrame("handshake failed")); err != nil { s.logger.Error("write to client GoawayFrame error", err) } @@ -231,7 +240,7 @@ func (s *Server) handshake(conn quic.Connection, stream quic.Stream, fs frame.Re } if err := s.handleHandshakeFrame(c); err != nil { - c.logger.Info("handshake failed", "error", err) + c.Logger.Info("handshake failed", "error", err) if err := fs.WriteFrame(frame.NewGoawayFrame(err.Error())); err != nil { s.logger.Error("write to client GoawayFrame error", err) } @@ -263,7 +272,7 @@ func (s *Server) Close() error { } // handleConnection handles streams on a connection, -// use c.logger in this function scope for more complete logger information. +// use c.Logger in this function scope for more complete logger information. func (s *Server) handleConnection(c *Context) { fs := NewFrameStream(c.Stream) // check update for stream @@ -274,27 +283,27 @@ func (s *Server) handleConnection(c *Context) { if e, ok := err.(*quic.ApplicationError); ok { if yerr.Is(e.ErrorCode, yerr.ErrorCodeClientAbort) { // client abort - c.logger.Info("client close the connection") + c.Logger.Info("client close the connection") break } else { ye := yerr.New(yerr.Parse(e.ErrorCode), err) - c.logger.Error("read frame error", ye) + c.Logger.Error("read frame error", ye) } } else if err == io.EOF { - c.logger.Info("connection EOF") + c.Logger.Info("connection EOF") break } if errors.Is(err, net.ErrClosed) { // if client close the connection, net.ErrClosed will be raise // by quic-go IdleTimeoutError after connection's KeepAlive config. - c.logger.Warn("connection error", "error", net.ErrClosed) + c.Logger.Warn("connection error", "error", net.ErrClosed) c.CloseWithError(yerr.ErrorCodeClosed, "net.ErrClosed") break } // any error occurred, we should close the stream // after this, conn.AcceptStream() will raise the error c.CloseWithError(yerr.ErrorCodeUnknown, err.Error()) - c.logger.Warn("connection close") + c.Logger.Warn("connection close") break } @@ -304,21 +313,21 @@ func (s *Server) handleConnection(c *Context) { // before frame handlers for _, handler := range s.beforeHandlers { if err := handler(c); err != nil { - c.logger.Error("beforeFrameHandler error", err) + c.Logger.Error("beforeFrameHandler error", err) c.CloseWithError(yerr.ErrorCodeBeforeHandler, err.Error()) return } } // main handler if err := s.mainFrameHandler(c); err != nil { - c.logger.Error("mainFrameHandler error", err) + c.Logger.Error("mainFrameHandler error", err) c.CloseWithError(yerr.ErrorCodeMainHandler, err.Error()) return } // after frame handler for _, handler := range s.afterHandlers { if err := handler(c); err != nil { - c.logger.Error("afterFrameHandler error", err) + c.Logger.Error("afterFrameHandler error", err) c.CloseWithError(yerr.ErrorCodeAfterHandler, err.Error()) return } @@ -331,7 +340,7 @@ func (s *Server) mainFrameHandler(c *Context) error { switch frameType { case frame.TagOfHandshakeFrame: - c.logger.Warn("receive a handshakeFrame, ingonre it") + c.Logger.Warn("receive a handshakeFrame, ingonre it") case frame.TagOfDataFrame: if err := s.handleDataFrame(c); err != nil { c.CloseWithError(yerr.ErrorCodeData, fmt.Sprintf("handleDataFrame err: %v", err)) @@ -342,7 +351,7 @@ func (s *Server) mainFrameHandler(c *Context) error { s.handleBackflowFrame(c) } default: - c.logger.Warn("unexpected frame", "unexpected_frame_type", frameType) + c.Logger.Warn("unexpected frame", "unexpected_frame_type", frameType) } return nil } @@ -357,17 +366,17 @@ func (s *Server) handleHandshakeFrame(c *Context) error { clientType := ClientType(f.ClientType) stream := c.Stream // credential - c.logger.Debug("GOT HandshakeFrame", "client_type", f.ClientType, "client_id", clientID, "auth_name", authName(f.AuthName())) + c.Logger.Debug("GOT HandshakeFrame", "client_type", f.ClientType, "client_id", clientID, "auth_name", authName(f.AuthName())) // authenticate authed := auth.Authenticate(s.opts.auths, f) - c.logger.Debug("authenticate", "authed", authed) + c.Logger.Debug("authenticate", "authed", authed) if !authed { err := fmt.Errorf("handshake authentication fails, client credential name is %s", authName(f.AuthName())) // return err - c.logger.Debug("authenticated", "authed", authed) + c.Logger.Debug("authenticated", "authed", authed) rejectedFrame := frame.NewRejectedFrame(err.Error()) if _, err = stream.Write(rejectedFrame.Encode()); err != nil { - c.logger.Error("write to RejectedFrame failed", err, "authed", authed) + c.Logger.Error("write to RejectedFrame failed", err, "authed", authed) return err } return nil @@ -382,7 +391,7 @@ func (s *Server) handleHandshakeFrame(c *Context) error { if err != nil { return err } - conn = newConnection(f.Name, f.ClientID, clientType, metadata, stream, f.ObserveDataTags, c.logger) + conn = newConnection(f.Name, f.ClientID, clientType, metadata, stream, f.ObserveDataTags, c.Logger) if clientType == ClientTypeStreamFunction { // route @@ -395,10 +404,10 @@ func (s *Server) handleHandshakeFrame(c *Context) error { if e, ok := err.(yerr.DuplicateNameError); ok { existsConnID := e.ConnID() if conn := s.connector.Get(existsConnID); conn != nil { - c.logger.Debug("write GoawayFrame", "error", e.Error(), "exists_conn_id", existsConnID) + c.Logger.Debug("write GoawayFrame", "error", e.Error(), "exists_conn_id", existsConnID) goawayFrame := frame.NewGoawayFrame(e.Error()) if err := conn.Write(goawayFrame); err != nil { - c.logger.Error("write GoawayFrame failed", err) + c.Logger.Error("write GoawayFrame failed", err) return err } } @@ -408,7 +417,7 @@ func (s *Server) handleHandshakeFrame(c *Context) error { } } case ClientTypeUpstreamZipper: - conn = newConnection(f.Name, f.ClientID, clientType, nil, stream, f.ObserveDataTags, c.logger) + conn = newConnection(f.Name, f.ClientID, clientType, nil, stream, f.ObserveDataTags, c.Logger) default: // TODO: There is no need to Remove, // unknown client type is not be add to connector. @@ -419,11 +428,11 @@ func (s *Server) handleHandshakeFrame(c *Context) error { } if _, err := stream.Write(frame.NewHandshakeAckFrame().Encode()); err != nil { - c.logger.Error("write handshakeAckFrame error", err) + c.Logger.Error("write handshakeAckFrame error", err) } s.connector.Add(connID, conn) - c.logger.Info("client is connected!") + c.Logger.Info("client is connected!") return nil } @@ -440,7 +449,7 @@ func (s *Server) handleDataFrame(c *Context) error { fromID := c.ConnID() from := s.connector.Get(fromID) if from == nil { - c.logger.Warn("handleDataFrame connector cannot find", "from_conn_id", fromID) + c.Logger.Warn("handleDataFrame connector cannot find", "from_conn_id", fromID) return fmt.Errorf("handleDataFrame connector cannot find %s", fromID) } @@ -458,7 +467,7 @@ func (s *Server) handleDataFrame(c *Context) error { // route route := s.router.Route(metadata) if route == nil { - c.logger.Warn("handleDataFrame route is nil") + c.Logger.Warn("handleDataFrame route is nil") return fmt.Errorf("handleDataFrame route is nil") } @@ -467,12 +476,12 @@ func (s *Server) handleDataFrame(c *Context) error { for _, toID := range connIDs { conn := s.connector.Get(toID) if conn == nil { - c.logger.Error("Can't find forward conn", errors.New("conn is nil"), "forward_conn_id", toID) + c.Logger.Error("Can't find forward conn", errors.New("conn is nil"), "forward_conn_id", toID) continue } to := conn.Name() - c.logger.Info( + c.Logger.Info( "handleDataFrame", "from_conn_name", from.Name(), "from_conn_id", fromID, @@ -483,7 +492,7 @@ func (s *Server) handleDataFrame(c *Context) error { // write data frame to stream if err := conn.Write(f); err != nil { - c.logger.Error("handleDataFrame conn.Write", err) + c.Logger.Error("handleDataFrame conn.Write", err) } } @@ -500,9 +509,9 @@ func (s *Server) handleBackflowFrame(c *Context) error { sourceConns := s.connector.GetSourceConns(sourceID, tag) for _, source := range sourceConns { if source != nil { - c.logger.Info("handleBackflowFrame", "source_conn_id", sourceID, "back_flow_frame", f.String()) + c.Logger.Info("handleBackflowFrame", "source_conn_id", sourceID, "back_flow_frame", f.String()) if err := source.Write(bf); err != nil { - c.logger.Error("handleBackflowFrame conn.Write", err) + c.Logger.Error("handleBackflowFrame conn.Write", err) return err } } @@ -561,7 +570,7 @@ func (s *Server) AddDownstreamServer(addr string, c frame.Writer) { func (s *Server) dispatchToDownstreams(c *Context) { conn := s.connector.Get(c.connID) if conn == nil { - c.logger.Debug("dispatchToDownstreams failed") + c.Logger.Debug("dispatchToDownstreams failed") } else if conn.ClientType() == ClientTypeSource { f := c.Frame.(*frame.DataFrame) if f.IsBroadcast() { @@ -569,7 +578,7 @@ func (s *Server) dispatchToDownstreams(c *Context) { f.GetMetaFrame().SetMetadata(conn.Metadata().Encode()) } for addr, ds := range s.downstreams { - c.logger.Info("dispatching to", "dispatch_addr", addr, "tid", f.TransactionID()) + c.Logger.Info("dispatching to", "dispatch_addr", addr, "tid", f.TransactionID()) ds.WriteFrame(f) } } @@ -600,6 +609,12 @@ func (s *Server) Connector() Connector { return s.connector } +// SetStartHandlers sets a function for operating connection, +// this function executes after handshake successful. +func (s *Server) SetStartHandlers(handlers ...FrameHandler) { + s.startHandlers = append(s.startHandlers, handlers...) +} + // SetBeforeHandlers set the before handlers of server. func (s *Server) SetBeforeHandlers(handlers ...FrameHandler) { s.beforeHandlers = append(s.beforeHandlers, handlers...) diff --git a/core/server_test.go b/core/server_test.go index 9da04da96..985280f45 100644 --- a/core/server_test.go +++ b/core/server_test.go @@ -141,7 +141,7 @@ func TestHandleDataFrame(t *testing.T) { connID: sourceConnID, Stream: sourceStream, Frame: dataFrame, - logger: server.logger, + Logger: server.logger, } err := server.handleDataFrame(c) @@ -172,7 +172,7 @@ func TestHandleDataFrame(t *testing.T) { connID: zipperConnID, Stream: zipperStream, Frame: dataFrame, - logger: server.logger, + Logger: server.logger, } err := server.handleDataFrame(c) @@ -308,7 +308,7 @@ func TestHandShake(t *testing.T) { connID: clientID, Stream: stream, Frame: frame.NewHandshakeFrame(clientName, clientID, clientType, []frame.Tag{frame.Tag(1)}, "token", token), - logger: server.logger, + Logger: server.logger, } for n := 0; n < tt.handshakeTimes; n++ { diff --git a/core/yerr/errors.go b/core/yerr/errors.go index 735445cc5..600d0c468 100644 --- a/core/yerr/errors.go +++ b/core/yerr/errors.go @@ -64,6 +64,8 @@ const ( ErrorCodeUnknownClient ErrorCode = 0xCD // ErrorCodeDuplicateName unknown client error ErrorCodeDuplicateName ErrorCode = 0xC6 + // ErrorCodeStartHandler start handler + ErrorCodeStartHandler ErrorCode = 0xC8 ) var errCodeStringMap = map[ErrorCode]string{ @@ -79,6 +81,7 @@ var errCodeStringMap = map[ErrorCode]string{ ErrorCodeData: "DataFrame", ErrorCodeUnknownClient: "UnknownClient", ErrorCodeDuplicateName: "DuplicateName", + ErrorCodeStartHandler: "StartHandler", } func (e ErrorCode) String() string {