diff --git a/client.go b/client.go index 61cd7526526..670e7a7c678 100644 --- a/client.go +++ b/client.go @@ -34,7 +34,7 @@ type client struct { conn quicConn - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer tracingID uint64 logger utils.Logger } @@ -153,7 +153,7 @@ func dial( if c.config.Tracer != nil { c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID) } - if c.tracer != nil { + if c.tracer != nil && c.tracer.StartedConnection != nil { c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) } if err := c.dial(ctx); err != nil { diff --git a/client_test.go b/client_test.go index e908b82f139..3fa5fb5a22b 100644 --- a/client_test.go +++ b/client_test.go @@ -43,7 +43,7 @@ var _ = Describe("Client", func() { initialPacketNumber protocol.PacketNumber, enable0RTT bool, hasNegotiatedVersion bool, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, v protocol.VersionNumber, @@ -54,10 +54,11 @@ var _ = Describe("Client", func() { tlsConf = &tls.Config{NextProtos: []string{"proto1"}} connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37}) originalClientConnConstructor = newClientConnection - tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + var tr *logging.ConnectionTracer + tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl) config = &Config{ - Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) logging.ConnectionTracer { - return tracer + Tracer: func(ctx context.Context, perspective logging.Perspective, id ConnectionID) *logging.ConnectionTracer { + return tr }, Versions: []protocol.VersionNumber{protocol.Version1}, } @@ -70,7 +71,7 @@ var _ = Describe("Client", func() { destConnID: connID, version: protocol.Version1, sendConn: packetConn, - tracer: tracer, + tracer: tr, logger: utils.DefaultLogger, } getMultiplexer() // make the sync.Once execute @@ -121,7 +122,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, enable0RTT bool, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -158,7 +159,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, enable0RTT bool, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -195,7 +196,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, _ bool, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -280,7 +281,7 @@ var _ = Describe("Client", func() { _ protocol.PacketNumber, _ bool, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, versionP protocol.VersionNumber, @@ -323,7 +324,7 @@ var _ = Describe("Client", func() { pn protocol.PacketNumber, _ bool, hasNegotiatedVersion bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, versionP protocol.VersionNumber, diff --git a/codecov.yml b/codecov.yml index 0b95a1032b5..a24c7a15e0d 100644 --- a/codecov.yml +++ b/codecov.yml @@ -5,7 +5,6 @@ coverage: - interop/ - internal/handshake/cipher_suite.go - internal/utils/linkedlist/linkedlist.go - - logging/null_tracer.go - fuzzing/ - metrics/ status: diff --git a/config_test.go b/config_test.go index 9cb665a30b8..c5701608fea 100644 --- a/config_test.go +++ b/config_test.go @@ -125,7 +125,7 @@ var _ = Describe("Config", func() { GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") }, AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true }, - Tracer: func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer { + Tracer: func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer { calledTracer = true return nil }, diff --git a/connection.go b/connection.go index 10a36ba005e..91a386e7d32 100644 --- a/connection.go +++ b/connection.go @@ -208,7 +208,7 @@ type connection struct { connState ConnectionState logID string - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger } @@ -232,7 +232,7 @@ var newConnection = func( tlsConf *tls.Config, tokenGenerator *handshake.TokenGenerator, clientAddressValidated bool, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, v protocol.VersionNumber, @@ -311,7 +311,7 @@ var newConnection = func( } else { params.MaxDatagramFrameSize = protocol.InvalidByteCount } - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentTransportParameters != nil { s.tracer.SentTransportParameters(params) } cs := handshake.NewCryptoSetupServer( @@ -345,7 +345,7 @@ var newClientConnection = func( initialPacketNumber protocol.PacketNumber, enable0RTT bool, hasNegotiatedVersion bool, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, tracingID uint64, logger utils.Logger, v protocol.VersionNumber, @@ -418,7 +418,7 @@ var newClientConnection = func( } else { params.MaxDatagramFrameSize = protocol.InvalidByteCount } - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentTransportParameters != nil { s.tracer.SentTransportParameters(params) } cs := handshake.NewCryptoSetupClient( @@ -642,8 +642,10 @@ runLoop: s.cryptoStreamHandler.Close() s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE s.handleCloseError(&closeErr) - if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil { - s.tracer.Close() + if s.tracer != nil && s.tracer.Close != nil { + if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) { + s.tracer.Close() + } } s.logger.Infof("Connection %s closed.", s.logID) s.timer.Stop() @@ -800,14 +802,14 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { var err error destConnID, err = wire.ParseConnectionID(p.data, s.srcConnIDLen) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError) } s.logger.Debugf("error parsing packet, couldn't parse connection ID: %s", err) break } if destConnID != lastConnID { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) } s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", destConnID, lastConnID) @@ -818,7 +820,7 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { if wire.IsLongHeaderPacket(p.data[0]) { hdr, packetData, rest, err := wire.ParsePacket(p.data) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { dropReason := logging.PacketDropHeaderParseError if err == wire.ErrUnsupportedVersion { dropReason = logging.PacketDropUnsupportedVersion @@ -831,7 +833,7 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool { lastConnID = hdr.DestConnectionID if hdr.Version != s.version { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) } s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version) @@ -890,14 +892,14 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protoc if s.receivedPacketHandler.IsPotentiallyDuplicate(pn, protocol.Encryption1RTT) { s.logger.Debugf("Dropping (potentially) duplicate packet.") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate) } return false } var log func([]logging.Frame) - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedShortHeaderPacket != nil { log = func(frames []logging.Frame) { s.tracer.ReceivedShortHeaderPacket( &logging.ShortHeader{ @@ -936,7 +938,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. if s.receivedFirstPacket && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID) } s.logger.Debugf("Dropping Initial packet (%d bytes) with unexpected source connection ID: %s (expected %s)", p.Size(), hdr.SrcConnectionID, s.handshakeDestConnID) @@ -944,7 +946,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) } // drop 0-RTT packets, if we are a client if s.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketType0RTT, p.Size(), logging.PacketDropKeyUnavailable) } return false @@ -963,7 +965,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) if s.receivedPacketHandler.IsPotentiallyDuplicate(packet.hdr.PacketNumber, packet.encryptionLevel) { s.logger.Debugf("Dropping (potentially) duplicate packet.") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropDuplicate) } return false @@ -979,7 +981,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool) { switch err { case handshake.ErrKeysDropped: - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropKeyUnavailable) } s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size()) @@ -995,7 +997,7 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P }) case handshake.ErrDecryptionFailed: // This might be a packet injected by an attacker. Drop it. - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropPayloadDecryptError) } s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err) @@ -1003,7 +1005,7 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P var headerErr *headerParseError if errors.As(err, &headerErr) { // This might be a packet injected by an attacker. Drop it. - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err) @@ -1018,14 +1020,14 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime time.Time) bool /* was this a valid Retry */ { if s.perspective == protocol.PerspectiveServer { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry.") return false } if s.receivedFirstPacket { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry, since we already received a packet.") @@ -1033,7 +1035,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime ti } destConnID := s.connIDManager.Get() if hdr.SrcConnectionID == destConnID { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.") @@ -1048,7 +1050,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime ti tag := handshake.GetRetryIntegrityTag(data[:len(data)-16], destConnID, hdr.Version) if !bytes.Equal(data[len(data)-16:], tag[:]) { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropPayloadDecryptError) } s.logger.Debugf("Ignoring spoofed Retry. Integrity Tag doesn't match.") @@ -1060,7 +1062,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime ti (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID) } - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedRetry != nil { s.tracer.ReceivedRetry(hdr) } newDestConnID := hdr.SrcConnectionID @@ -1081,7 +1083,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime ti func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) } return @@ -1089,7 +1091,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { src, dest, supportedVersions, err := wire.ParseVersionNegotiationPacket(p.data) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing Version Negotiation packet: %s", err) @@ -1098,7 +1100,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { for _, v := range supportedVersions { if v == s.version { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedVersion) } // The Version Negotiation packet contains the version that we offered. @@ -1108,7 +1110,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { } s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", supportedVersions) - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedVersionNegotiationPacket != nil { s.tracer.ReceivedVersionNegotiationPacket(dest, src, supportedVersions) } newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, supportedVersions) @@ -1120,7 +1122,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { s.logger.Infof("No compatible QUIC version found.") return } - if s.tracer != nil { + if s.tracer != nil && s.tracer.NegotiatedVersion != nil { s.tracer.NegotiatedVersion(newVersion, s.config.Versions, supportedVersions) } @@ -1140,7 +1142,7 @@ func (s *connection) handleUnpackedLongHeaderPacket( ) error { if !s.receivedFirstPacket { s.receivedFirstPacket = true - if !s.versionNegotiated && s.tracer != nil { + if !s.versionNegotiated && s.tracer != nil && s.tracer.NegotiatedVersion != nil { var clientVersions, serverVersions []protocol.VersionNumber switch s.perspective { case protocol.PerspectiveClient: @@ -1167,7 +1169,7 @@ func (s *connection) handleUnpackedLongHeaderPacket( s.handshakeDestConnID = packet.hdr.SrcConnectionID s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID) } - if s.tracer != nil { + if s.tracer != nil && s.tracer.StartedConnection != nil { s.tracer.StartedConnection( s.conn.LocalAddr(), s.conn.RemoteAddr(), @@ -1191,7 +1193,7 @@ func (s *connection) handleUnpackedLongHeaderPacket( s.keepAlivePingSent = false var log func([]logging.Frame) - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedLongHeaderPacket != nil { log = func(frames []logging.Frame) { s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, ecn, frames) } @@ -1339,7 +1341,7 @@ func (s *connection) handlePacket(p receivedPacket) { select { case s.receivedPackets <- p: default: - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) } } @@ -1619,7 +1621,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { s.datagramQueue.CloseWithError(e) } - if s.tracer != nil && !errors.As(e, &recreateErr) { + if s.tracer != nil && s.tracer.ClosedConnection != nil && !errors.As(e, &recreateErr) { s.tracer.ClosedConnection(e) } @@ -1646,7 +1648,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { } func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedEncryptionLevel != nil { s.tracer.DroppedEncryptionLevel(encLevel) } s.sentPacketHandler.DropPackets(encLevel) @@ -1681,7 +1683,7 @@ func (s *connection) restoreTransportParameters(params *wire.TransportParameters } func (s *connection) handleTransportParameters(params *wire.TransportParameters) error { - if s.tracer != nil { + if s.tracer != nil && s.tracer.ReceivedTransportParameters != nil { s.tracer.ReceivedTransportParameters(params) } if err := s.checkTransportParameters(params); err != nil { @@ -2133,7 +2135,7 @@ func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) } // tracing - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentLongHeaderPacket != nil { frames := make([]logging.Frame, 0, len(p.frames)) for _, f := range p.frames { frames = append(frames, logutils.ConvertFrame(f.Frame)) @@ -2179,7 +2181,7 @@ func (s *connection) logShortHeaderPacket( } // tracing - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentShortHeaderPacket != nil { fs := make([]logging.Frame, 0, len(frames)+len(streamFrames)) for _, f := range frames { fs = append(fs, logutils.ConvertFrame(f.Frame)) @@ -2301,14 +2303,14 @@ func (s *connection) tryQueueingUndecryptablePacket(p receivedPacket, pt logging panic("shouldn't queue undecryptable packets after handshake completion") } if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropDOSPrevention) } s.logger.Infof("Dropping undecryptable packet (%d bytes). Undecryptable packet queue full.", p.Size()) return } s.logger.Infof("Queueing packet (%d bytes) for later decryption", p.Size()) - if s.tracer != nil { + if s.tracer != nil && s.tracer.BufferedPacket != nil { s.tracer.BufferedPacket(pt, p.Size()) } s.undecryptablePackets = append(s.undecryptablePackets, p) diff --git a/connection_test.go b/connection_test.go index 67b02443b7d..06132cc8aff 100644 --- a/connection_test.go +++ b/connection_test.go @@ -102,7 +102,8 @@ var _ = Describe("Connection", func() { mconn.EXPECT().LocalAddr().Return(localAddr).AnyTimes() tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) Expect(err).ToNot(HaveOccurred()) - tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + var tr *logging.ConnectionTracer + tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().SentTransportParameters(gomock.Any()) tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() @@ -121,7 +122,7 @@ var _ = Describe("Connection", func() { &tls.Config{}, tokenGenerator, false, - tracer, + tr, 1234, utils.DefaultLogger, protocol.Version1, @@ -2541,7 +2542,8 @@ var _ = Describe("Client Connection", func() { tlsConf = &tls.Config{} } connRunner = NewMockConnRunner(mockCtrl) - tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + var tr *logging.ConnectionTracer + tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().SentTransportParameters(gomock.Any()) tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() @@ -2557,7 +2559,7 @@ var _ = Describe("Client Connection", func() { 42, // initial packet number false, false, - tracer, + tr, 1234, utils.DefaultLogger, protocol.Version1, diff --git a/example/client/main.go b/example/client/main.go index 83f810fd1a7..5b86faa37c8 100644 --- a/example/client/main.go +++ b/example/client/main.go @@ -58,7 +58,7 @@ func main() { var qconf quic.Config if *enableQlog { - qconf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + qconf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { filename := fmt.Sprintf("client_%x.qlog", connID) f, err := os.Create(filename) if err != nil { diff --git a/example/main.go b/example/main.go index 058144050e8..c2fb574d49d 100644 --- a/example/main.go +++ b/example/main.go @@ -163,7 +163,7 @@ func main() { handler := setupHandler(*www) quicConf := &quic.Config{} if *enableQlog { - quicConf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + quicConf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { filename := fmt.Sprintf("server_%x.qlog", connID) f, err := os.Create(filename) if err != nil { diff --git a/integrationtests/self/key_update_test.go b/integrationtests/self/key_update_test.go index 7975a77bb5a..3704f96eb82 100644 --- a/integrationtests/self/key_update_test.go +++ b/integrationtests/self/key_update_test.go @@ -38,18 +38,13 @@ func countKeyPhases() (sent, received int) { return } -type keyUpdateConnTracer struct { - logging.NullConnectionTracer -} - -var _ logging.ConnectionTracer = &keyUpdateConnTracer{} - -func (t *keyUpdateConnTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ *logging.AckFrame, _ []logging.Frame) { - sentHeaders = append(sentHeaders, hdr) -} - -func (t *keyUpdateConnTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ []logging.Frame) { - receivedHeaders = append(receivedHeaders, hdr) +var keyUpdateConnTracer = &logging.ConnectionTracer{ + SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ *logging.AckFrame, _ []logging.Frame) { + sentHeaders = append(sentHeaders, hdr) + }, + ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, _ []logging.Frame) { + receivedHeaders = append(receivedHeaders, hdr) + }, } var _ = Describe("Key Update tests", func() { @@ -77,8 +72,8 @@ var _ = Describe("Key Update tests", func() { context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), - getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { - return &keyUpdateConnTracer{} + getQuicConfig(&quic.Config{Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return keyUpdateConnTracer }}), ) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/packetization_test.go b/integrationtests/self/packetization_test.go index 740956c5492..8c59bc0b1a1 100644 --- a/integrationtests/self/packetization_test.go +++ b/integrationtests/self/packetization_test.go @@ -21,7 +21,7 @@ var _ = Describe("Packetization", func() { It("bundles ACKs", func() { const numMsg = 100 - serverTracer := newPacketTracer() + serverCounter, serverTracer := newPacketTracer() server, err := quic.ListenAddr( "localhost:0", getTLSConfig(), @@ -43,7 +43,7 @@ var _ = Describe("Packetization", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - clientTracer := newPacketTracer() + clientCounter, clientTracer := newPacketTracer() conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), @@ -104,8 +104,8 @@ var _ = Describe("Packetization", func() { return } - numBundledIncoming := countBundledPackets(clientTracer.getRcvdShortHeaderPackets()) - numBundledOutgoing := countBundledPackets(serverTracer.getRcvdShortHeaderPackets()) + numBundledIncoming := countBundledPackets(clientCounter.getRcvdShortHeaderPackets()) + numBundledOutgoing := countBundledPackets(serverCounter.getRcvdShortHeaderPackets()) fmt.Fprintf(GinkgoWriter, "bundled incoming packets: %d / %d\n", numBundledIncoming, numMsg) fmt.Fprintf(GinkgoWriter, "bundled outgoing packets: %d / %d\n", numBundledOutgoing, numMsg) Expect(numBundledIncoming).To(And( diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index 6396d3a7aac..22b2952c176 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -86,7 +86,7 @@ var ( logBuf *syncedBuffer versionParam string - qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer + qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer enableQlog bool version quic.VersionNumber @@ -177,10 +177,16 @@ func getQuicConfig(conf *quic.Config) *quic.Config { } if enableQlog { if conf.Tracer == nil { - conf.Tracer = qlogTracer + conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { + return logging.NewMultiplexedConnectionTracer( + qlogTracer(ctx, p, connID), + // multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere + &logging.ConnectionTracer{}, + ) + } } else if qlogTracer != nil { origTracer := conf.Tracer - conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { return logging.NewMultiplexedConnectionTracer( qlogTracer(ctx, p, connID), origTracer(ctx, p, connID), @@ -242,8 +248,8 @@ func scaleDuration(d time.Duration) time.Duration { return time.Duration(scaleFactor) * d } -func newTracer(tracer logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { - return func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { return tracer } +func newTracer(tracer *logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return tracer } } type packet struct { @@ -258,51 +264,46 @@ type shortHeaderPacket struct { frames []logging.Frame } -type packetTracer struct { - logging.NullConnectionTracer +type packetCounter struct { closed chan struct{} sentShortHdr, rcvdShortHdr []shortHeaderPacket rcvdLongHdr []packet } -var _ logging.ConnectionTracer = &packetTracer{} - -func newPacketTracer() *packetTracer { - return &packetTracer{closed: make(chan struct{})} -} - -func (t *packetTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) { - t.rcvdLongHdr = append(t.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames}) -} - -func (t *packetTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) { - t.rcvdShortHdr = append(t.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) -} - -func (t *packetTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, ack *wire.AckFrame, frames []logging.Frame) { - if ack != nil { - frames = append(frames, ack) - } - t.sentShortHdr = append(t.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) -} - -func (t *packetTracer) Close() { close(t.closed) } - -func (t *packetTracer) getSentShortHeaderPackets() []shortHeaderPacket { +func (t *packetCounter) getSentShortHeaderPackets() []shortHeaderPacket { <-t.closed return t.sentShortHdr } -func (t *packetTracer) getRcvdLongHeaderPackets() []packet { +func (t *packetCounter) getRcvdLongHeaderPackets() []packet { <-t.closed return t.rcvdLongHdr } -func (t *packetTracer) getRcvdShortHeaderPackets() []shortHeaderPacket { +func (t *packetCounter) getRcvdShortHeaderPackets() []shortHeaderPacket { <-t.closed return t.rcvdShortHdr } +func newPacketTracer() (*packetCounter, *logging.ConnectionTracer) { + c := &packetCounter{} + return c, &logging.ConnectionTracer{ + ReceivedLongHeaderPacket: func(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) { + c.rcvdLongHdr = append(c.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames}) + }, + ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) { + c.rcvdShortHdr = append(c.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) + }, + SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, ack *wire.AckFrame, frames []logging.Frame) { + if ack != nil { + frames = append(frames, ack) + } + c.sentShortHdr = append(c.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) + }, + Close: func() { close(c.closed) }, + } +} + func TestSelf(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Self integration tests") diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index 47569b77e74..ddd0b973846 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -194,7 +194,7 @@ var _ = Describe("Timeout tests", func() { close(serverConnClosed) }() - tr := newPacketTracer() + counter, tr := newPacketTracer() conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), @@ -215,7 +215,7 @@ var _ = Describe("Timeout tests", func() { }() Eventually(done, 2*idleTimeout).Should(BeClosed()) var lastAckElicitingPacketSentAt time.Time - for _, p := range tr.getSentShortHeaderPackets() { + for _, p := range counter.getSentShortHeaderPackets() { var hasAckElicitingFrame bool for _, f := range p.frames { if _, ok := f.(*logging.AckFrame); ok { @@ -228,7 +228,7 @@ var _ = Describe("Timeout tests", func() { lastAckElicitingPacketSentAt = p.time } } - rcvdPackets := tr.getRcvdShortHeaderPackets() + rcvdPackets := counter.getRcvdShortHeaderPackets() lastPacketRcvdAt := rcvdPackets[len(rcvdPackets)-1].time // We're ignoring here that only the first ack-eliciting packet sent resets the idle timeout. // This is ok since we're dealing with a lossless connection here, diff --git a/integrationtests/self/tracer_test.go b/integrationtests/self/tracer_test.go index 3bfae3c6b8a..5179646e99a 100644 --- a/integrationtests/self/tracer_test.go +++ b/integrationtests/self/tracer_test.go @@ -26,9 +26,9 @@ var _ = Describe("Handshake tests", func() { fmt.Fprintf(GinkgoWriter, "%s using qlog: %t, custom: %t\n", pers, enableQlog, enableCustomTracer) - var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer + var tracerConstructors []func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer if enableQlog { - tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + tracerConstructors = append(tracerConstructors, func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { if mrand.Int()%2 == 0 { // simulate that a qlog collector might only want to log some connections fmt.Fprintf(GinkgoWriter, "%s qlog tracer deciding to not trace connection %x\n", p, connID) return nil @@ -38,13 +38,13 @@ var _ = Describe("Handshake tests", func() { }) } if enableCustomTracer { - tracerConstructors = append(tracerConstructors, func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { - return logging.NullConnectionTracer{} + tracerConstructors = append(tracerConstructors, func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return &logging.ConnectionTracer{} }) } c := conf.Clone() - c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { - tracers := make([]logging.ConnectionTracer, 0, len(tracerConstructors)) + c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { + tracers := make([]*logging.ConnectionTracer, 0, len(tracerConstructors)) for _, c := range tracerConstructors { if tr := c(ctx, p, connID); tr != nil { tracers = append(tracers, tr) diff --git a/integrationtests/self/zero_rtt_oldgo_test.go b/integrationtests/self/zero_rtt_oldgo_test.go index 3e9277b937a..eb2302d3d85 100644 --- a/integrationtests/self/zero_rtt_oldgo_test.go +++ b/integrationtests/self/zero_rtt_oldgo_test.go @@ -202,7 +202,7 @@ var _ = Describe("0-RTT", func() { Eventually(conn.Context().Done()).Should(BeClosed()) } - // can be used to extract 0-RTT from a packetTracer + // can be used to extract 0-RTT from a packetCounter get0RTTPackets := func(packets []packet) []protocol.PacketNumber { var zeroRTTPackets []protocol.PacketNumber for _, p := range packets { @@ -219,7 +219,7 @@ var _ = Describe("0-RTT", func() { It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() { tlsConf, clientTLSConf := dialAndReceiveSessionTicket(nil) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -244,7 +244,7 @@ var _ = Describe("0-RTT", func() { ) var numNewConnIDs int - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { for _, f := range p.frames { if _, ok := f.(*logging.NewConnectionIDFrame); ok { numNewConnIDs++ @@ -260,7 +260,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0))) }) @@ -273,7 +273,7 @@ var _ = Describe("0-RTT", func() { zeroRTTData := GeneratePRData(5 << 10) oneRTTData := PRData - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -330,7 +330,7 @@ var _ = Describe("0-RTT", func() { // check that 0-RTT packets only contain STREAM frames for the first stream var num0RTT int - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { if p.hdr.Header.Type != protocol.PacketType0RTT { continue } @@ -355,7 +355,7 @@ var _ = Describe("0-RTT", func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -406,7 +406,7 @@ var _ = Describe("0-RTT", func() { fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped) Expect(numDropped).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) }) It("retransmits all 0-RTT data when the server performs a Retry", func() { @@ -430,7 +430,7 @@ var _ = Describe("0-RTT", func() { return } - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -480,7 +480,7 @@ var _ = Describe("0-RTT", func() { defer mutex.Unlock() Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra Expect(secondCounter).To(BeNumerically("~", firstCounter, 20)) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5)) Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5))) }) @@ -491,14 +491,12 @@ var _ = Describe("0-RTT", func() { MaxIncomingUniStreams: maxStreams, })) - tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, getQuicConfig(&quic.Config{ MaxIncomingUniStreams: maxStreams + 1, Allow0RTT: true, - Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -536,7 +534,7 @@ var _ = Describe("0-RTT", func() { MaxIncomingStreams: maxStreams, })) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -556,7 +554,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) It("rejects 0-RTT when the ALPN changed", func() { @@ -565,7 +563,7 @@ var _ = Describe("0-RTT", func() { // now close the listener and dial new connection with a different ALPN clientConf.NextProtos = []string{"new-alpn"} tlsConf.NextProtos = []string{"new-alpn"} - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -585,14 +583,14 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) It("rejects 0-RTT when the application doesn't allow it", func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) // now close the listener and dial new connection with a different ALPN - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -612,12 +610,12 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) DescribeTable("flow control limits", func(addFlowControlLimit func(*quic.Config, uint64)) { - tracer := newPacketTracer() + counter, tracer := newPacketTracer() firstConf := getQuicConfig(&quic.Config{Allow0RTT: true}) addFlowControlLimit(firstConf, 3) tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) @@ -669,7 +667,7 @@ var _ = Describe("0-RTT", func() { Eventually(conn.Context().Done()).Should(BeClosed()) var processedFirst bool - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { for _, f := range p.frames { if sf, ok := f.(*logging.StreamFrame); ok { if !processedFirst { @@ -695,7 +693,7 @@ var _ = Describe("0-RTT", func() { It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) // now dial new connection with different transport parameters - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -764,14 +762,14 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) } It("queues 0-RTT packets, if the Initial is delayed", func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -796,8 +794,8 @@ var _ = Describe("0-RTT", func() { transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) - Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(counter.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) }) @@ -807,7 +805,7 @@ var _ = Describe("0-RTT", func() { EnableDatagrams: true, })) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -856,7 +854,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(zeroRTTPackets).To(HaveLen(1)) }) @@ -865,7 +863,7 @@ var _ = Describe("0-RTT", func() { EnableDatagrams: true, })) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -911,6 +909,6 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) }) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index c9bdeff6990..1f750d3ac25 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -232,7 +232,7 @@ var _ = Describe("0-RTT", func() { Eventually(conn.Context().Done()).Should(BeClosed()) } - // can be used to extract 0-RTT from a packetTracer + // can be used to extract 0-RTT from a packetCounter get0RTTPackets := func(packets []packet) []protocol.PacketNumber { var zeroRTTPackets []protocol.PacketNumber for _, p := range packets { @@ -251,7 +251,7 @@ var _ = Describe("0-RTT", func() { clientTLSConf := getTLSClientConfig() dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -276,7 +276,7 @@ var _ = Describe("0-RTT", func() { ) var numNewConnIDs int - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { for _, f := range p.frames { if _, ok := f.(*logging.NewConnectionIDFrame); ok { numNewConnIDs++ @@ -292,7 +292,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0))) }) @@ -307,7 +307,7 @@ var _ = Describe("0-RTT", func() { zeroRTTData := GeneratePRData(5 << 10) oneRTTData := PRData - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -364,7 +364,7 @@ var _ = Describe("0-RTT", func() { // check that 0-RTT packets only contain STREAM frames for the first stream var num0RTT int - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { if p.hdr.Header.Type != protocol.PacketType0RTT { continue } @@ -391,7 +391,7 @@ var _ = Describe("0-RTT", func() { clientConf := getTLSClientConfig() dialAndReceiveSessionTicket(tlsConf, nil, clientConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -442,7 +442,7 @@ var _ = Describe("0-RTT", func() { fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped) Expect(numDropped).ToNot(BeZero()) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) }) It("retransmits all 0-RTT data when the server performs a Retry", func() { @@ -468,7 +468,7 @@ var _ = Describe("0-RTT", func() { return } - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -518,7 +518,7 @@ var _ = Describe("0-RTT", func() { defer mutex.Unlock() Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra Expect(secondCounter).To(BeNumerically("~", firstCounter, 20)) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5)) Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5))) }) @@ -531,14 +531,12 @@ var _ = Describe("0-RTT", func() { MaxIncomingUniStreams: maxStreams, }), clientConf) - tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, getQuicConfig(&quic.Config{ MaxIncomingUniStreams: maxStreams + 1, Allow0RTT: true, - Tracer: newTracer(tracer), }), ) Expect(err).ToNot(HaveOccurred()) @@ -578,7 +576,7 @@ var _ = Describe("0-RTT", func() { MaxIncomingStreams: maxStreams, }), clientConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -599,7 +597,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) It("rejects 0-RTT when the ALPN changed", func() { @@ -612,7 +610,7 @@ var _ = Describe("0-RTT", func() { // Append to the client's ALPN. // crypto/tls will attempt to resume with the ALPN from the original connection clientConf.NextProtos = append(clientConf.NextProtos, "new-alpn") - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -632,7 +630,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) It("rejects 0-RTT when the application doesn't allow it", func() { @@ -641,7 +639,7 @@ var _ = Describe("0-RTT", func() { dialAndReceiveSessionTicket(tlsConf, nil, clientConf) // now close the listener and dial new connection with a different ALPN - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -661,12 +659,12 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) DescribeTable("flow control limits", func(addFlowControlLimit func(*quic.Config, uint64)) { - tracer := newPacketTracer() + counter, tracer := newPacketTracer() firstConf := getQuicConfig(&quic.Config{Allow0RTT: true}) addFlowControlLimit(firstConf, 3) tlsConf := getTLSConfig() @@ -720,7 +718,7 @@ var _ = Describe("0-RTT", func() { Eventually(conn.Context().Done()).Should(BeClosed()) var processedFirst bool - for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, p := range counter.getRcvdLongHeaderPackets() { for _, f := range p.frames { if sf, ok := f.(*logging.StreamFrame); ok { if !processedFirst { @@ -748,7 +746,7 @@ var _ = Describe("0-RTT", func() { clientConf := getTLSClientConfig() dialAndReceiveSessionTicket(tlsConf, nil, clientConf) // now dial new connection with different transport parameters - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -817,7 +815,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) } @@ -826,7 +824,7 @@ var _ = Describe("0-RTT", func() { clientConf := getTLSClientConfig() dialAndReceiveSessionTicket(tlsConf, nil, clientConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -851,8 +849,8 @@ var _ = Describe("0-RTT", func() { transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) - Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(counter.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) }) @@ -878,14 +876,10 @@ var _ = Describe("0-RTT", func() { clientTLSConf := getTLSClientConfig() dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf) - tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, - getQuicConfig(&quic.Config{ - Allow0RTT: true, - Tracer: newTracer(tracer), - }), + getQuicConfig(&quic.Config{Allow0RTT: true}), ) Expect(err).ToNot(HaveOccurred()) defer ln.Close() @@ -916,14 +910,10 @@ var _ = Describe("0-RTT", func() { } dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf) - tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, - getQuicConfig(&quic.Config{ - Allow0RTT: true, - Tracer: newTracer(tracer), - }), + getQuicConfig(&quic.Config{Allow0RTT: true}), ) Expect(err).ToNot(HaveOccurred()) defer ln.Close() @@ -946,7 +936,7 @@ var _ = Describe("0-RTT", func() { EnableDatagrams: true, }), clientTLSConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -994,7 +984,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + zeroRTTPackets := get0RTTPackets(counter.getRcvdLongHeaderPackets()) Expect(zeroRTTPackets).To(HaveLen(1)) }) @@ -1005,7 +995,7 @@ var _ = Describe("0-RTT", func() { EnableDatagrams: true, }), clientTLSConf) - tracer := newPacketTracer() + counter, tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -1051,6 +1041,6 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(get0RTTPackets(counter.getRcvdLongHeaderPackets())).To(BeEmpty()) }) }) diff --git a/integrationtests/tools/qlog.go b/integrationtests/tools/qlog.go index 352e0a613d3..ea37456e8b1 100644 --- a/integrationtests/tools/qlog.go +++ b/integrationtests/tools/qlog.go @@ -14,8 +14,8 @@ import ( "github.com/quic-go/quic-go/qlog" ) -func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { - return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { +func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { + return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { role := "server" if p == logging.PerspectiveClient { role = "client" diff --git a/integrationtests/versionnegotiation/handshake_test.go b/integrationtests/versionnegotiation/handshake_test.go index a079f6e12d2..eefcd8e7cd3 100644 --- a/integrationtests/versionnegotiation/handshake_test.go +++ b/integrationtests/versionnegotiation/handshake_test.go @@ -21,29 +21,29 @@ type versioner interface { GetVersion() protocol.VersionNumber } -type versionNegotiationTracer struct { - logging.NullConnectionTracer - +type result struct { loggedVersions bool receivedVersionNegotiation bool chosen logging.VersionNumber clientVersions, serverVersions []logging.VersionNumber } -var _ logging.ConnectionTracer = &versionNegotiationTracer{} - -func (t *versionNegotiationTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { - if t.loggedVersions { - Fail("only expected one call to NegotiatedVersions") +func newVersionNegotiationTracer() (*result, *logging.ConnectionTracer) { + r := &result{} + return r, &logging.ConnectionTracer{ + NegotiatedVersion: func(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { + if r.loggedVersions { + Fail("only expected one call to NegotiatedVersions") + } + r.loggedVersions = true + r.chosen = chosen + r.clientVersions = clientVersions + r.serverVersions = serverVersions + }, + ReceivedVersionNegotiationPacket: func(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) { + r.receivedVersionNegotiation = true + }, } - t.loggedVersions = true - t.chosen = chosen - t.clientVersions = clientVersions - t.serverVersions = serverVersions -} - -func (t *versionNegotiationTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) { - t.receivedVersionNegotiation = true } var _ = Describe("Handshake tests", func() { @@ -86,54 +86,54 @@ var _ = Describe("Handshake tests", func() { // but it supports a bunch of versions that the client doesn't speak serverConfig := &quic.Config{} serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9} - serverTracer := &versionNegotiationTracer{} - serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + serverResult, serverTracer := newVersionNegotiationTracer() + serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return serverTracer } server, cl := startServer(getTLSConfig(), serverConfig) defer cl() - clientTracer := &versionNegotiationTracer{} + clientResult, clientTracer := newVersionNegotiationTracer() conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), - maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) logging.ConnectionTracer { + maybeAddQLOGTracer(&quic.Config{Tracer: func(ctx context.Context, perspective logging.Perspective, id quic.ConnectionID) *logging.ConnectionTracer { return clientTracer }}), ) Expect(err).ToNot(HaveOccurred()) Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion)) Expect(conn.CloseWithError(0, "")).To(Succeed()) - Expect(clientTracer.chosen).To(Equal(expectedVersion)) - Expect(clientTracer.receivedVersionNegotiation).To(BeFalse()) - Expect(clientTracer.clientVersions).To(Equal(protocol.SupportedVersions)) - Expect(clientTracer.serverVersions).To(BeEmpty()) - Expect(serverTracer.chosen).To(Equal(expectedVersion)) - Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions)) - Expect(serverTracer.clientVersions).To(BeEmpty()) + Expect(clientResult.chosen).To(Equal(expectedVersion)) + Expect(clientResult.receivedVersionNegotiation).To(BeFalse()) + Expect(clientResult.clientVersions).To(Equal(protocol.SupportedVersions)) + Expect(clientResult.serverVersions).To(BeEmpty()) + Expect(serverResult.chosen).To(Equal(expectedVersion)) + Expect(serverResult.serverVersions).To(Equal(serverConfig.Versions)) + Expect(serverResult.clientVersions).To(BeEmpty()) }) It("when the client supports more versions than the server supports", func() { expectedVersion := protocol.SupportedVersions[0] // The server doesn't support the highest supported version, which is the first one the client will try, // but it supports a bunch of versions that the client doesn't speak - serverTracer := &versionNegotiationTracer{} + serverResult, serverTracer := newVersionNegotiationTracer() serverConfig := &quic.Config{} serverConfig.Versions = supportedVersions - serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return serverTracer } server, cl := startServer(getTLSConfig(), serverConfig) defer cl() clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} - clientTracer := &versionNegotiationTracer{} + clientResult, clientTracer := newVersionNegotiationTracer() conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), maybeAddQLOGTracer(&quic.Config{ Versions: clientVersions, - Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return clientTracer }, }), @@ -141,22 +141,22 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) Expect(conn.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0])) Expect(conn.CloseWithError(0, "")).To(Succeed()) - Expect(clientTracer.chosen).To(Equal(expectedVersion)) - Expect(clientTracer.receivedVersionNegotiation).To(BeTrue()) - Expect(clientTracer.clientVersions).To(Equal(clientVersions)) - Expect(clientTracer.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions - Expect(serverTracer.chosen).To(Equal(expectedVersion)) - Expect(serverTracer.serverVersions).To(Equal(serverConfig.Versions)) - Expect(serverTracer.clientVersions).To(BeEmpty()) + Expect(clientResult.chosen).To(Equal(expectedVersion)) + Expect(clientResult.receivedVersionNegotiation).To(BeTrue()) + Expect(clientResult.clientVersions).To(Equal(clientVersions)) + Expect(clientResult.serverVersions).To(ContainElements(supportedVersions)) // may contain greased versions + Expect(serverResult.chosen).To(Equal(expectedVersion)) + Expect(serverResult.serverVersions).To(Equal(serverConfig.Versions)) + Expect(serverResult.clientVersions).To(BeEmpty()) }) It("fails if the server disables version negotiation", func() { // The server doesn't support the highest supported version, which is the first one the client will try, // but it supports a bunch of versions that the client doesn't speak - serverTracer := &versionNegotiationTracer{} + _, serverTracer := newVersionNegotiationTracer() serverConfig := &quic.Config{} serverConfig.Versions = supportedVersions - serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + serverConfig.Tracer = func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return serverTracer } conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}) @@ -170,14 +170,14 @@ var _ = Describe("Handshake tests", func() { defer ln.Close() clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} - clientTracer := &versionNegotiationTracer{} + clientResult, clientTracer := newVersionNegotiationTracer() _, err = quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", conn.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), maybeAddQLOGTracer(&quic.Config{ Versions: clientVersions, - Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { + Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return clientTracer }, HandshakeIdleTimeout: 100 * time.Millisecond, @@ -187,7 +187,7 @@ var _ = Describe("Handshake tests", func() { var nerr net.Error Expect(errors.As(err, &nerr)).To(BeTrue()) Expect(nerr.Timeout()).To(BeTrue()) - Expect(clientTracer.receivedVersionNegotiation).To(BeFalse()) + Expect(clientResult.receivedVersionNegotiation).To(BeFalse()) }) } }) diff --git a/integrationtests/versionnegotiation/versionnegotiation_suite_test.go b/integrationtests/versionnegotiation/versionnegotiation_suite_test.go index a01ac1f8a58..150181f257c 100644 --- a/integrationtests/versionnegotiation/versionnegotiation_suite_test.go +++ b/integrationtests/versionnegotiation/versionnegotiation_suite_test.go @@ -70,7 +70,7 @@ func maybeAddQLOGTracer(c *quic.Config) *quic.Config { c.Tracer = qlogger } else if qlogger != nil { origTracer := c.Tracer - c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { + c.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { return logging.NewMultiplexedConnectionTracer( qlogger(ctx, p, connID), origTracer(ctx, p, connID), diff --git a/interface.go b/interface.go index 8faf0eb8e27..a65f00139c5 100644 --- a/interface.go +++ b/interface.go @@ -325,7 +325,7 @@ type Config struct { Allow0RTT bool // Enable QUIC datagram support (RFC 9221). EnableDatagrams bool - Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer + Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer } type ClientHelloInfo struct { diff --git a/internal/ackhandler/ackhandler.go b/internal/ackhandler/ackhandler.go index 036ed5ead90..cb28582a337 100644 --- a/internal/ackhandler/ackhandler.go +++ b/internal/ackhandler/ackhandler.go @@ -16,7 +16,7 @@ func NewAckHandler( clientAddressValidated bool, enableECN bool, pers protocol.Perspective, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, ) (SentPacketHandler, ReceivedPacketHandler) { sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger) diff --git a/internal/ackhandler/ecn.go b/internal/ackhandler/ecn.go index bb8b37f8331..43cc3dd6e10 100644 --- a/internal/ackhandler/ecn.go +++ b/internal/ackhandler/ecn.go @@ -45,13 +45,13 @@ type ecnTracker struct { numSentECT0, numSentECT1 int64 numAckedECT0, numAckedECT1, numAckedECNCE int64 - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger } var _ ecnHandler = &ecnTracker{} -func newECNTracker(logger utils.Logger, tracer logging.ConnectionTracer) *ecnTracker { +func newECNTracker(logger utils.Logger, tracer *logging.ConnectionTracer) *ecnTracker { return &ecnTracker{ firstTestingPacket: protocol.InvalidPacketNumber, lastTestingPacket: protocol.InvalidPacketNumber, @@ -92,7 +92,7 @@ func (e *ecnTracker) SentPacket(pn protocol.PacketNumber, ecn protocol.ECN) { e.firstTestingPacket = pn } if e.numSentECT0+e.numSentECT1 >= numECNTestingPackets { - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger) } e.state = ecnStateUnknown @@ -103,7 +103,7 @@ func (e *ecnTracker) SentPacket(pn protocol.PacketNumber, ecn protocol.ECN) { func (e *ecnTracker) Mode() protocol.ECN { switch e.state { case ecnStateInitial: - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger) } e.state = ecnStateTesting @@ -127,7 +127,7 @@ func (e *ecnTracker) LostPacket(pn protocol.PacketNumber) { e.numLostTesting++ if e.numLostTesting >= e.numSentTesting { e.logger.Debugf("Disabling ECN. All testing packets were lost.") - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets) } e.state = ecnStateFailed @@ -146,7 +146,7 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 // the total number of packets sent with each corresponding ECT codepoint. if ect0 > e.numSentECT0 || ect1 > e.numSentECT1 { e.logger.Debugf("Disabling ECN. Received more ECT(0) / ECT(1) acknowledgements than packets sent.") - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedMoreECNCountsThanSent) } e.state = ecnStateFailed @@ -172,7 +172,7 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 // * peers that don't report any ECN counts if (ackedECT0 > 0 || ackedECT1 > 0) && ect0 == 0 && ect1 == 0 && ecnce == 0 { e.logger.Debugf("Disabling ECN. ECN-marked packet acknowledged, but no ECN counts on ACK frame.") - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts) } e.state = ecnStateFailed @@ -189,7 +189,7 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 // Any decrease means that the peer's counting logic is broken. if newECT0 < 0 || newECT1 < 0 || newECNCE < 0 { e.logger.Debugf("Disabling ECN. ECN counts decreased unexpectedly.") - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedDecreasedECNCounts) } e.state = ecnStateFailed @@ -201,7 +201,7 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 // This could be the result of (partial) bleaching. if newECT0+newECNCE < ackedECT0 { e.logger.Debugf("Disabling ECN. Received less ECT(0) + ECN-CE than packets sent with ECT(0).") - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts) } e.state = ecnStateFailed @@ -211,7 +211,7 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 // the number of newly acknowledged packets sent with an ECT(1) marking. if newECT1+newECNCE < ackedECT1 { e.logger.Debugf("Disabling ECN. Received less ECT(1) + ECN-CE than packets sent with ECT(1).") - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts) } e.state = ecnStateFailed @@ -226,7 +226,7 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 if e.state == ecnStateTesting || e.state == ecnStateUnknown { // Detect mangling (a path remarking all ECN-marked testing packets as CE). if e.numSentECT0+e.numSentECT1 == e.numAckedECNCE && e.numAckedECNCE >= numECNTestingPackets { - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected) } e.state = ecnStateFailed @@ -243,7 +243,7 @@ func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64 // This check won't succeed if the path is mangling ECN-marks (i.e. rewrites all ECN-marked packets to CE). if ackedTestingPacket && (newECT0 > 0 || newECT1 > 0) { e.logger.Debugf("ECN capability confirmed.") - if e.tracer != nil { + if e.tracer != nil && e.tracer.ECNStateUpdated != nil { e.tracer.ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger) } e.state = ecnStateCapable diff --git a/internal/ackhandler/ecn_test.go b/internal/ackhandler/ecn_test.go index 984a97a9660..26d54eed5f3 100644 --- a/internal/ackhandler/ecn_test.go +++ b/internal/ackhandler/ecn_test.go @@ -23,8 +23,9 @@ var _ = Describe("ECN tracker", func() { } BeforeEach(func() { - tracer = mocklogging.NewMockConnectionTracer(mockCtrl) - ecnTracker = newECNTracker(utils.DefaultLogger, tracer) + var tr *logging.ConnectionTracer + tr, tracer = mocklogging.NewMockConnectionTracer(mockCtrl) + ecnTracker = newECNTracker(utils.DefaultLogger, tr) }) It("sends exactly 10 testing packets", func() { diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index b498ce4b77d..c8265a78d93 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -97,7 +97,7 @@ type sentPacketHandler struct { perspective protocol.Perspective - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger } @@ -115,7 +115,7 @@ func newSentPacketHandler( clientAddressValidated bool, enableECN bool, pers protocol.Perspective, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, ) *sentPacketHandler { congestion := congestion.NewCubicSender( @@ -196,7 +196,7 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { default: panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) } - if h.tracer != nil && h.ptoCount != 0 { + if h.tracer != nil && h.tracer.UpdatedPTOCount != nil && h.ptoCount != 0 { h.tracer.UpdatedPTOCount(0) } h.ptoCount = 0 @@ -286,7 +286,7 @@ func (h *sentPacketHandler) SentPacket( p.includedInBytesInFlight = true pnSpace.history.SentAckElicitingPacket(p) - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedMetrics != nil { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } h.setLossDetectionTimer() @@ -376,14 +376,14 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En // Reset the pto_count unless the client is unsure if the server has validated the client's address. if h.peerCompletedAddressValidation { - if h.tracer != nil && h.ptoCount != 0 { + if h.tracer != nil && h.tracer.UpdatedPTOCount != nil && h.ptoCount != 0 { h.tracer.UpdatedPTOCount(0) } h.ptoCount = 0 } h.numProbesToSend = 0 - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedMetrics != nil { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } @@ -462,7 +462,7 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL if err := pnSpace.history.Remove(p.PacketNumber); err != nil { return nil, err } - if h.tracer != nil { + if h.tracer != nil && h.tracer.AcknowledgedPacket != nil { h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber) } } @@ -555,7 +555,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { if !lossTime.IsZero() { // Early retransmit timer or time loss detection. h.alarm = lossTime - if h.tracer != nil && h.alarm != oldAlarm { + if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm { h.tracer.SetLossTimer(logging.TimerTypeACK, encLevel, h.alarm) } return @@ -566,7 +566,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { h.alarm = time.Time{} if !oldAlarm.IsZero() { h.logger.Debugf("Canceling loss detection timer. Amplification limited.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } @@ -578,7 +578,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() { h.alarm = time.Time{} if !oldAlarm.IsZero() { h.logger.Debugf("Canceling loss detection timer. No packets in flight.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } @@ -591,14 +591,14 @@ func (h *sentPacketHandler) setLossDetectionTimer() { if !oldAlarm.IsZero() { h.alarm = time.Time{} h.logger.Debugf("Canceling loss detection timer. No PTO needed..") - if h.tracer != nil { + if h.tracer != nil && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } return } h.alarm = ptoTime - if h.tracer != nil && h.alarm != oldAlarm { + if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm { h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm) } } @@ -629,7 +629,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E if h.logger.Debug() { h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.LostPacket != nil { h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold) } } @@ -639,7 +639,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E if h.logger.Debug() { h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.LostPacket != nil { h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold) } } @@ -676,7 +676,7 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { if h.logger.Debug() { h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.LossTimerExpired != nil { h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel) } // Early retransmit or time loss detection @@ -713,8 +713,12 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount) } if h.tracer != nil { - h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel) - h.tracer.UpdatedPTOCount(h.ptoCount) + if h.tracer.LossTimerExpired != nil { + h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel) + } + if h.tracer.UpdatedPTOCount != nil { + h.tracer.UpdatedPTOCount(h.ptoCount) + } } h.numProbesToSend += 2 //nolint:exhaustive // We never arm a PTO timer for 0-RTT packets. @@ -890,7 +894,7 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) error { if h.logger.Debug() { h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedMetrics != nil { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } } @@ -899,8 +903,10 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) error { oldAlarm := h.alarm h.alarm = time.Time{} if h.tracer != nil { - h.tracer.UpdatedPTOCount(0) - if !oldAlarm.IsZero() { + if h.tracer.UpdatedPTOCount != nil { + h.tracer.UpdatedPTOCount(0) + } + if !oldAlarm.IsZero() && h.tracer.LossTimerCanceled != nil { h.tracer.LossTimerCanceled() } } diff --git a/internal/congestion/cubic_sender.go b/internal/congestion/cubic_sender.go index 10eb4667dfa..ee558f2d5ab 100644 --- a/internal/congestion/cubic_sender.go +++ b/internal/congestion/cubic_sender.go @@ -56,7 +56,7 @@ type cubicSender struct { maxDatagramSize protocol.ByteCount lastState logging.CongestionState - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer } var ( @@ -70,7 +70,7 @@ func NewCubicSender( rttStats *utils.RTTStats, initialMaxDatagramSize protocol.ByteCount, reno bool, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, ) *cubicSender { return newCubicSender( clock, @@ -90,7 +90,7 @@ func newCubicSender( initialMaxDatagramSize, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, ) *cubicSender { c := &cubicSender{ rttStats: rttStats, @@ -108,7 +108,7 @@ func newCubicSender( maxDatagramSize: initialMaxDatagramSize, } c.pacer = newPacer(c.BandwidthEstimate) - if c.tracer != nil { + if c.tracer != nil && c.tracer.UpdatedCongestionState != nil { c.lastState = logging.CongestionStateSlowStart c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart) } @@ -296,7 +296,7 @@ func (c *cubicSender) OnConnectionMigration() { } func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) { - if c.tracer == nil || new == c.lastState { + if c.tracer == nil || c.tracer.UpdatedCongestionState == nil || new == c.lastState { return } c.tracer.UpdatedCongestionState(new) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index b5fb404bf6d..cae02873c83 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -43,7 +43,7 @@ type cryptoSetup struct { rttStats *utils.RTTStats - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger perspective protocol.Perspective @@ -77,7 +77,7 @@ func NewCryptoSetupClient( tlsConf *tls.Config, enable0RTT bool, rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber, ) CryptoSetup { @@ -111,7 +111,7 @@ func NewCryptoSetupServer( tlsConf *tls.Config, allow0RTT bool, rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber, ) CryptoSetup { @@ -165,13 +165,13 @@ func newCryptoSetup( connID protocol.ConnectionID, tp *wire.TransportParameters, rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, + tracer *logging.ConnectionTracer, logger utils.Logger, perspective protocol.Perspective, version protocol.VersionNumber, ) *cryptoSetup { initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version) - if tracer != nil { + if tracer != nil && tracer.UpdatedKeyFromTLS != nil { tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) } @@ -193,7 +193,7 @@ func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) { initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version) h.initialSealer = initialSealer h.initialOpener = initialOpener - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) } @@ -457,7 +457,7 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr } h.mutex.Unlock() h.events = append(h.events, Event{Kind: EventReceivedReadKeys}) - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite()) } } @@ -479,7 +479,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t if h.logger.Debug() { h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective) } // don't set used0RTT here. 0-RTT might still get rejected. @@ -503,7 +503,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t h.used0RTT.Store(true) h.zeroRTTSealer = nil h.logger.Debugf("Dropping 0-RTT keys.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil { h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) } } @@ -511,7 +511,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t panic("unexpected write encryption level") } h.mutex.Unlock() - if h.tracer != nil { + if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil { h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective) } } @@ -651,7 +651,7 @@ func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) { h.zeroRTTOpener = nil h.logger.Debugf("Dropping 0-RTT keys.") - if h.tracer != nil { + if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil { h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) } } diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 919b8a5bcf0..a583f27732d 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -57,7 +57,7 @@ type updatableAEAD struct { rttStats *utils.RTTStats - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer logger utils.Logger version protocol.VersionNumber @@ -70,7 +70,7 @@ var ( _ ShortHeaderSealer = &updatableAEAD{} ) -func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD { +func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD { return &updatableAEAD{ firstPacketNumber: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber, @@ -86,7 +86,7 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, func (a *updatableAEAD) rollKeys() { if a.prevRcvAEAD != nil { a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry) - if a.tracer != nil { + if a.tracer != nil && a.tracer.DroppedKey != nil { a.tracer.DroppedKey(a.keyPhase - 1) } a.prevRcvAEADExpiry = time.Time{} @@ -182,7 +182,7 @@ func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.Pac a.prevRcvAEAD = nil a.logger.Debugf("Dropping key phase %d", a.keyPhase-1) a.prevRcvAEADExpiry = time.Time{} - if a.tracer != nil { + if a.tracer != nil && a.tracer.DroppedKey != nil { a.tracer.DroppedKey(a.keyPhase - 1) } } @@ -216,7 +216,7 @@ func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.Pac // The peer initiated this key update. It's safe to drop the keys for the previous generation now. // Start a timer to drop the previous key generation. a.startKeyDropTimer(rcvTime) - if a.tracer != nil { + if a.tracer != nil && a.tracer.UpdatedKey != nil { a.tracer.UpdatedKey(a.keyPhase, true) } a.firstRcvdWithCurrentKey = pn @@ -308,7 +308,7 @@ func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { if a.shouldInitiateKeyUpdate() { a.rollKeys() a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase) - if a.tracer != nil { + if a.tracer != nil && a.tracer.UpdatedKey != nil { a.tracer.UpdatedKey(a.keyPhase, false) } } diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index a4a91f01044..a57030bf73e 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -11,6 +11,7 @@ import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -62,7 +63,8 @@ var _ = Describe("Updatable AEAD", func() { ) BeforeEach(func() { - serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) + var tr *logging.ConnectionTracer + tr, serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) trafficSecret1 := make([]byte, 16) trafficSecret2 := make([]byte, 16) rand.Read(trafficSecret1) @@ -70,7 +72,7 @@ var _ = Describe("Updatable AEAD", func() { rttStats = utils.NewRTTStats() client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, v) - server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger, v) + server = newUpdatableAEAD(rttStats, tr, utils.DefaultLogger, v) client.SetReadKey(cs, trafficSecret2) client.SetWriteKey(cs, trafficSecret1) server.SetReadKey(cs, trafficSecret1) diff --git a/internal/mocks/logging/connection_tracer.go b/internal/mocks/logging/connection_tracer.go index 811d1e99370..a2c74b1eb39 100644 --- a/internal/mocks/logging/connection_tracer.go +++ b/internal/mocks/logging/connection_tracer.go @@ -1,388 +1,108 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go/logging (interfaces: ConnectionTracer) +//go:build !gomock && !generate -// Package mocklogging is a generated GoMock package. package mocklogging import ( - net "net" - reflect "reflect" - time "time" + "net" + "time" - protocol "github.com/quic-go/quic-go/internal/protocol" - utils "github.com/quic-go/quic-go/internal/utils" - wire "github.com/quic-go/quic-go/internal/wire" - logging "github.com/quic-go/quic-go/logging" - gomock "go.uber.org/mock/gomock" -) - -// MockConnectionTracer is a mock of ConnectionTracer interface. -type MockConnectionTracer struct { - ctrl *gomock.Controller - recorder *MockConnectionTracerMockRecorder -} - -// MockConnectionTracerMockRecorder is the mock recorder for MockConnectionTracer. -type MockConnectionTracerMockRecorder struct { - mock *MockConnectionTracer -} - -// NewMockConnectionTracer creates a new mock instance. -func NewMockConnectionTracer(ctrl *gomock.Controller) *MockConnectionTracer { - mock := &MockConnectionTracer{ctrl: ctrl} - mock.recorder = &MockConnectionTracerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockConnectionTracer) EXPECT() *MockConnectionTracerMockRecorder { - return m.recorder -} - -// AcknowledgedPacket mocks base method. -func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AcknowledgedPacket", arg0, arg1) -} - -// AcknowledgedPacket indicates an expected call of AcknowledgedPacket. -func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) -} - -// BufferedPacket mocks base method. -func (m *MockConnectionTracer) BufferedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "BufferedPacket", arg0, arg1) -} - -// BufferedPacket indicates an expected call of BufferedPacket. -func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0, arg1) -} - -// Close mocks base method. -func (m *MockConnectionTracer) Close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Close") -} - -// Close indicates an expected call of Close. -func (mr *MockConnectionTracerMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnectionTracer)(nil).Close)) -} - -// ClosedConnection mocks base method. -func (m *MockConnectionTracer) ClosedConnection(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ClosedConnection", arg0) -} - -// ClosedConnection indicates an expected call of ClosedConnection. -func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) -} - -// Debug mocks base method. -func (m *MockConnectionTracer) Debug(arg0, arg1 string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Debug", arg0, arg1) -} - -// Debug indicates an expected call of Debug. -func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) -} - -// DroppedEncryptionLevel mocks base method. -func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedEncryptionLevel", arg0) -} + "github.com/quic-go/quic-go/internal/mocks/logging/internal" + "github.com/quic-go/quic-go/logging" -// DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. -func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) -} - -// DroppedKey mocks base method. -func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedKey", arg0) -} - -// DroppedKey indicates an expected call of DroppedKey. -func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) -} - -// DroppedPacket mocks base method. -func (m *MockConnectionTracer) DroppedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount, arg2 logging.PacketDropReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2) -} - -// DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) -} - -// ECNStateUpdated mocks base method. -func (m *MockConnectionTracer) ECNStateUpdated(arg0 logging.ECNState, arg1 logging.ECNStateTrigger) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ECNStateUpdated", arg0, arg1) -} - -// ECNStateUpdated indicates an expected call of ECNStateUpdated. -func (mr *MockConnectionTracerMockRecorder) ECNStateUpdated(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNStateUpdated", reflect.TypeOf((*MockConnectionTracer)(nil).ECNStateUpdated), arg0, arg1) -} - -// LossTimerCanceled mocks base method. -func (m *MockConnectionTracer) LossTimerCanceled() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LossTimerCanceled") -} - -// LossTimerCanceled indicates an expected call of LossTimerCanceled. -func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerCanceled", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerCanceled)) -} - -// LossTimerExpired mocks base method. -func (m *MockConnectionTracer) LossTimerExpired(arg0 logging.TimerType, arg1 protocol.EncryptionLevel) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LossTimerExpired", arg0, arg1) -} - -// LossTimerExpired indicates an expected call of LossTimerExpired. -func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) -} - -// LostPacket mocks base method. -func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber, arg2 logging.PacketLossReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "LostPacket", arg0, arg1, arg2) -} - -// LostPacket indicates an expected call of LostPacket. -func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) -} - -// NegotiatedVersion mocks base method. -func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, arg1, arg2 []protocol.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) -} - -// NegotiatedVersion indicates an expected call of NegotiatedVersion. -func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) -} - -// ReceivedLongHeaderPacket mocks base method. -func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedLongHeaderPacket", arg0, arg1, arg2, arg3) -} - -// ReceivedLongHeaderPacket indicates an expected call of ReceivedLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2, arg3) -} - -// ReceivedRetry mocks base method. -func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedRetry", arg0) -} - -// ReceivedRetry indicates an expected call of ReceivedRetry. -func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) -} - -// ReceivedShortHeaderPacket mocks base method. -func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2, arg3) -} - -// ReceivedShortHeaderPacket indicates an expected call of ReceivedShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2, arg3) -} - -// ReceivedTransportParameters mocks base method. -func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedTransportParameters", arg0) -} - -// ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. -func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) -} - -// ReceivedVersionNegotiationPacket mocks base method. -func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0, arg1 protocol.ArbitraryLenConnectionID, arg2 []protocol.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1, arg2) -} - -// ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1, arg2) -} - -// RestoredTransportParameters mocks base method. -func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RestoredTransportParameters", arg0) -} - -// RestoredTransportParameters indicates an expected call of RestoredTransportParameters. -func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) -} - -// SentLongHeaderPacket mocks base method. -func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3, arg4) -} - -// SentLongHeaderPacket indicates an expected call of SentLongHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3, arg4) -} - -// SentShortHeaderPacket mocks base method. -func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3, arg4) -} - -// SentShortHeaderPacket indicates an expected call of SentShortHeaderPacket. -func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3, arg4) -} - -// SentTransportParameters mocks base method. -func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParameters) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentTransportParameters", arg0) -} - -// SentTransportParameters indicates an expected call of SentTransportParameters. -func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) -} - -// SetLossTimer mocks base method. -func (m *MockConnectionTracer) SetLossTimer(arg0 logging.TimerType, arg1 protocol.EncryptionLevel, arg2 time.Time) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetLossTimer", arg0, arg1, arg2) -} - -// SetLossTimer indicates an expected call of SetLossTimer. -func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) -} - -// StartedConnection mocks base method. -func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 protocol.ConnectionID) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "StartedConnection", arg0, arg1, arg2, arg3) -} - -// StartedConnection indicates an expected call of StartedConnection. -func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) -} - -// UpdatedCongestionState mocks base method. -func (m *MockConnectionTracer) UpdatedCongestionState(arg0 logging.CongestionState) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedCongestionState", arg0) -} - -// UpdatedCongestionState indicates an expected call of UpdatedCongestionState. -func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) -} - -// UpdatedKey mocks base method. -func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedKey", arg0, arg1) -} - -// UpdatedKey indicates an expected call of UpdatedKey. -func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) -} - -// UpdatedKeyFromTLS mocks base method. -func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, arg1 protocol.Perspective) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedKeyFromTLS", arg0, arg1) -} - -// UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. -func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) -} - -// UpdatedMetrics mocks base method. -func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3) -} - -// UpdatedMetrics indicates an expected call of UpdatedMetrics. -func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) -} - -// UpdatedPTOCount mocks base method. -func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdatedPTOCount", arg0) -} + "go.uber.org/mock/gomock" +) -// UpdatedPTOCount indicates an expected call of UpdatedPTOCount. -func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) +type MockConnectionTracer = internal.MockConnectionTracer + +func NewMockConnectionTracer(ctrl *gomock.Controller) (*logging.ConnectionTracer, *MockConnectionTracer) { + t := internal.NewMockConnectionTracer(ctrl) + return &logging.ConnectionTracer{ + StartedConnection: func(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) { + t.StartedConnection(local, remote, srcConnID, destConnID) + }, + NegotiatedVersion: func(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { + t.NegotiatedVersion(chosen, clientVersions, serverVersions) + }, + ClosedConnection: func(e error) { + t.ClosedConnection(e) + }, + SentTransportParameters: func(tp *logging.TransportParameters) { + t.SentTransportParameters(tp) + }, + ReceivedTransportParameters: func(tp *logging.TransportParameters) { + t.ReceivedTransportParameters(tp) + }, + RestoredTransportParameters: func(tp *logging.TransportParameters) { + t.RestoredTransportParameters(tp) + }, + SentLongHeaderPacket: func(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) { + t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) + }, + SentShortHeaderPacket: func(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) { + t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) + }, + ReceivedVersionNegotiationPacket: func(dest, src logging.ArbitraryLenConnectionID, versions []logging.VersionNumber) { + t.ReceivedVersionNegotiationPacket(dest, src, versions) + }, + ReceivedRetry: func(hdr *logging.Header) { + t.ReceivedRetry(hdr) + }, + ReceivedLongHeaderPacket: func(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { + t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) + }, + ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { + t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) + }, + BufferedPacket: func(typ logging.PacketType, size logging.ByteCount) { + t.BufferedPacket(typ, size) + }, + DroppedPacket: func(typ logging.PacketType, size logging.ByteCount, reason logging.PacketDropReason) { + t.DroppedPacket(typ, size, reason) + }, + UpdatedMetrics: func(rttStats *logging.RTTStats, cwnd, bytesInFlight logging.ByteCount, packetsInFlight int) { + t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight) + }, + AcknowledgedPacket: func(encLevel logging.EncryptionLevel, pn logging.PacketNumber) { + t.AcknowledgedPacket(encLevel, pn) + }, + LostPacket: func(encLevel logging.EncryptionLevel, pn logging.PacketNumber, reason logging.PacketLossReason) { + t.LostPacket(encLevel, pn, reason) + }, + UpdatedCongestionState: func(state logging.CongestionState) { + t.UpdatedCongestionState(state) + }, + UpdatedPTOCount: func(value uint32) { + t.UpdatedPTOCount(value) + }, + UpdatedKeyFromTLS: func(encLevel logging.EncryptionLevel, perspective logging.Perspective) { + t.UpdatedKeyFromTLS(encLevel, perspective) + }, + UpdatedKey: func(generation logging.KeyPhase, remote bool) { + t.UpdatedKey(generation, remote) + }, + DroppedEncryptionLevel: func(encLevel logging.EncryptionLevel) { + t.DroppedEncryptionLevel(encLevel) + }, + DroppedKey: func(generation logging.KeyPhase) { + t.DroppedKey(generation) + }, + SetLossTimer: func(typ logging.TimerType, encLevel logging.EncryptionLevel, exp time.Time) { + t.SetLossTimer(typ, encLevel, exp) + }, + LossTimerExpired: func(typ logging.TimerType, encLevel logging.EncryptionLevel) { + t.LossTimerExpired(typ, encLevel) + }, + LossTimerCanceled: func() { + t.LossTimerCanceled() + }, + ECNStateUpdated: func(state logging.ECNState, trigger logging.ECNStateTrigger) { + t.ECNStateUpdated(state, trigger) + }, + Close: func() { + t.Close() + }, + Debug: func(name, msg string) { + t.Debug(name, msg) + }, + }, t } diff --git a/internal/mocks/logging/internal/connection_tracer.go b/internal/mocks/logging/internal/connection_tracer.go new file mode 100644 index 00000000000..1cc0bd2e25e --- /dev/null +++ b/internal/mocks/logging/internal/connection_tracer.go @@ -0,0 +1,388 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/quic-go/quic-go/internal/mocks/logging (interfaces: ConnectionTracer) + +// Package internal is a generated GoMock package. +package internal + +import ( + net "net" + reflect "reflect" + time "time" + + protocol "github.com/quic-go/quic-go/internal/protocol" + utils "github.com/quic-go/quic-go/internal/utils" + wire "github.com/quic-go/quic-go/internal/wire" + logging "github.com/quic-go/quic-go/logging" + gomock "go.uber.org/mock/gomock" +) + +// MockConnectionTracer is a mock of ConnectionTracer interface. +type MockConnectionTracer struct { + ctrl *gomock.Controller + recorder *MockConnectionTracerMockRecorder +} + +// MockConnectionTracerMockRecorder is the mock recorder for MockConnectionTracer. +type MockConnectionTracerMockRecorder struct { + mock *MockConnectionTracer +} + +// NewMockConnectionTracer creates a new mock instance. +func NewMockConnectionTracer(ctrl *gomock.Controller) *MockConnectionTracer { + mock := &MockConnectionTracer{ctrl: ctrl} + mock.recorder = &MockConnectionTracerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnectionTracer) EXPECT() *MockConnectionTracerMockRecorder { + return m.recorder +} + +// AcknowledgedPacket mocks base method. +func (m *MockConnectionTracer) AcknowledgedPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AcknowledgedPacket", arg0, arg1) +} + +// AcknowledgedPacket indicates an expected call of AcknowledgedPacket. +func (mr *MockConnectionTracerMockRecorder) AcknowledgedPacket(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcknowledgedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).AcknowledgedPacket), arg0, arg1) +} + +// BufferedPacket mocks base method. +func (m *MockConnectionTracer) BufferedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "BufferedPacket", arg0, arg1) +} + +// BufferedPacket indicates an expected call of BufferedPacket. +func (mr *MockConnectionTracerMockRecorder) BufferedPacket(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).BufferedPacket), arg0, arg1) +} + +// Close mocks base method. +func (m *MockConnectionTracer) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close. +func (mr *MockConnectionTracerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnectionTracer)(nil).Close)) +} + +// ClosedConnection mocks base method. +func (m *MockConnectionTracer) ClosedConnection(arg0 error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ClosedConnection", arg0) +} + +// ClosedConnection indicates an expected call of ClosedConnection. +func (mr *MockConnectionTracerMockRecorder) ClosedConnection(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClosedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).ClosedConnection), arg0) +} + +// Debug mocks base method. +func (m *MockConnectionTracer) Debug(arg0, arg1 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Debug", arg0, arg1) +} + +// Debug indicates an expected call of Debug. +func (mr *MockConnectionTracerMockRecorder) Debug(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockConnectionTracer)(nil).Debug), arg0, arg1) +} + +// DroppedEncryptionLevel mocks base method. +func (m *MockConnectionTracer) DroppedEncryptionLevel(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedEncryptionLevel", arg0) +} + +// DroppedEncryptionLevel indicates an expected call of DroppedEncryptionLevel. +func (mr *MockConnectionTracerMockRecorder) DroppedEncryptionLevel(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedEncryptionLevel", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedEncryptionLevel), arg0) +} + +// DroppedKey mocks base method. +func (m *MockConnectionTracer) DroppedKey(arg0 protocol.KeyPhase) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedKey", arg0) +} + +// DroppedKey indicates an expected call of DroppedKey. +func (mr *MockConnectionTracerMockRecorder) DroppedKey(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedKey", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedKey), arg0) +} + +// DroppedPacket mocks base method. +func (m *MockConnectionTracer) DroppedPacket(arg0 logging.PacketType, arg1 protocol.ByteCount, arg2 logging.PacketDropReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2) +} + +// DroppedPacket indicates an expected call of DroppedPacket. +func (mr *MockConnectionTracerMockRecorder) DroppedPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockConnectionTracer)(nil).DroppedPacket), arg0, arg1, arg2) +} + +// ECNStateUpdated mocks base method. +func (m *MockConnectionTracer) ECNStateUpdated(arg0 logging.ECNState, arg1 logging.ECNStateTrigger) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ECNStateUpdated", arg0, arg1) +} + +// ECNStateUpdated indicates an expected call of ECNStateUpdated. +func (mr *MockConnectionTracerMockRecorder) ECNStateUpdated(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ECNStateUpdated", reflect.TypeOf((*MockConnectionTracer)(nil).ECNStateUpdated), arg0, arg1) +} + +// LossTimerCanceled mocks base method. +func (m *MockConnectionTracer) LossTimerCanceled() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LossTimerCanceled") +} + +// LossTimerCanceled indicates an expected call of LossTimerCanceled. +func (mr *MockConnectionTracerMockRecorder) LossTimerCanceled() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerCanceled", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerCanceled)) +} + +// LossTimerExpired mocks base method. +func (m *MockConnectionTracer) LossTimerExpired(arg0 logging.TimerType, arg1 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LossTimerExpired", arg0, arg1) +} + +// LossTimerExpired indicates an expected call of LossTimerExpired. +func (mr *MockConnectionTracerMockRecorder) LossTimerExpired(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LossTimerExpired", reflect.TypeOf((*MockConnectionTracer)(nil).LossTimerExpired), arg0, arg1) +} + +// LostPacket mocks base method. +func (m *MockConnectionTracer) LostPacket(arg0 protocol.EncryptionLevel, arg1 protocol.PacketNumber, arg2 logging.PacketLossReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "LostPacket", arg0, arg1, arg2) +} + +// LostPacket indicates an expected call of LostPacket. +func (mr *MockConnectionTracerMockRecorder) LostPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LostPacket", reflect.TypeOf((*MockConnectionTracer)(nil).LostPacket), arg0, arg1, arg2) +} + +// NegotiatedVersion mocks base method. +func (m *MockConnectionTracer) NegotiatedVersion(arg0 protocol.VersionNumber, arg1, arg2 []protocol.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NegotiatedVersion", arg0, arg1, arg2) +} + +// NegotiatedVersion indicates an expected call of NegotiatedVersion. +func (mr *MockConnectionTracerMockRecorder) NegotiatedVersion(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiatedVersion", reflect.TypeOf((*MockConnectionTracer)(nil).NegotiatedVersion), arg0, arg1, arg2) +} + +// ReceivedLongHeaderPacket mocks base method. +func (m *MockConnectionTracer) ReceivedLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []logging.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedLongHeaderPacket", arg0, arg1, arg2, arg3) +} + +// ReceivedLongHeaderPacket indicates an expected call of ReceivedLongHeaderPacket. +func (mr *MockConnectionTracerMockRecorder) ReceivedLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedLongHeaderPacket), arg0, arg1, arg2, arg3) +} + +// ReceivedRetry mocks base method. +func (m *MockConnectionTracer) ReceivedRetry(arg0 *wire.Header) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedRetry", arg0) +} + +// ReceivedRetry indicates an expected call of ReceivedRetry. +func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedRetry", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedRetry), arg0) +} + +// ReceivedShortHeaderPacket mocks base method. +func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 []logging.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2, arg3) +} + +// ReceivedShortHeaderPacket indicates an expected call of ReceivedShortHeaderPacket. +func (mr *MockConnectionTracerMockRecorder) ReceivedShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedShortHeaderPacket), arg0, arg1, arg2, arg3) +} + +// ReceivedTransportParameters mocks base method. +func (m *MockConnectionTracer) ReceivedTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedTransportParameters", arg0) +} + +// ReceivedTransportParameters indicates an expected call of ReceivedTransportParameters. +func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedTransportParameters), arg0) +} + +// ReceivedVersionNegotiationPacket mocks base method. +func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0, arg1 protocol.ArbitraryLenConnectionID, arg2 []protocol.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1, arg2) +} + +// ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. +func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1, arg2) +} + +// RestoredTransportParameters mocks base method. +func (m *MockConnectionTracer) RestoredTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RestoredTransportParameters", arg0) +} + +// RestoredTransportParameters indicates an expected call of RestoredTransportParameters. +func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0) +} + +// SentLongHeaderPacket mocks base method. +func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []logging.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3, arg4) +} + +// SentLongHeaderPacket indicates an expected call of SentLongHeaderPacket. +func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3, arg4) +} + +// SentShortHeaderPacket mocks base method. +func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 protocol.ECN, arg3 *wire.AckFrame, arg4 []logging.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3, arg4) +} + +// SentShortHeaderPacket indicates an expected call of SentShortHeaderPacket. +func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3, arg4) +} + +// SentTransportParameters mocks base method. +func (m *MockConnectionTracer) SentTransportParameters(arg0 *wire.TransportParameters) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentTransportParameters", arg0) +} + +// SentTransportParameters indicates an expected call of SentTransportParameters. +func (mr *MockConnectionTracerMockRecorder) SentTransportParameters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).SentTransportParameters), arg0) +} + +// SetLossTimer mocks base method. +func (m *MockConnectionTracer) SetLossTimer(arg0 logging.TimerType, arg1 protocol.EncryptionLevel, arg2 time.Time) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetLossTimer", arg0, arg1, arg2) +} + +// SetLossTimer indicates an expected call of SetLossTimer. +func (mr *MockConnectionTracerMockRecorder) SetLossTimer(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLossTimer", reflect.TypeOf((*MockConnectionTracer)(nil).SetLossTimer), arg0, arg1, arg2) +} + +// StartedConnection mocks base method. +func (m *MockConnectionTracer) StartedConnection(arg0, arg1 net.Addr, arg2, arg3 protocol.ConnectionID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "StartedConnection", arg0, arg1, arg2, arg3) +} + +// StartedConnection indicates an expected call of StartedConnection. +func (mr *MockConnectionTracerMockRecorder) StartedConnection(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartedConnection", reflect.TypeOf((*MockConnectionTracer)(nil).StartedConnection), arg0, arg1, arg2, arg3) +} + +// UpdatedCongestionState mocks base method. +func (m *MockConnectionTracer) UpdatedCongestionState(arg0 logging.CongestionState) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedCongestionState", arg0) +} + +// UpdatedCongestionState indicates an expected call of UpdatedCongestionState. +func (mr *MockConnectionTracerMockRecorder) UpdatedCongestionState(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedCongestionState", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedCongestionState), arg0) +} + +// UpdatedKey mocks base method. +func (m *MockConnectionTracer) UpdatedKey(arg0 protocol.KeyPhase, arg1 bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedKey", arg0, arg1) +} + +// UpdatedKey indicates an expected call of UpdatedKey. +func (mr *MockConnectionTracerMockRecorder) UpdatedKey(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKey", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKey), arg0, arg1) +} + +// UpdatedKeyFromTLS mocks base method. +func (m *MockConnectionTracer) UpdatedKeyFromTLS(arg0 protocol.EncryptionLevel, arg1 protocol.Perspective) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedKeyFromTLS", arg0, arg1) +} + +// UpdatedKeyFromTLS indicates an expected call of UpdatedKeyFromTLS. +func (mr *MockConnectionTracerMockRecorder) UpdatedKeyFromTLS(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedKeyFromTLS", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedKeyFromTLS), arg0, arg1) +} + +// UpdatedMetrics mocks base method. +func (m *MockConnectionTracer) UpdatedMetrics(arg0 *utils.RTTStats, arg1, arg2 protocol.ByteCount, arg3 int) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedMetrics", arg0, arg1, arg2, arg3) +} + +// UpdatedMetrics indicates an expected call of UpdatedMetrics. +func (mr *MockConnectionTracerMockRecorder) UpdatedMetrics(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedMetrics", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedMetrics), arg0, arg1, arg2, arg3) +} + +// UpdatedPTOCount mocks base method. +func (m *MockConnectionTracer) UpdatedPTOCount(arg0 uint32) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdatedPTOCount", arg0) +} + +// UpdatedPTOCount indicates an expected call of UpdatedPTOCount. +func (mr *MockConnectionTracerMockRecorder) UpdatedPTOCount(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatedPTOCount", reflect.TypeOf((*MockConnectionTracer)(nil).UpdatedPTOCount), arg0) +} diff --git a/internal/mocks/logging/internal/tracer.go b/internal/mocks/logging/internal/tracer.go new file mode 100644 index 00000000000..15dbcaed47e --- /dev/null +++ b/internal/mocks/logging/internal/tracer.go @@ -0,0 +1,74 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/quic-go/quic-go/internal/mocks/logging (interfaces: Tracer) + +// Package internal is a generated GoMock package. +package internal + +import ( + net "net" + reflect "reflect" + + protocol "github.com/quic-go/quic-go/internal/protocol" + wire "github.com/quic-go/quic-go/internal/wire" + logging "github.com/quic-go/quic-go/logging" + gomock "go.uber.org/mock/gomock" +) + +// MockTracer is a mock of Tracer interface. +type MockTracer struct { + ctrl *gomock.Controller + recorder *MockTracerMockRecorder +} + +// MockTracerMockRecorder is the mock recorder for MockTracer. +type MockTracerMockRecorder struct { + mock *MockTracer +} + +// NewMockTracer creates a new mock instance. +func NewMockTracer(ctrl *gomock.Controller) *MockTracer { + mock := &MockTracer{ctrl: ctrl} + mock.recorder = &MockTracerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTracer) EXPECT() *MockTracerMockRecorder { + return m.recorder +} + +// DroppedPacket mocks base method. +func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 logging.PacketType, arg2 protocol.ByteCount, arg3 logging.PacketDropReason) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2, arg3) +} + +// DroppedPacket indicates an expected call of DroppedPacket. +func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) +} + +// SentPacket mocks base method. +func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol.ByteCount, arg3 []logging.Frame) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) +} + +// SentPacket indicates an expected call of SentPacket. +func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) +} + +// SentVersionNegotiationPacket mocks base method. +func (m *MockTracer) SentVersionNegotiationPacket(arg0 net.Addr, arg1, arg2 protocol.ArbitraryLenConnectionID, arg3 []protocol.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentVersionNegotiationPacket", arg0, arg1, arg2, arg3) +} + +// SentVersionNegotiationPacket indicates an expected call of SentVersionNegotiationPacket. +func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3) +} diff --git a/internal/mocks/logging/mockgen.go b/internal/mocks/logging/mockgen.go new file mode 100644 index 00000000000..8bf08dc8ecc --- /dev/null +++ b/internal/mocks/logging/mockgen.go @@ -0,0 +1,51 @@ +//go:build gomock || generate + +package mocklogging + +import ( + "net" + "time" + + "github.com/quic-go/quic-go/logging" +) + +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package internal -destination internal/tracer.go github.com/quic-go/quic-go/internal/mocks/logging Tracer" +type Tracer interface { + SentPacket(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) + SentVersionNegotiationPacket(_ net.Addr, dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) + DroppedPacket(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) +} + +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package internal -destination internal/connection_tracer.go github.com/quic-go/quic-go/internal/mocks/logging ConnectionTracer" +type ConnectionTracer interface { + StartedConnection(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) + NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) + ClosedConnection(error) + SentTransportParameters(*logging.TransportParameters) + ReceivedTransportParameters(*logging.TransportParameters) + RestoredTransportParameters(parameters *logging.TransportParameters) // for 0-RTT + SentLongHeaderPacket(*logging.ExtendedHeader, logging.ByteCount, logging.ECN, *logging.AckFrame, []logging.Frame) + SentShortHeaderPacket(*logging.ShortHeader, logging.ByteCount, logging.ECN, *logging.AckFrame, []logging.Frame) + ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) + ReceivedRetry(*logging.Header) + ReceivedLongHeaderPacket(*logging.ExtendedHeader, logging.ByteCount, logging.ECN, []logging.Frame) + ReceivedShortHeaderPacket(*logging.ShortHeader, logging.ByteCount, logging.ECN, []logging.Frame) + BufferedPacket(logging.PacketType, logging.ByteCount) + DroppedPacket(logging.PacketType, logging.ByteCount, logging.PacketDropReason) + UpdatedMetrics(rttStats *logging.RTTStats, cwnd, bytesInFlight logging.ByteCount, packetsInFlight int) + AcknowledgedPacket(logging.EncryptionLevel, logging.PacketNumber) + LostPacket(logging.EncryptionLevel, logging.PacketNumber, logging.PacketLossReason) + UpdatedCongestionState(logging.CongestionState) + UpdatedPTOCount(value uint32) + UpdatedKeyFromTLS(logging.EncryptionLevel, logging.Perspective) + UpdatedKey(generation logging.KeyPhase, remote bool) + DroppedEncryptionLevel(logging.EncryptionLevel) + DroppedKey(generation logging.KeyPhase) + SetLossTimer(logging.TimerType, logging.EncryptionLevel, time.Time) + LossTimerExpired(logging.TimerType, logging.EncryptionLevel) + LossTimerCanceled() + ECNStateUpdated(state logging.ECNState, trigger logging.ECNStateTrigger) + // Close is called when the connection is closed. + Close() + Debug(name, msg string) +} diff --git a/internal/mocks/logging/tracer.go b/internal/mocks/logging/tracer.go index 2cdd2335701..115f578a2c7 100644 --- a/internal/mocks/logging/tracer.go +++ b/internal/mocks/logging/tracer.go @@ -1,74 +1,29 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go/logging (interfaces: Tracer) +//go:build !gomock && !generate -// Package mocklogging is a generated GoMock package. package mocklogging import ( - net "net" - reflect "reflect" + "net" - protocol "github.com/quic-go/quic-go/internal/protocol" - wire "github.com/quic-go/quic-go/internal/wire" - logging "github.com/quic-go/quic-go/logging" - gomock "go.uber.org/mock/gomock" -) - -// MockTracer is a mock of Tracer interface. -type MockTracer struct { - ctrl *gomock.Controller - recorder *MockTracerMockRecorder -} - -// MockTracerMockRecorder is the mock recorder for MockTracer. -type MockTracerMockRecorder struct { - mock *MockTracer -} + "github.com/quic-go/quic-go/internal/mocks/logging/internal" + "github.com/quic-go/quic-go/logging" -// NewMockTracer creates a new mock instance. -func NewMockTracer(ctrl *gomock.Controller) *MockTracer { - mock := &MockTracer{ctrl: ctrl} - mock.recorder = &MockTracerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTracer) EXPECT() *MockTracerMockRecorder { - return m.recorder -} - -// DroppedPacket mocks base method. -func (m *MockTracer) DroppedPacket(arg0 net.Addr, arg1 logging.PacketType, arg2 protocol.ByteCount, arg3 logging.PacketDropReason) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "DroppedPacket", arg0, arg1, arg2, arg3) -} - -// DroppedPacket indicates an expected call of DroppedPacket. -func (mr *MockTracerMockRecorder) DroppedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DroppedPacket", reflect.TypeOf((*MockTracer)(nil).DroppedPacket), arg0, arg1, arg2, arg3) -} - -// SentPacket mocks base method. -func (m *MockTracer) SentPacket(arg0 net.Addr, arg1 *wire.Header, arg2 protocol.ByteCount, arg3 []logging.Frame) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3) -} - -// SentPacket indicates an expected call of SentPacket. -func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) -} - -// SentVersionNegotiationPacket mocks base method. -func (m *MockTracer) SentVersionNegotiationPacket(arg0 net.Addr, arg1, arg2 protocol.ArbitraryLenConnectionID, arg3 []protocol.VersionNumber) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SentVersionNegotiationPacket", arg0, arg1, arg2, arg3) -} + "go.uber.org/mock/gomock" +) -// SentVersionNegotiationPacket indicates an expected call of SentVersionNegotiationPacket. -func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3) +type MockTracer = internal.MockTracer + +func NewMockTracer(ctrl *gomock.Controller) (*logging.Tracer, *MockTracer) { + t := internal.NewMockTracer(ctrl) + return &logging.Tracer{ + SentPacket: func(remote net.Addr, hdr *logging.Header, size logging.ByteCount, frames []logging.Frame) { + t.SentPacket(remote, hdr, size, frames) + }, + SentVersionNegotiationPacket: func(remote net.Addr, dest, src logging.ArbitraryLenConnectionID, versions []logging.VersionNumber) { + t.SentVersionNegotiationPacket(remote, dest, src, versions) + }, + DroppedPacket: func(remote net.Addr, typ logging.PacketType, size logging.ByteCount, reason logging.PacketDropReason) { + t.DroppedPacket(remote, typ, size, reason) + }, + }, t } diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go index 9bee4270f94..23bcda009f5 100644 --- a/internal/mocks/mockgen.go +++ b/internal/mocks/mockgen.go @@ -1,19 +1,19 @@ +//go:build gomock || generate + package mocks -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockquic -destination quic/stream.go github.com/quic-go/quic-go Stream" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockquic -destination quic/early_conn_tmp.go github.com/quic-go/quic-go EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && go run golang.org/x/tools/cmd/goimports -w quic/early_conn.go" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocklogging -destination logging/tracer.go github.com/quic-go/quic-go/logging Tracer" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocklogging -destination logging/connection_tracer.go github.com/quic-go/quic-go/logging ConnectionTracer" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination short_header_sealer.go github.com/quic-go/quic-go/internal/handshake ShortHeaderSealer" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination short_header_opener.go github.com/quic-go/quic-go/internal/handshake ShortHeaderOpener" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination long_header_opener.go github.com/quic-go/quic-go/internal/handshake LongHeaderOpener" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination crypto_setup_tmp.go github.com/quic-go/quic-go/internal/handshake CryptoSetup && sed -E 's~github.com/quic-go/qtls[[:alnum:]_-]*~github.com/quic-go/quic-go/internal/qtls~g; s~qtls.ConnectionStateWith0RTT~qtls.ConnectionState~g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && go run golang.org/x/tools/cmd/goimports -w crypto_setup.go" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination stream_flow_controller.go github.com/quic-go/quic-go/internal/flowcontrol StreamFlowController" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination congestion.go github.com/quic-go/quic-go/internal/congestion SendAlgorithmWithDebugInfos" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocks -destination connection_flow_controller.go github.com/quic-go/quic-go/internal/flowcontrol ConnectionFlowController" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/quic-go/quic-go/internal/ackhandler SentPacketHandler" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/quic-go/quic-go/internal/ackhandler ReceivedPacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mockquic -destination quic/stream.go github.com/quic-go/quic-go Stream" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mockquic -destination quic/early_conn_tmp.go github.com/quic-go/quic-go EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && go run golang.org/x/tools/cmd/goimports -w quic/early_conn.go" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination short_header_sealer.go github.com/quic-go/quic-go/internal/handshake ShortHeaderSealer" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination short_header_opener.go github.com/quic-go/quic-go/internal/handshake ShortHeaderOpener" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination long_header_opener.go github.com/quic-go/quic-go/internal/handshake LongHeaderOpener" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination crypto_setup_tmp.go github.com/quic-go/quic-go/internal/handshake CryptoSetup && sed -E 's~github.com/quic-go/qtls[[:alnum:]_-]*~github.com/quic-go/quic-go/internal/qtls~g; s~qtls.ConnectionStateWith0RTT~qtls.ConnectionState~g' crypto_setup_tmp.go > crypto_setup.go && rm crypto_setup_tmp.go && go run golang.org/x/tools/cmd/goimports -w crypto_setup.go" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination stream_flow_controller.go github.com/quic-go/quic-go/internal/flowcontrol StreamFlowController" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination congestion.go github.com/quic-go/quic-go/internal/congestion SendAlgorithmWithDebugInfos" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocks -destination connection_flow_controller.go github.com/quic-go/quic-go/internal/flowcontrol ConnectionFlowController" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mockackhandler -destination ackhandler/sent_packet_handler.go github.com/quic-go/quic-go/internal/ackhandler SentPacketHandler" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mockackhandler -destination ackhandler/received_packet_handler.go github.com/quic-go/quic-go/internal/ackhandler ReceivedPacketHandler" // The following command produces a warning message on OSX, however, it still generates the correct mock file. // See https://github.com/golang/mock/issues/339 for details. -//go:generate sh -c "go run go.uber.org/mock/mockgen -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache" +//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package mocktls -destination tls/client_session_cache.go crypto/tls ClientSessionCache" diff --git a/interop/utils/logging.go b/interop/utils/logging.go index 30e3f663f7a..3cb42244b0f 100644 --- a/interop/utils/logging.go +++ b/interop/utils/logging.go @@ -29,7 +29,7 @@ func GetSSLKeyLog() (io.WriteCloser, error) { } // NewQLOGConnectionTracer create a qlog file in QLOGDIR -func NewQLOGConnectionTracer(_ context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { +func NewQLOGConnectionTracer(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { qlogDir := os.Getenv("QLOGDIR") if len(qlogDir) == 0 { return nil diff --git a/logging/connection_tracer.go b/logging/connection_tracer.go new file mode 100644 index 00000000000..e3f322d91d1 --- /dev/null +++ b/logging/connection_tracer.go @@ -0,0 +1,255 @@ +package logging + +import ( + "net" + "time" +) + +// A ConnectionTracer records events. +type ConnectionTracer struct { + StartedConnection func(local, remote net.Addr, srcConnID, destConnID ConnectionID) + NegotiatedVersion func(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) + ClosedConnection func(error) + SentTransportParameters func(*TransportParameters) + ReceivedTransportParameters func(*TransportParameters) + RestoredTransportParameters func(parameters *TransportParameters) // for 0-RTT + SentLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame) + SentShortHeaderPacket func(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame) + ReceivedVersionNegotiationPacket func(dest, src ArbitraryLenConnectionID, _ []VersionNumber) + ReceivedRetry func(*Header) + ReceivedLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, []Frame) + ReceivedShortHeaderPacket func(*ShortHeader, ByteCount, ECN, []Frame) + BufferedPacket func(PacketType, ByteCount) + DroppedPacket func(PacketType, ByteCount, PacketDropReason) + UpdatedMetrics func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) + AcknowledgedPacket func(EncryptionLevel, PacketNumber) + LostPacket func(EncryptionLevel, PacketNumber, PacketLossReason) + UpdatedCongestionState func(CongestionState) + UpdatedPTOCount func(value uint32) + UpdatedKeyFromTLS func(EncryptionLevel, Perspective) + UpdatedKey func(generation KeyPhase, remote bool) + DroppedEncryptionLevel func(EncryptionLevel) + DroppedKey func(generation KeyPhase) + SetLossTimer func(TimerType, EncryptionLevel, time.Time) + LossTimerExpired func(TimerType, EncryptionLevel) + LossTimerCanceled func() + ECNStateUpdated func(state ECNState, trigger ECNStateTrigger) + // Close is called when the connection is closed. + Close func() + Debug func(name, msg string) +} + +// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers. +func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &ConnectionTracer{ + StartedConnection: func(local, remote net.Addr, srcConnID, destConnID ConnectionID) { + for _, t := range tracers { + if t.StartedConnection != nil { + t.StartedConnection(local, remote, srcConnID, destConnID) + } + } + }, + NegotiatedVersion: func(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { + for _, t := range tracers { + if t.NegotiatedVersion != nil { + t.NegotiatedVersion(chosen, clientVersions, serverVersions) + } + } + }, + ClosedConnection: func(e error) { + for _, t := range tracers { + if t.ClosedConnection != nil { + t.ClosedConnection(e) + } + } + }, + SentTransportParameters: func(tp *TransportParameters) { + for _, t := range tracers { + if t.SentTransportParameters != nil { + t.SentTransportParameters(tp) + } + } + }, + ReceivedTransportParameters: func(tp *TransportParameters) { + for _, t := range tracers { + if t.ReceivedTransportParameters != nil { + t.ReceivedTransportParameters(tp) + } + } + }, + RestoredTransportParameters: func(tp *TransportParameters) { + for _, t := range tracers { + if t.RestoredTransportParameters != nil { + t.RestoredTransportParameters(tp) + } + } + }, + SentLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { + for _, t := range tracers { + if t.SentLongHeaderPacket != nil { + t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) + } + } + }, + SentShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { + for _, t := range tracers { + if t.SentShortHeaderPacket != nil { + t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) + } + } + }, + ReceivedVersionNegotiationPacket: func(dest, src ArbitraryLenConnectionID, versions []VersionNumber) { + for _, t := range tracers { + if t.ReceivedVersionNegotiationPacket != nil { + t.ReceivedVersionNegotiationPacket(dest, src, versions) + } + } + }, + ReceivedRetry: func(hdr *Header) { + for _, t := range tracers { + if t.ReceivedRetry != nil { + t.ReceivedRetry(hdr) + } + } + }, + ReceivedLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) { + for _, t := range tracers { + if t.ReceivedLongHeaderPacket != nil { + t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) + } + } + }, + ReceivedShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) { + for _, t := range tracers { + if t.ReceivedShortHeaderPacket != nil { + t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) + } + } + }, + BufferedPacket: func(typ PacketType, size ByteCount) { + for _, t := range tracers { + if t.BufferedPacket != nil { + t.BufferedPacket(typ, size) + } + } + }, + DroppedPacket: func(typ PacketType, size ByteCount, reason PacketDropReason) { + for _, t := range tracers { + if t.DroppedPacket != nil { + t.DroppedPacket(typ, size, reason) + } + } + }, + UpdatedMetrics: func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) { + for _, t := range tracers { + if t.UpdatedMetrics != nil { + t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight) + } + } + }, + AcknowledgedPacket: func(encLevel EncryptionLevel, pn PacketNumber) { + for _, t := range tracers { + if t.AcknowledgedPacket != nil { + t.AcknowledgedPacket(encLevel, pn) + } + } + }, + LostPacket: func(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { + for _, t := range tracers { + if t.LostPacket != nil { + t.LostPacket(encLevel, pn, reason) + } + } + }, + UpdatedCongestionState: func(state CongestionState) { + for _, t := range tracers { + if t.UpdatedCongestionState != nil { + t.UpdatedCongestionState(state) + } + } + }, + UpdatedPTOCount: func(value uint32) { + for _, t := range tracers { + if t.UpdatedPTOCount != nil { + t.UpdatedPTOCount(value) + } + } + }, + UpdatedKeyFromTLS: func(encLevel EncryptionLevel, perspective Perspective) { + for _, t := range tracers { + if t.UpdatedKeyFromTLS != nil { + t.UpdatedKeyFromTLS(encLevel, perspective) + } + } + }, + UpdatedKey: func(generation KeyPhase, remote bool) { + for _, t := range tracers { + if t.UpdatedKey != nil { + t.UpdatedKey(generation, remote) + } + } + }, + DroppedEncryptionLevel: func(encLevel EncryptionLevel) { + for _, t := range tracers { + if t.DroppedEncryptionLevel != nil { + t.DroppedEncryptionLevel(encLevel) + } + } + }, + DroppedKey: func(generation KeyPhase) { + for _, t := range tracers { + if t.DroppedKey != nil { + t.DroppedKey(generation) + } + } + }, + SetLossTimer: func(typ TimerType, encLevel EncryptionLevel, exp time.Time) { + for _, t := range tracers { + if t.SetLossTimer != nil { + t.SetLossTimer(typ, encLevel, exp) + } + } + }, + LossTimerExpired: func(typ TimerType, encLevel EncryptionLevel) { + for _, t := range tracers { + if t.LossTimerExpired != nil { + t.LossTimerExpired(typ, encLevel) + } + } + }, + LossTimerCanceled: func() { + for _, t := range tracers { + if t.LossTimerCanceled != nil { + t.LossTimerCanceled() + } + } + }, + ECNStateUpdated: func(state ECNState, trigger ECNStateTrigger) { + for _, t := range tracers { + if t.ECNStateUpdated != nil { + t.ECNStateUpdated(state, trigger) + } + } + }, + Close: func() { + for _, t := range tracers { + if t.Close != nil { + t.Close() + } + } + }, + Debug: func(name, msg string) { + for _, t := range tracers { + if t.Debug != nil { + t.Debug(name, msg) + } + } + }, + } +} diff --git a/logging/interface.go b/logging/interface.go index 62028ebc07b..10ac038fb2f 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -3,9 +3,6 @@ package logging import ( - "net" - "time" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" @@ -112,44 +109,3 @@ type ShortHeader struct { PacketNumberLen protocol.PacketNumberLen KeyPhase KeyPhaseBit } - -// A Tracer traces events. -type Tracer interface { - SentPacket(net.Addr, *Header, ByteCount, []Frame) - SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) - DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) -} - -// A ConnectionTracer records events. -type ConnectionTracer interface { - StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) - NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) - ClosedConnection(error) - SentTransportParameters(*TransportParameters) - ReceivedTransportParameters(*TransportParameters) - RestoredTransportParameters(parameters *TransportParameters) // for 0-RTT - SentLongHeaderPacket(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame) - SentShortHeaderPacket(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame) - ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) - ReceivedRetry(*Header) - ReceivedLongHeaderPacket(*ExtendedHeader, ByteCount, ECN, []Frame) - ReceivedShortHeaderPacket(*ShortHeader, ByteCount, ECN, []Frame) - BufferedPacket(PacketType, ByteCount) - DroppedPacket(PacketType, ByteCount, PacketDropReason) - UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) - AcknowledgedPacket(EncryptionLevel, PacketNumber) - LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) - UpdatedCongestionState(CongestionState) - UpdatedPTOCount(value uint32) - UpdatedKeyFromTLS(EncryptionLevel, Perspective) - UpdatedKey(generation KeyPhase, remote bool) - DroppedEncryptionLevel(EncryptionLevel) - DroppedKey(generation KeyPhase) - SetLossTimer(TimerType, EncryptionLevel, time.Time) - LossTimerExpired(TimerType, EncryptionLevel) - LossTimerCanceled() - ECNStateUpdated(state ECNState, trigger ECNStateTrigger) - // Close is called when the connection is closed. - Close() - Debug(name, msg string) -} diff --git a/logging/multiplex.go b/logging/multiplex.go deleted file mode 100644 index ef91f3f2b96..00000000000 --- a/logging/multiplex.go +++ /dev/null @@ -1,232 +0,0 @@ -package logging - -import ( - "net" - "time" -) - -type tracerMultiplexer struct { - tracers []Tracer -} - -var _ Tracer = &tracerMultiplexer{} - -// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. -func NewMultiplexedTracer(tracers ...Tracer) Tracer { - if len(tracers) == 0 { - return nil - } - if len(tracers) == 1 { - return tracers[0] - } - return &tracerMultiplexer{tracers} -} - -func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { - for _, t := range m.tracers { - t.SentPacket(remote, hdr, size, frames) - } -} - -func (m *tracerMultiplexer) SentVersionNegotiationPacket(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []VersionNumber) { - for _, t := range m.tracers { - t.SentVersionNegotiationPacket(remote, dest, src, versions) - } -} - -func (m *tracerMultiplexer) DroppedPacket(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) { - for _, t := range m.tracers { - t.DroppedPacket(remote, typ, size, reason) - } -} - -type connTracerMultiplexer struct { - tracers []ConnectionTracer -} - -var _ ConnectionTracer = &connTracerMultiplexer{} - -// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers. -func NewMultiplexedConnectionTracer(tracers ...ConnectionTracer) ConnectionTracer { - if len(tracers) == 0 { - return nil - } - if len(tracers) == 1 { - return tracers[0] - } - return &connTracerMultiplexer{tracers: tracers} -} - -func (m *connTracerMultiplexer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) { - for _, t := range m.tracers { - t.StartedConnection(local, remote, srcConnID, destConnID) - } -} - -func (m *connTracerMultiplexer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { - for _, t := range m.tracers { - t.NegotiatedVersion(chosen, clientVersions, serverVersions) - } -} - -func (m *connTracerMultiplexer) ClosedConnection(e error) { - for _, t := range m.tracers { - t.ClosedConnection(e) - } -} - -func (m *connTracerMultiplexer) SentTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.SentTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) ReceivedTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.ReceivedTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) RestoredTransportParameters(tp *TransportParameters) { - for _, t := range m.tracers { - t.RestoredTransportParameters(tp) - } -} - -func (m *connTracerMultiplexer) SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { - for _, t := range m.tracers { - t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) - } -} - -func (m *connTracerMultiplexer) SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) { - for _, t := range m.tracers { - t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) - } -} - -func (m *connTracerMultiplexer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, versions []VersionNumber) { - for _, t := range m.tracers { - t.ReceivedVersionNegotiationPacket(dest, src, versions) - } -} - -func (m *connTracerMultiplexer) ReceivedRetry(hdr *Header) { - for _, t := range m.tracers { - t.ReceivedRetry(hdr) - } -} - -func (m *connTracerMultiplexer) ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) { - for _, t := range m.tracers { - t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) - } -} - -func (m *connTracerMultiplexer) ReceivedShortHeaderPacket(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) { - for _, t := range m.tracers { - t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) - } -} - -func (m *connTracerMultiplexer) BufferedPacket(typ PacketType, size ByteCount) { - for _, t := range m.tracers { - t.BufferedPacket(typ, size) - } -} - -func (m *connTracerMultiplexer) DroppedPacket(typ PacketType, size ByteCount, reason PacketDropReason) { - for _, t := range m.tracers { - t.DroppedPacket(typ, size, reason) - } -} - -func (m *connTracerMultiplexer) UpdatedCongestionState(state CongestionState) { - for _, t := range m.tracers { - t.UpdatedCongestionState(state) - } -} - -func (m *connTracerMultiplexer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFLight ByteCount, packetsInFlight int) { - for _, t := range m.tracers { - t.UpdatedMetrics(rttStats, cwnd, bytesInFLight, packetsInFlight) - } -} - -func (m *connTracerMultiplexer) AcknowledgedPacket(encLevel EncryptionLevel, pn PacketNumber) { - for _, t := range m.tracers { - t.AcknowledgedPacket(encLevel, pn) - } -} - -func (m *connTracerMultiplexer) LostPacket(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) { - for _, t := range m.tracers { - t.LostPacket(encLevel, pn, reason) - } -} - -func (m *connTracerMultiplexer) UpdatedPTOCount(value uint32) { - for _, t := range m.tracers { - t.UpdatedPTOCount(value) - } -} - -func (m *connTracerMultiplexer) UpdatedKeyFromTLS(encLevel EncryptionLevel, perspective Perspective) { - for _, t := range m.tracers { - t.UpdatedKeyFromTLS(encLevel, perspective) - } -} - -func (m *connTracerMultiplexer) UpdatedKey(generation KeyPhase, remote bool) { - for _, t := range m.tracers { - t.UpdatedKey(generation, remote) - } -} - -func (m *connTracerMultiplexer) DroppedEncryptionLevel(encLevel EncryptionLevel) { - for _, t := range m.tracers { - t.DroppedEncryptionLevel(encLevel) - } -} - -func (m *connTracerMultiplexer) DroppedKey(generation KeyPhase) { - for _, t := range m.tracers { - t.DroppedKey(generation) - } -} - -func (m *connTracerMultiplexer) SetLossTimer(typ TimerType, encLevel EncryptionLevel, exp time.Time) { - for _, t := range m.tracers { - t.SetLossTimer(typ, encLevel, exp) - } -} - -func (m *connTracerMultiplexer) LossTimerExpired(typ TimerType, encLevel EncryptionLevel) { - for _, t := range m.tracers { - t.LossTimerExpired(typ, encLevel) - } -} - -func (m *connTracerMultiplexer) LossTimerCanceled() { - for _, t := range m.tracers { - t.LossTimerCanceled() - } -} - -func (m *connTracerMultiplexer) ECNStateUpdated(state ECNState, trigger ECNStateTrigger) { - for _, t := range m.tracers { - t.ECNStateUpdated(state, trigger) - } -} - -func (m *connTracerMultiplexer) Debug(name, msg string) { - for _, t := range m.tracers { - t.Debug(name, msg) - } -} - -func (m *connTracerMultiplexer) Close() { - for _, t := range m.tracers { - t.Close() - } -} diff --git a/logging/multiplex_test.go b/logging/multiplex_test.go index c0f784ec276..69606346262 100644 --- a/logging/multiplex_test.go +++ b/logging/multiplex_test.go @@ -21,21 +21,22 @@ var _ = Describe("Tracing", func() { }) It("returns the raw tracer if only one tracer is passed in", func() { - tr := mocklogging.NewMockTracer(mockCtrl) + tr := &Tracer{} tracer := NewMultiplexedTracer(tr) - Expect(tracer).To(BeAssignableToTypeOf(&mocklogging.MockTracer{})) + Expect(tracer).To(Equal(tr)) }) Context("tracing events", func() { var ( - tracer Tracer + tracer *Tracer tr1, tr2 *mocklogging.MockTracer ) BeforeEach(func() { - tr1 = mocklogging.NewMockTracer(mockCtrl) - tr2 = mocklogging.NewMockTracer(mockCtrl) - tracer = NewMultiplexedTracer(tr1, tr2) + var t1, t2 *Tracer + t1, tr1 = mocklogging.NewMockTracer(mockCtrl) + t2, tr2 = mocklogging.NewMockTracer(mockCtrl) + tracer = NewMultiplexedTracer(t1, t2, &Tracer{}) }) It("traces the PacketSent event", func() { @@ -68,18 +69,19 @@ var _ = Describe("Tracing", func() { Context("Connection Tracer", func() { var ( - tracer ConnectionTracer + tracer *ConnectionTracer tr1 *mocklogging.MockConnectionTracer tr2 *mocklogging.MockConnectionTracer ) BeforeEach(func() { - tr1 = mocklogging.NewMockConnectionTracer(mockCtrl) - tr2 = mocklogging.NewMockConnectionTracer(mockCtrl) - tracer = NewMultiplexedConnectionTracer(tr1, tr2) + var t1, t2 *ConnectionTracer + t1, tr1 = mocklogging.NewMockConnectionTracer(mockCtrl) + t2, tr2 = mocklogging.NewMockConnectionTracer(mockCtrl) + tracer = NewMultiplexedConnectionTracer(t1, t2) }) - It("trace the ConnectionStarted event", func() { + It("traces the StartedConnection event", func() { local := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4)} remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} dest := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) @@ -89,6 +91,15 @@ var _ = Describe("Tracing", func() { tracer.StartedConnection(local, remote, src, dest) }) + It("traces the NegotiatedVersion event", func() { + chosen := protocol.Version2 + client := []protocol.VersionNumber{protocol.Version1} + server := []protocol.VersionNumber{13, 37} + tr1.EXPECT().NegotiatedVersion(chosen, client, server) + tr2.EXPECT().NegotiatedVersion(chosen, client, server) + tracer.NegotiatedVersion(chosen, client, server) + }) + It("traces the ClosedConnection event", func() { e := errors.New("test err") tr1.EXPECT().ClosedConnection(e) diff --git a/logging/null_tracer.go b/logging/null_tracer.go deleted file mode 100644 index 3e3458f7e6d..00000000000 --- a/logging/null_tracer.go +++ /dev/null @@ -1,63 +0,0 @@ -package logging - -import ( - "net" - "time" -) - -// The NullTracer is a Tracer that does nothing. -// It is useful for embedding. -type NullTracer struct{} - -var _ Tracer = &NullTracer{} - -func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {} -func (n NullTracer) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) { -} -func (n NullTracer) DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) {} - -// The NullConnectionTracer is a ConnectionTracer that does nothing. -// It is useful for embedding. -type NullConnectionTracer struct{} - -var _ ConnectionTracer = &NullConnectionTracer{} - -func (n NullConnectionTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) { -} - -func (n NullConnectionTracer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { -} -func (n NullConnectionTracer) ClosedConnection(err error) {} -func (n NullConnectionTracer) SentTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) ReceivedTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) RestoredTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) SentLongHeaderPacket(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame) { -} - -func (n NullConnectionTracer) SentShortHeaderPacket(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame) { -} - -func (n NullConnectionTracer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) { -} -func (n NullConnectionTracer) ReceivedRetry(*Header) {} -func (n NullConnectionTracer) ReceivedLongHeaderPacket(*ExtendedHeader, ByteCount, ECN, []Frame) {} -func (n NullConnectionTracer) ReceivedShortHeaderPacket(*ShortHeader, ByteCount, ECN, []Frame) {} -func (n NullConnectionTracer) BufferedPacket(PacketType, ByteCount) {} -func (n NullConnectionTracer) DroppedPacket(PacketType, ByteCount, PacketDropReason) {} - -func (n NullConnectionTracer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) { -} -func (n NullConnectionTracer) AcknowledgedPacket(EncryptionLevel, PacketNumber) {} -func (n NullConnectionTracer) LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) {} -func (n NullConnectionTracer) UpdatedCongestionState(CongestionState) {} -func (n NullConnectionTracer) UpdatedPTOCount(uint32) {} -func (n NullConnectionTracer) UpdatedKeyFromTLS(EncryptionLevel, Perspective) {} -func (n NullConnectionTracer) UpdatedKey(keyPhase KeyPhase, remote bool) {} -func (n NullConnectionTracer) DroppedEncryptionLevel(EncryptionLevel) {} -func (n NullConnectionTracer) DroppedKey(KeyPhase) {} -func (n NullConnectionTracer) SetLossTimer(TimerType, EncryptionLevel, time.Time) {} -func (n NullConnectionTracer) LossTimerExpired(TimerType, EncryptionLevel) {} -func (n NullConnectionTracer) LossTimerCanceled() {} -func (n NullConnectionTracer) ECNStateUpdated(ECNState, ECNStateTrigger) {} -func (n NullConnectionTracer) Close() {} -func (n NullConnectionTracer) Debug(name, msg string) {} diff --git a/logging/tracer.go b/logging/tracer.go new file mode 100644 index 00000000000..5918f30f842 --- /dev/null +++ b/logging/tracer.go @@ -0,0 +1,43 @@ +package logging + +import "net" + +// A Tracer traces events. +type Tracer struct { + SentPacket func(net.Addr, *Header, ByteCount, []Frame) + SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) + DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason) +} + +// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers. +func NewMultiplexedTracer(tracers ...*Tracer) *Tracer { + if len(tracers) == 0 { + return nil + } + if len(tracers) == 1 { + return tracers[0] + } + return &Tracer{ + SentPacket: func(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { + for _, t := range tracers { + if t.SentPacket != nil { + t.SentPacket(remote, hdr, size, frames) + } + } + }, + SentVersionNegotiationPacket: func(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []VersionNumber) { + for _, t := range tracers { + if t.SentVersionNegotiationPacket != nil { + t.SentVersionNegotiationPacket(remote, dest, src, versions) + } + } + }, + DroppedPacket: func(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) { + for _, t := range tracers { + if t.DroppedPacket != nil { + t.DroppedPacket(remote, typ, size, reason) + } + } + }, + } +} diff --git a/qlog/qlog.go b/qlog/qlog.go index 943bbc3d11b..7801b646358 100644 --- a/qlog/qlog.go +++ b/qlog/qlog.go @@ -63,11 +63,9 @@ type connectionTracer struct { lastMetrics *metrics } -var _ logging.ConnectionTracer = &connectionTracer{} - // NewConnectionTracer creates a new tracer to record a qlog for a connection. -func NewConnectionTracer(w io.WriteCloser, p protocol.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer { - t := &connectionTracer{ +func NewConnectionTracer(w io.WriteCloser, p protocol.Perspective, odcid protocol.ConnectionID) *logging.ConnectionTracer { + t := connectionTracer{ w: w, perspective: p, odcid: odcid, @@ -76,7 +74,84 @@ func NewConnectionTracer(w io.WriteCloser, p protocol.Perspective, odcid protoco referenceTime: time.Now(), } go t.run() - return t + return &logging.ConnectionTracer{ + StartedConnection: func(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) { + t.StartedConnection(local, remote, srcConnID, destConnID) + }, + NegotiatedVersion: func(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { + t.NegotiatedVersion(chosen, clientVersions, serverVersions) + }, + ClosedConnection: func(e error) { t.ClosedConnection(e) }, + SentTransportParameters: func(tp *wire.TransportParameters) { t.SentTransportParameters(tp) }, + ReceivedTransportParameters: func(tp *wire.TransportParameters) { t.ReceivedTransportParameters(tp) }, + RestoredTransportParameters: func(tp *wire.TransportParameters) { t.RestoredTransportParameters(tp) }, + SentLongHeaderPacket: func(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) { + t.SentLongHeaderPacket(hdr, size, ecn, ack, frames) + }, + SentShortHeaderPacket: func(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) { + t.SentShortHeaderPacket(hdr, size, ecn, ack, frames) + }, + ReceivedLongHeaderPacket: func(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { + t.ReceivedLongHeaderPacket(hdr, size, ecn, frames) + }, + ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { + t.ReceivedShortHeaderPacket(hdr, size, ecn, frames) + }, + ReceivedRetry: func(hdr *wire.Header) { + t.ReceivedRetry(hdr) + }, + ReceivedVersionNegotiationPacket: func(dest, src logging.ArbitraryLenConnectionID, versions []logging.VersionNumber) { + t.ReceivedVersionNegotiationPacket(dest, src, versions) + }, + BufferedPacket: func(pt logging.PacketType, size protocol.ByteCount) { + t.BufferedPacket(pt, size) + }, + DroppedPacket: func(pt logging.PacketType, size protocol.ByteCount, reason logging.PacketDropReason) { + t.DroppedPacket(pt, size, reason) + }, + UpdatedMetrics: func(rttStats *utils.RTTStats, cwnd, bytesInFlight protocol.ByteCount, packetsInFlight int) { + t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight) + }, + LostPacket: func(encLevel protocol.EncryptionLevel, pn protocol.PacketNumber, lossReason logging.PacketLossReason) { + t.LostPacket(encLevel, pn, lossReason) + }, + UpdatedCongestionState: func(state logging.CongestionState) { + t.UpdatedCongestionState(state) + }, + UpdatedPTOCount: func(value uint32) { + t.UpdatedPTOCount(value) + }, + UpdatedKeyFromTLS: func(encLevel protocol.EncryptionLevel, pers protocol.Perspective) { + t.UpdatedKeyFromTLS(encLevel, pers) + }, + UpdatedKey: func(generation protocol.KeyPhase, remote bool) { + t.UpdatedKey(generation, remote) + }, + DroppedEncryptionLevel: func(encLevel protocol.EncryptionLevel) { + t.DroppedEncryptionLevel(encLevel) + }, + DroppedKey: func(generation protocol.KeyPhase) { + t.DroppedKey(generation) + }, + SetLossTimer: func(tt logging.TimerType, encLevel protocol.EncryptionLevel, timeout time.Time) { + t.SetLossTimer(tt, encLevel, timeout) + }, + LossTimerExpired: func(tt logging.TimerType, encLevel protocol.EncryptionLevel) { + t.LossTimerExpired(tt, encLevel) + }, + LossTimerCanceled: func() { + t.LossTimerCanceled() + }, + ECNStateUpdated: func(state logging.ECNState, trigger logging.ECNStateTrigger) { + t.ECNStateUpdated(state, trigger) + }, + Debug: func(name, msg string) { + t.Debug(name, msg) + }, + Close: func() { + t.Close() + }, + } } func (t *connectionTracer) run() { diff --git a/qlog/qlog_test.go b/qlog/qlog_test.go index c19a0d3e1da..a58f02780c8 100644 --- a/qlog/qlog_test.go +++ b/qlog/qlog_test.go @@ -70,7 +70,7 @@ var _ = Describe("Tracing", func() { Context("connection tracer", func() { var ( - tracer logging.ConnectionTracer + tracer *logging.ConnectionTracer buf *bytes.Buffer ) diff --git a/server.go b/server.go index d41c6465ac4..813caafb431 100644 --- a/server.go +++ b/server.go @@ -93,7 +93,7 @@ type baseServer struct { *tls.Config, *handshake.TokenGenerator, bool, /* client address validated by an address validation token */ - logging.ConnectionTracer, + *logging.ConnectionTracer, uint64, utils.Logger, protocol.VersionNumber, @@ -109,7 +109,7 @@ type baseServer struct { connQueue chan quicConn connQueueLen int32 // to be used as an atomic - tracer logging.Tracer + tracer *logging.Tracer logger utils.Logger } @@ -225,7 +225,7 @@ func newServer( connIDGenerator ConnectionIDGenerator, tlsConf *tls.Config, config *Config, - tracer logging.Tracer, + tracer *logging.Tracer, onClose func(), disableVersionNegotiation bool, acceptEarly bool, @@ -353,7 +353,7 @@ func (s *baseServer) handlePacket(p receivedPacket) { case s.receivedPackets <- p: default: s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) } } @@ -366,7 +366,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st if wire.IsVersionNegotiationPacket(p.data) { s.logger.Debugf("Dropping Version Negotiation packet.") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -379,7 +379,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st // drop the packet if we failed to parse the protocol version if err != nil { s.logger.Debugf("Dropping a packet with an unknown version") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -392,7 +392,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st if p.Size() < protocol.MinUnknownVersionPacketSize { s.logger.Debugf("Dropping a packet with an unsupported version number %d that is too small (%d bytes)", v, p.Size()) - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -402,7 +402,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st if wire.Is0RTTPacket(p.data) { if !s.acceptEarlyConns { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -414,7 +414,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st // The header will then be parsed again. hdr, _, _, err := wire.ParsePacket(p.data) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing packet: %s", err) @@ -422,7 +422,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st } if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -433,7 +433,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st // There's little point in sending a Stateless Reset, since the client // might not have received the token yet. s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) } return false @@ -452,7 +452,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { connID, err := wire.ParseConnectionID(p.data, 0) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError) } return false @@ -466,7 +466,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { if q, ok := s.zeroRTTQueues[connID]; ok { if len(q.packets) >= protocol.Max0RTTQueueLen { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } return false @@ -476,7 +476,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { } if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } return false @@ -504,7 +504,7 @@ func (s *baseServer) cleanupZeroRTTQueues(now time.Time) { continue } for _, p := range q.packets { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) } p.buffer.Release() @@ -540,7 +540,7 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool { func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { p.buffer.Release() - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) } return errors.New("too short connection ID") @@ -625,7 +625,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error } config = populateConfig(conf) } - var tracer logging.ConnectionTracer + var tracer *logging.ConnectionTracer if config.Tracer != nil { // Use the same connection ID that is passed to the client's GetLogWriter callback. connID := hdr.DestConnectionID @@ -742,7 +742,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe // append the Retry integrity tag tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version) buf.Data = append(buf.Data, tag[:]...) - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentPacket != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported) @@ -763,7 +763,7 @@ func (s *baseServer) maybeSendInvalidToken(p receivedPacket) { hdr, _, _, err := wire.ParsePacket(p.data) if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing packet: %s", err) @@ -778,14 +778,14 @@ func (s *baseServer) maybeSendInvalidToken(p receivedPacket) { // Only send INVALID_TOKEN if we can unprotect the packet. // This makes sure that we won't send it for packets that were corrupted. if err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) } return } hdrLen := extHdr.ParsedLen() if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) } return @@ -841,7 +841,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han replyHdr.Log(s.logger) wire.LogFrame(s.logger, ccf, true) - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentPacket != nil { s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported) @@ -870,7 +870,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { _, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data) if err != nil { // should never happen s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs") - if s.tracer != nil { + if s.tracer != nil && s.tracer.DroppedPacket != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return @@ -879,7 +879,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { s.logger.Debugf("Client offered version %s, sending Version Negotiation", v) data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions) - if s.tracer != nil { + if s.tracer != nil && s.tracer.SentVersionNegotiationPacket != nil { s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) } if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil { diff --git a/server_test.go b/server_test.go index 31a977648d8..71c968bd739 100644 --- a/server_test.go +++ b/server_test.go @@ -177,8 +177,9 @@ var _ = Describe("Server", func() { ) BeforeEach(func() { - tracer = mocklogging.NewMockTracer(mockCtrl) - tr = &Transport{Conn: conn, Tracer: tracer} + var t *logging.Tracer + t, tracer = mocklogging.NewMockTracer(mockCtrl) + tr = &Transport{Conn: conn, Tracer: t} ln, err := tr.Listen(tlsConf, nil) Expect(err).ToNot(HaveOccurred()) serv = ln.baseServer @@ -291,7 +292,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -493,7 +494,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -552,7 +553,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -605,7 +606,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -641,7 +642,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -712,7 +713,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1022,7 +1023,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1099,7 +1100,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1172,7 +1173,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1215,7 +1216,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1279,7 +1280,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, @@ -1329,8 +1330,9 @@ var _ = Describe("Server", func() { ) BeforeEach(func() { - tracer = mocklogging.NewMockTracer(mockCtrl) - tr = &Transport{Conn: conn, Tracer: tracer} + var t *logging.Tracer + t, tracer = mocklogging.NewMockTracer(mockCtrl) + tr = &Transport{Conn: conn, Tracer: t} ln, err := tr.ListenEarly(tlsConf, nil) Expect(err).ToNot(HaveOccurred()) phm = NewMockPacketHandlerManager(mockCtrl) @@ -1404,7 +1406,7 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ logging.ConnectionTracer, + _ *logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, diff --git a/transport.go b/transport.go index c021be77d7a..6aca9b21d8b 100644 --- a/transport.go +++ b/transport.go @@ -63,7 +63,7 @@ type Transport struct { DisableVersionNegotiationPackets bool // A Tracer traces events that don't belong to a single QUIC connection. - Tracer logging.Tracer + Tracer *logging.Tracer handlerMap packetHandlerManager @@ -351,7 +351,7 @@ func (t *Transport) handlePacket(p receivedPacket) { connID, err := wire.ParseConnectionID(p.data, t.connIDLen) if err != nil { t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) - if t.Tracer != nil { + if t.Tracer != nil && t.Tracer.DroppedPacket != nil { t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } p.buffer.MaybeRelease() @@ -446,7 +446,7 @@ func (t *Transport) handleNonQUICPacket(p receivedPacket) { select { case t.nonQUICPackets <- p: default: - if t.Tracer != nil { + if t.Tracer != nil && t.Tracer.DroppedPacket != nil { t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) } } diff --git a/transport_test.go b/transport_test.go index 14c6fdbbcf8..2501971982e 100644 --- a/transport_test.go +++ b/transport_test.go @@ -126,11 +126,11 @@ var _ = Describe("Transport", func() { It("drops unparseable QUIC packets", func() { addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} packetChan := make(chan packetToRead) - tracer := mocklogging.NewMockTracer(mockCtrl) + t, tracer := mocklogging.NewMockTracer(mockCtrl) tr := &Transport{ Conn: newMockPacketConn(packetChan), ConnectionIDLength: 10, - Tracer: tracer, + Tracer: t, } tr.init(true) dropped := make(chan struct{}) @@ -328,11 +328,9 @@ var _ = Describe("Transport", func() { It("allows receiving non-QUIC packets", func() { remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} packetChan := make(chan packetToRead) - tracer := mocklogging.NewMockTracer(mockCtrl) tr := &Transport{ Conn: newMockPacketConn(packetChan), ConnectionIDLength: 10, - Tracer: tracer, } tr.init(true) receivedPacketChan := make(chan []byte) @@ -362,11 +360,11 @@ var _ = Describe("Transport", func() { It("drops non-QUIC packet if the application doesn't process them quickly enough", func() { remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} packetChan := make(chan packetToRead) - tracer := mocklogging.NewMockTracer(mockCtrl) + t, tracer := mocklogging.NewMockTracer(mockCtrl) tr := &Transport{ Conn: newMockPacketConn(packetChan), ConnectionIDLength: 10, - Tracer: tracer, + Tracer: t, } tr.init(true)