diff --git a/fuzzing/transportparameters/fuzz.go b/fuzzing/transportparameters/fuzz.go index d0dd975fdcc..267979ce3b1 100644 --- a/fuzzing/transportparameters/fuzz.go +++ b/fuzzing/transportparameters/fuzz.go @@ -1,7 +1,6 @@ package transportparameters import ( - "bytes" "errors" "fmt" @@ -55,12 +54,12 @@ func fuzzTransportParameters(data []byte, sentByServer bool) int { func fuzzTransportParametersForSessionTicket(data []byte) int { tp := &wire.TransportParameters{} - if err := tp.UnmarshalFromSessionTicket(bytes.NewReader(data)); err != nil { + if err := tp.UnmarshalFromSessionTicket(data); err != nil { return 0 } b := tp.MarshalForSessionTicket(nil) tp2 := &wire.TransportParameters{} - if err := tp2.UnmarshalFromSessionTicket(bytes.NewReader(b)); err != nil { + if err := tp2.UnmarshalFromSessionTicket(b); err != nil { panic(err) } return 1 diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index adf74fe7481..4273a4a6ad3 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -1,7 +1,6 @@ package handshake import ( - "bytes" "context" "crypto/tls" "errors" @@ -338,25 +337,26 @@ func (h *cryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (a return false } -func decodeDataFromSessionState(data []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) { - r := bytes.NewReader(data) - ver, err := quicvarint.Read(r) +func decodeDataFromSessionState(b []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) { + ver, l, err := quicvarint.Parse(b) if err != nil { return 0, nil, err } + b = b[l:] if ver != clientSessionStateRevision { return 0, nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision) } - rttEncoded, err := quicvarint.Read(r) + rttEncoded, l, err := quicvarint.Parse(b) if err != nil { return 0, nil, err } + b = b[l:] rtt := time.Duration(rttEncoded) * time.Microsecond if !earlyData { return rtt, nil, nil } var tp wire.TransportParameters - if err := tp.UnmarshalFromSessionTicket(r); err != nil { + if err := tp.UnmarshalFromSessionTicket(b); err != nil { return 0, nil, err } return rtt, &tp, nil diff --git a/internal/handshake/session_ticket.go b/internal/handshake/session_ticket.go index 9481af563e9..b67f0101e42 100644 --- a/internal/handshake/session_ticket.go +++ b/internal/handshake/session_ticket.go @@ -1,7 +1,6 @@ package handshake import ( - "bytes" "errors" "fmt" "time" @@ -28,25 +27,26 @@ func (t *sessionTicket) Marshal() []byte { } func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error { - r := bytes.NewReader(b) - rev, err := quicvarint.Read(r) + rev, l, err := quicvarint.Parse(b) if err != nil { return errors.New("failed to read session ticket revision") } + b = b[l:] if rev != sessionTicketRevision { return fmt.Errorf("unknown session ticket revision: %d", rev) } - rtt, err := quicvarint.Read(r) + rtt, l, err := quicvarint.Parse(b) if err != nil { return errors.New("failed to read RTT") } + b = b[l:] if using0RTT { var tp wire.TransportParameters - if err := tp.UnmarshalFromSessionTicket(r); err != nil { + if err := tp.UnmarshalFromSessionTicket(b); err != nil { return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) } t.Parameters = &tp - } else if r.Len() > 0 { + } else if len(b) > 0 { return fmt.Errorf("the session ticket has more bytes than expected") } t.RTT = time.Duration(rtt) * time.Microsecond diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index 04e31f85f8b..721ad62fb59 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -5,6 +5,7 @@ import ( "fmt" "math" "net/netip" + "testing" "time" "golang.org/x/exp/rand" @@ -17,20 +18,20 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("Transport Parameters", func() { - getRandomValueUpTo := func(max int64) uint64 { - maxVals := []int64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4} - m := maxVals[int(rand.Int31n(4))] - if m > max { - m = max - } - return uint64(rand.Int63n(m)) +func getRandomValueUpTo(max int64) uint64 { + maxVals := []int64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4} + m := maxVals[int(rand.Int31n(4))] + if m > max { + m = max } + return uint64(rand.Int63n(m)) +} - getRandomValue := func() uint64 { - return getRandomValueUpTo(math.MaxInt64) - } +func getRandomValue() uint64 { + return getRandomValueUpTo(math.MaxInt64) +} +var _ = Describe("Transport Parameters", func() { BeforeEach(func() { rand.Seed(uint64(GinkgoRandomSeed())) }) @@ -504,7 +505,7 @@ var _ = Describe("Transport Parameters", func() { Expect(params.ValidFor0RTT(params)).To(BeTrue()) b := params.MarshalForSessionTicket(nil) var tp TransportParameters - Expect(tp.UnmarshalFromSessionTicket(bytes.NewReader(b))).To(Succeed()) + Expect(tp.UnmarshalFromSessionTicket(b)).To(Succeed()) Expect(tp.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal)) Expect(tp.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote)) Expect(tp.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni)) @@ -517,7 +518,7 @@ var _ = Describe("Transport Parameters", func() { It("rejects the parameters if it can't parse them", func() { var p TransportParameters - Expect(p.UnmarshalFromSessionTicket(bytes.NewReader([]byte("foobar")))).ToNot(Succeed()) + Expect(p.UnmarshalFromSessionTicket([]byte("foobar"))).ToNot(Succeed()) }) It("rejects the parameters if the version changed", func() { @@ -525,7 +526,7 @@ var _ = Describe("Transport Parameters", func() { data := p.MarshalForSessionTicket(nil) b := quicvarint.Append(nil, transportParameterMarshalingVersion+1) b = append(b, data[quicvarint.Len(transportParameterMarshalingVersion):]...) - Expect(p.UnmarshalFromSessionTicket(bytes.NewReader(b))).To(MatchError(fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1))) + Expect(p.UnmarshalFromSessionTicket(b)).To(MatchError(fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1))) }) Context("rejects the parameters if they changed", func() { @@ -722,3 +723,66 @@ var _ = Describe("Transport Parameters", func() { }) }) }) + +func BenchmarkTransportParameters(b *testing.B) { + b.Run("without preferred address", func(b *testing.B) { benchmarkTransportParameters(b, false) }) + b.Run("with preferred address", func(b *testing.B) { benchmarkTransportParameters(b, true) }) +} + +func benchmarkTransportParameters(b *testing.B, withPreferredAddress bool) { + var token protocol.StatelessResetToken + rand.Read(token[:]) + rcid := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}) + params := &TransportParameters{ + InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), + InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), + InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()), + InitialMaxData: protocol.ByteCount(getRandomValue()), + MaxIdleTimeout: 0xcafe * time.Second, + MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), + MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), + DisableActiveMigration: true, + StatelessResetToken: &token, + OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), + RetrySourceConnectionID: &rcid, + AckDelayExponent: 13, + MaxAckDelay: 42 * time.Millisecond, + ActiveConnectionIDLimit: 2 + getRandomValueUpTo(math.MaxInt64-2), + MaxDatagramFrameSize: protocol.ByteCount(getRandomValue()), + } + var token2 protocol.StatelessResetToken + rand.Read(token2[:]) + if withPreferredAddress { + var ip4 [4]byte + var ip6 [16]byte + rand.Read(ip4[:]) + rand.Read(ip6[:]) + params.PreferredAddress = &PreferredAddress{ + IPv4: netip.AddrPortFrom(netip.AddrFrom4(ip4), 1234), + IPv6: netip.AddrPortFrom(netip.AddrFrom16(ip6), 4321), + ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + StatelessResetToken: token2, + } + } + data := params.Marshal(protocol.PerspectiveServer) + + b.ResetTimer() + b.ReportAllocs() + var p TransportParameters + for i := 0; i < b.N; i++ { + if err := p.Unmarshal(data, protocol.PerspectiveServer); err != nil { + b.Fatal(err) + } + // check a few fields + if p.DisableActiveMigration != params.DisableActiveMigration || + p.InitialMaxStreamDataBidiLocal != params.InitialMaxStreamDataBidiLocal || + *p.StatelessResetToken != *params.StatelessResetToken || + p.AckDelayExponent != params.AckDelayExponent { + b.Fatalf("params mismatch: %v vs %v", p, params) + } + if withPreferredAddress && *p.PreferredAddress != *params.PreferredAddress { + b.Fatalf("preferred address mismatch: %v vs %v", p.PreferredAddress, params.PreferredAddress) + } + } +} diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index c03be3cd739..43c3c184c3c 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -1,7 +1,6 @@ package wire import ( - "bytes" "crypto/rand" "encoding/binary" "errors" @@ -13,7 +12,6 @@ import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" - "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/quicvarint" ) @@ -89,7 +87,7 @@ type TransportParameters struct { // Unmarshal the transport parameters func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error { - if err := p.unmarshal(bytes.NewReader(data), sentBy, false); err != nil { + if err := p.unmarshal(data, sentBy, false); err != nil { return &qerr.TransportError{ ErrorCode: qerr.TransportParameterError, ErrorMessage: err.Error(), @@ -98,7 +96,7 @@ func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective return nil } -func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspective, fromSessionTicket bool) error { +func (p *TransportParameters) unmarshal(b []byte, sentBy protocol.Perspective, fromSessionTicket bool) error { // needed to check that every parameter is only sent at most once var parameterIDs []transportParameterID @@ -112,18 +110,20 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec p.MaxAckDelay = protocol.DefaultMaxAckDelay p.MaxDatagramFrameSize = protocol.InvalidByteCount - for r.Len() > 0 { - paramIDInt, err := quicvarint.Read(r) + for len(b) > 0 { + paramIDInt, l, err := quicvarint.Parse(b) if err != nil { return err } paramID := transportParameterID(paramIDInt) - paramLen, err := quicvarint.Read(r) + b = b[l:] + paramLen, l, err := quicvarint.Parse(b) if err != nil { return err } - if uint64(r.Len()) < paramLen { - return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", r.Len(), paramLen) + b = b[l:] + if uint64(len(b)) < paramLen { + return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(b), paramLen) } parameterIDs = append(parameterIDs, paramID) switch paramID { @@ -141,16 +141,18 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec maxAckDelayParameterID, maxDatagramFrameSizeParameterID, ackDelayExponentParameterID: - if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil { + if err := p.readNumericTransportParameter(b, paramID, int(paramLen)); err != nil { return err } + b = b[paramLen:] case preferredAddressParameterID: if sentBy == protocol.PerspectiveClient { return errors.New("client sent a preferred_address") } - if err := p.readPreferredAddress(r, int(paramLen)); err != nil { + if err := p.readPreferredAddress(b, int(paramLen)); err != nil { return err } + b = b[paramLen:] case disableActiveMigrationParameterID: if paramLen != 0 { return fmt.Errorf("wrong length for disable_active_migration: %d (expected empty)", paramLen) @@ -164,25 +166,41 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen) } var token protocol.StatelessResetToken - r.Read(token[:]) + if len(b) < len(token) { + return io.EOF + } + copy(token[:], b) + b = b[len(token):] p.StatelessResetToken = &token case originalDestinationConnectionIDParameterID: if sentBy == protocol.PerspectiveClient { return errors.New("client sent an original_destination_connection_id") } - p.OriginalDestinationConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen)) + if paramLen > protocol.MaxConnIDLen { + return protocol.ErrInvalidConnectionIDLen + } + p.OriginalDestinationConnectionID = protocol.ParseConnectionID(b[:paramLen]) + b = b[paramLen:] readOriginalDestinationConnectionID = true case initialSourceConnectionIDParameterID: - p.InitialSourceConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen)) + if paramLen > protocol.MaxConnIDLen { + return protocol.ErrInvalidConnectionIDLen + } + p.InitialSourceConnectionID = protocol.ParseConnectionID(b[:paramLen]) + b = b[paramLen:] readInitialSourceConnectionID = true case retrySourceConnectionIDParameterID: if sentBy == protocol.PerspectiveClient { return errors.New("client sent a retry_source_connection_id") } - connID, _ := protocol.ReadConnectionID(r, int(paramLen)) + if paramLen > protocol.MaxConnIDLen { + return protocol.ErrInvalidConnectionIDLen + } + connID := protocol.ParseConnectionID(b[:paramLen]) + b = b[paramLen:] p.RetrySourceConnectionID = &connID default: - r.Seek(int64(paramLen), io.SeekCurrent) + b = b[paramLen:] } } @@ -212,60 +230,47 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec return nil } -func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error { - remainingLen := r.Len() +func (p *TransportParameters) readPreferredAddress(b []byte, expectedLen int) error { + remainingLen := len(b) pa := &PreferredAddress{} - var ipv4 [4]byte - if _, err := io.ReadFull(r, ipv4[:]); err != nil { - return err + if len(b) < 4+2+16+2+1 { + return io.EOF } - port, err := utils.BigEndian.ReadUint16(r) - if err != nil { - return err - } - pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port) + var ipv4 [4]byte + copy(ipv4[:], b[:4]) + port4 := binary.BigEndian.Uint16(b[4:]) + b = b[4+2:] + pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port4) var ipv6 [16]byte - if _, err := io.ReadFull(r, ipv6[:]); err != nil { - return err - } - port, err = utils.BigEndian.ReadUint16(r) - if err != nil { - return err - } - pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port) - connIDLen, err := r.ReadByte() - if err != nil { - return err - } + copy(ipv6[:], b[:16]) + port6 := binary.BigEndian.Uint16(b[16:]) + pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port6) + b = b[16+2:] + connIDLen := int(b[0]) + b = b[1:] if connIDLen == 0 || connIDLen > protocol.MaxConnIDLen { return fmt.Errorf("invalid connection ID length: %d", connIDLen) } - connID, err := protocol.ReadConnectionID(r, int(connIDLen)) - if err != nil { - return err - } - pa.ConnectionID = connID - if _, err := io.ReadFull(r, pa.StatelessResetToken[:]); err != nil { - return err + if len(b) < connIDLen+len(pa.StatelessResetToken) { + return io.EOF } - if bytesRead := remainingLen - r.Len(); bytesRead != expectedLen { + pa.ConnectionID = protocol.ParseConnectionID(b[:connIDLen]) + b = b[connIDLen:] + copy(pa.StatelessResetToken[:], b) + b = b[len(pa.StatelessResetToken):] + if bytesRead := remainingLen - len(b); bytesRead != expectedLen { return fmt.Errorf("expected preferred_address to be %d long, read %d bytes", expectedLen, bytesRead) } p.PreferredAddress = pa return nil } -func (p *TransportParameters) readNumericTransportParameter( - r *bytes.Reader, - paramID transportParameterID, - expectedLen int, -) error { - remainingLen := r.Len() - val, err := quicvarint.Read(r) +func (p *TransportParameters) readNumericTransportParameter(b []byte, paramID transportParameterID, expectedLen int) error { + val, l, err := quicvarint.Parse(b) if err != nil { return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err) } - if remainingLen-r.Len() != expectedLen { + if l != expectedLen { return fmt.Errorf("inconsistent transport parameter length for transport parameter %#x", paramID) } //nolint:exhaustive // This only covers the numeric transport parameters. @@ -457,15 +462,15 @@ func (p *TransportParameters) MarshalForSessionTicket(b []byte) []byte { } // UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket. -func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error { - version, err := quicvarint.Read(r) +func (p *TransportParameters) UnmarshalFromSessionTicket(b []byte) error { + version, l, err := quicvarint.Parse(b) if err != nil { return err } if version != transportParameterMarshalingVersion { return fmt.Errorf("unknown transport parameter marshaling version: %d", version) } - return p.unmarshal(r, protocol.PerspectiveServer, true) + return p.unmarshal(b[l:], protocol.PerspectiveServer, true) } // ValidFor0RTT checks if the transport parameters match those saved in the session ticket.