diff --git a/fuzzing/transportparameters/cmd/corpus.go b/fuzzing/transportparameters/cmd/corpus.go index 0f4720610fc..e427293dc76 100644 --- a/fuzzing/transportparameters/cmd/corpus.go +++ b/fuzzing/transportparameters/cmd/corpus.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "log" "math" "math/rand" @@ -78,9 +77,7 @@ func main() { } data = tp.Marshal(pers) } else { - b := &bytes.Buffer{} - tp.MarshalForSessionTicket(b) - data = b.Bytes() + data = tp.MarshalForSessionTicket(nil) } if err := helper.WriteCorpusFileWithPrefix("corpus", data, transportparameters.PrefixLen); err != nil { log.Fatal(err) diff --git a/fuzzing/transportparameters/fuzz.go b/fuzzing/transportparameters/fuzz.go index a11fc5a1379..4b14e2d6950 100644 --- a/fuzzing/transportparameters/fuzz.go +++ b/fuzzing/transportparameters/fuzz.go @@ -51,10 +51,9 @@ func fuzzTransportParametersForSessionTicket(data []byte) int { if err := tp.UnmarshalFromSessionTicket(bytes.NewReader(data)); err != nil { return 0 } - buf := &bytes.Buffer{} - tp.MarshalForSessionTicket(buf) + b := tp.MarshalForSessionTicket(nil) tp2 := &wire.TransportParameters{} - if err := tp2.UnmarshalFromSessionTicket(bytes.NewReader(buf.Bytes())); err != nil { + if err := tp2.UnmarshalFromSessionTicket(bytes.NewReader(b)); err != nil { panic(err) } return 1 diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 31d9bf0aa40..9a19224a300 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -432,11 +432,10 @@ func (h *cryptoSetup) handleTransportParameters(data []byte) { // must be called after receiving the transport parameters func (h *cryptoSetup) marshalDataForSessionState() []byte { - buf := &bytes.Buffer{} - quicvarint.Write(buf, clientSessionStateRevision) - quicvarint.Write(buf, uint64(h.rttStats.SmoothedRTT().Microseconds())) - h.peerParams.MarshalForSessionTicket(buf) - return buf.Bytes() + b := make([]byte, 0, 256) + b = quicvarint.Append(b, clientSessionStateRevision) + b = quicvarint.Append(b, uint64(h.rttStats.SmoothedRTT().Microseconds())) + return h.peerParams.MarshalForSessionTicket(b) } func (h *cryptoSetup) handleDataFromSessionState(data []byte) { diff --git a/internal/handshake/session_ticket.go b/internal/handshake/session_ticket.go index 75cc04f987d..58b57c5a495 100644 --- a/internal/handshake/session_ticket.go +++ b/internal/handshake/session_ticket.go @@ -18,11 +18,10 @@ type sessionTicket struct { } func (t *sessionTicket) Marshal() []byte { - b := &bytes.Buffer{} - quicvarint.Write(b, sessionTicketRevision) - quicvarint.Write(b, uint64(t.RTT.Microseconds())) - t.Parameters.MarshalForSessionTicket(b) - return b.Bytes() + b := make([]byte, 0, 256) + b = quicvarint.Append(b, sessionTicketRevision) + b = quicvarint.Append(b, uint64(t.RTT.Microseconds())) + return t.Parameters.MarshalForSessionTicket(b) } func (t *sessionTicket) Unmarshal(b []byte) error { diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index b5f478fb79d..68417bb7a07 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -486,10 +486,9 @@ var _ = Describe("Transport Parameters", func() { ActiveConnectionIDLimit: getRandomValue(), } Expect(params.ValidFor0RTT(params)).To(BeTrue()) - b := &bytes.Buffer{} - params.MarshalForSessionTicket(b) + b := params.MarshalForSessionTicket(nil) var tp TransportParameters - Expect(tp.UnmarshalFromSessionTicket(bytes.NewReader(b.Bytes()))).To(Succeed()) + Expect(tp.UnmarshalFromSessionTicket(bytes.NewReader(b))).To(Succeed()) Expect(tp.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal)) Expect(tp.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote)) Expect(tp.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni)) @@ -506,9 +505,7 @@ var _ = Describe("Transport Parameters", func() { It("rejects the parameters if the version changed", func() { var p TransportParameters - buf := &bytes.Buffer{} - p.MarshalForSessionTicket(buf) - data := buf.Bytes() + data := p.MarshalForSessionTicket(nil) b := &bytes.Buffer{} quicvarint.Write(b, transportParameterMarshalingVersion+1) b.Write(data[quicvarint.Len(transportParameterMarshalingVersion):]) diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index 65bcfd55884..2ac374006fd 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "encoding/binary" "errors" "fmt" "io" @@ -313,94 +314,98 @@ func (p *TransportParameters) readNumericTransportParameter( // Marshal the transport parameters func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte { - b := &bytes.Buffer{} + // Typical Transport Parameters consume around 110 bytes, depending on the exact values, + // especially the lengths of the Connection IDs. + // Allocate 256 bytes, so we won't have to grow the slice in any case. + b := make([]byte, 0, 256) // add a greased value - quicvarint.Write(b, uint64(27+31*rand.Intn(100))) + b = quicvarint.Append(b, uint64(27+31*rand.Intn(100))) length := rand.Intn(16) - randomData := make([]byte, length) - rand.Read(randomData) - quicvarint.Write(b, uint64(length)) - b.Write(randomData) + b = quicvarint.Append(b, uint64(length)) + b = b[:len(b)+length] + rand.Read(b[len(b)-length:]) // initial_max_stream_data_bidi_local - p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) + b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) // initial_max_stream_data_bidi_remote - p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) + b = p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) // initial_max_stream_data_uni - p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) + b = p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) // initial_max_data - p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) + b = p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) // initial_max_bidi_streams - p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) + b = p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) // initial_max_uni_streams - p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) + b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) // idle_timeout - p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond)) + b = p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond)) // max_packet_size - p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(protocol.MaxPacketBufferSize)) + b = p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(protocol.MaxPacketBufferSize)) // max_ack_delay // Only send it if is different from the default value. if p.MaxAckDelay != protocol.DefaultMaxAckDelay { - p.marshalVarintParam(b, maxAckDelayParameterID, uint64(p.MaxAckDelay/time.Millisecond)) + b = p.marshalVarintParam(b, maxAckDelayParameterID, uint64(p.MaxAckDelay/time.Millisecond)) } // ack_delay_exponent // Only send it if is different from the default value. if p.AckDelayExponent != protocol.DefaultAckDelayExponent { - p.marshalVarintParam(b, ackDelayExponentParameterID, uint64(p.AckDelayExponent)) + b = p.marshalVarintParam(b, ackDelayExponentParameterID, uint64(p.AckDelayExponent)) } // disable_active_migration if p.DisableActiveMigration { - quicvarint.Write(b, uint64(disableActiveMigrationParameterID)) - quicvarint.Write(b, 0) + b = quicvarint.Append(b, uint64(disableActiveMigrationParameterID)) + b = quicvarint.Append(b, 0) } if pers == protocol.PerspectiveServer { // stateless_reset_token if p.StatelessResetToken != nil { - quicvarint.Write(b, uint64(statelessResetTokenParameterID)) - quicvarint.Write(b, 16) - b.Write(p.StatelessResetToken[:]) + b = quicvarint.Append(b, uint64(statelessResetTokenParameterID)) + b = quicvarint.Append(b, 16) + b = append(b, p.StatelessResetToken[:]...) } // original_destination_connection_id - quicvarint.Write(b, uint64(originalDestinationConnectionIDParameterID)) - quicvarint.Write(b, uint64(p.OriginalDestinationConnectionID.Len())) - b.Write(p.OriginalDestinationConnectionID.Bytes()) + b = quicvarint.Append(b, uint64(originalDestinationConnectionIDParameterID)) + b = quicvarint.Append(b, uint64(p.OriginalDestinationConnectionID.Len())) + b = append(b, p.OriginalDestinationConnectionID.Bytes()...) // preferred_address if p.PreferredAddress != nil { - quicvarint.Write(b, uint64(preferredAddressParameterID)) - quicvarint.Write(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16) + b = quicvarint.Append(b, uint64(preferredAddressParameterID)) + b = quicvarint.Append(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16) ipv4 := p.PreferredAddress.IPv4 - b.Write(ipv4[len(ipv4)-4:]) - utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv4Port) - b.Write(p.PreferredAddress.IPv6) - utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv6Port) - b.WriteByte(uint8(p.PreferredAddress.ConnectionID.Len())) - b.Write(p.PreferredAddress.ConnectionID.Bytes()) - b.Write(p.PreferredAddress.StatelessResetToken[:]) + b = append(b, ipv4[len(ipv4)-4:]...) + b = append(b, []byte{0, 0}...) + binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv4Port) + b = append(b, p.PreferredAddress.IPv6...) + b = append(b, []byte{0, 0}...) + binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv6Port) + b = append(b, uint8(p.PreferredAddress.ConnectionID.Len())) + b = append(b, p.PreferredAddress.ConnectionID.Bytes()...) + b = append(b, p.PreferredAddress.StatelessResetToken[:]...) } } // active_connection_id_limit - p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) + b = p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) // initial_source_connection_id - quicvarint.Write(b, uint64(initialSourceConnectionIDParameterID)) - quicvarint.Write(b, uint64(p.InitialSourceConnectionID.Len())) - b.Write(p.InitialSourceConnectionID.Bytes()) + b = quicvarint.Append(b, uint64(initialSourceConnectionIDParameterID)) + b = quicvarint.Append(b, uint64(p.InitialSourceConnectionID.Len())) + b = append(b, p.InitialSourceConnectionID.Bytes()...) // retry_source_connection_id if pers == protocol.PerspectiveServer && p.RetrySourceConnectionID != nil { - quicvarint.Write(b, uint64(retrySourceConnectionIDParameterID)) - quicvarint.Write(b, uint64(p.RetrySourceConnectionID.Len())) - b.Write(p.RetrySourceConnectionID.Bytes()) + b = quicvarint.Append(b, uint64(retrySourceConnectionIDParameterID)) + b = quicvarint.Append(b, uint64(p.RetrySourceConnectionID.Len())) + b = append(b, p.RetrySourceConnectionID.Bytes()...) } if p.MaxDatagramFrameSize != protocol.InvalidByteCount { - p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize)) + b = p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize)) } - return b.Bytes() + return b } -func (p *TransportParameters) marshalVarintParam(b *bytes.Buffer, id transportParameterID, val uint64) { - quicvarint.Write(b, uint64(id)) - quicvarint.Write(b, uint64(quicvarint.Len(val))) - quicvarint.Write(b, val) +func (p *TransportParameters) marshalVarintParam(b []byte, id transportParameterID, val uint64) []byte { + b = quicvarint.Append(b, uint64(id)) + b = quicvarint.Append(b, uint64(quicvarint.Len(val))) + return quicvarint.Append(b, val) } // MarshalForSessionTicket marshals the transport parameters we save in the session ticket. @@ -411,23 +416,23 @@ func (p *TransportParameters) marshalVarintParam(b *bytes.Buffer, id transportPa // if the transport parameters changed. // Since the session ticket is encrypted, the serialization format is defined by the server. // For convenience, we use the same format that we also use for sending the transport parameters. -func (p *TransportParameters) MarshalForSessionTicket(b *bytes.Buffer) { - quicvarint.Write(b, transportParameterMarshalingVersion) +func (p *TransportParameters) MarshalForSessionTicket(b []byte) []byte { + b = quicvarint.Append(b, transportParameterMarshalingVersion) // initial_max_stream_data_bidi_local - p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) + b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) // initial_max_stream_data_bidi_remote - p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) + b = p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) // initial_max_stream_data_uni - p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) + b = p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) // initial_max_data - p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) + b = p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) // initial_max_bidi_streams - p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) + b = p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) // initial_max_uni_streams - p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) + b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) // active_connection_id_limit - p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) + return p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) } // UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket.