Skip to content

Commit

Permalink
speed up marshaling of transport parameters (#3531)
Browse files Browse the repository at this point in the history
The speedup comes from multiple sources:
1. We now preallocate a byte slice, instead of appending multiple times.
2. Marshaling into a byte slice is faster than using a bytes.Buffer.
3. quicvarint.Write allocates, while quicvarint.Append doesn't.
  • Loading branch information
marten-seemann committed Aug 29, 2022
1 parent 3f1adfd commit 7023b52
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 78 deletions.
5 changes: 1 addition & 4 deletions fuzzing/transportparameters/cmd/corpus.go
@@ -1,7 +1,6 @@
package main

import (
"bytes"
"log"
"math"
"math/rand"
Expand Down Expand Up @@ -78,9 +77,7 @@ func main() {
}
data = tp.Marshal(pers)
} else {
b := &bytes.Buffer{}
tp.MarshalForSessionTicket(b)
data = b.Bytes()
data = tp.MarshalForSessionTicket(nil)
}
if err := helper.WriteCorpusFileWithPrefix("corpus", data, transportparameters.PrefixLen); err != nil {
log.Fatal(err)
Expand Down
5 changes: 2 additions & 3 deletions fuzzing/transportparameters/fuzz.go
Expand Up @@ -51,10 +51,9 @@ func fuzzTransportParametersForSessionTicket(data []byte) int {
if err := tp.UnmarshalFromSessionTicket(bytes.NewReader(data)); err != nil {
return 0
}
buf := &bytes.Buffer{}
tp.MarshalForSessionTicket(buf)
b := tp.MarshalForSessionTicket(nil)
tp2 := &wire.TransportParameters{}
if err := tp2.UnmarshalFromSessionTicket(bytes.NewReader(buf.Bytes())); err != nil {
if err := tp2.UnmarshalFromSessionTicket(bytes.NewReader(b)); err != nil {
panic(err)
}
return 1
Expand Down
9 changes: 4 additions & 5 deletions internal/handshake/crypto_setup.go
Expand Up @@ -432,11 +432,10 @@ func (h *cryptoSetup) handleTransportParameters(data []byte) {

// must be called after receiving the transport parameters
func (h *cryptoSetup) marshalDataForSessionState() []byte {
buf := &bytes.Buffer{}
quicvarint.Write(buf, clientSessionStateRevision)
quicvarint.Write(buf, uint64(h.rttStats.SmoothedRTT().Microseconds()))
h.peerParams.MarshalForSessionTicket(buf)
return buf.Bytes()
b := make([]byte, 0, 256)
b = quicvarint.Append(b, clientSessionStateRevision)
b = quicvarint.Append(b, uint64(h.rttStats.SmoothedRTT().Microseconds()))
return h.peerParams.MarshalForSessionTicket(b)
}

func (h *cryptoSetup) handleDataFromSessionState(data []byte) {
Expand Down
9 changes: 4 additions & 5 deletions internal/handshake/session_ticket.go
Expand Up @@ -18,11 +18,10 @@ type sessionTicket struct {
}

func (t *sessionTicket) Marshal() []byte {
b := &bytes.Buffer{}
quicvarint.Write(b, sessionTicketRevision)
quicvarint.Write(b, uint64(t.RTT.Microseconds()))
t.Parameters.MarshalForSessionTicket(b)
return b.Bytes()
b := make([]byte, 0, 256)
b = quicvarint.Append(b, sessionTicketRevision)
b = quicvarint.Append(b, uint64(t.RTT.Microseconds()))
return t.Parameters.MarshalForSessionTicket(b)
}

func (t *sessionTicket) Unmarshal(b []byte) error {
Expand Down
9 changes: 3 additions & 6 deletions internal/wire/transport_parameter_test.go
Expand Up @@ -486,10 +486,9 @@ var _ = Describe("Transport Parameters", func() {
ActiveConnectionIDLimit: getRandomValue(),
}
Expect(params.ValidFor0RTT(params)).To(BeTrue())
b := &bytes.Buffer{}
params.MarshalForSessionTicket(b)
b := params.MarshalForSessionTicket(nil)
var tp TransportParameters
Expect(tp.UnmarshalFromSessionTicket(bytes.NewReader(b.Bytes()))).To(Succeed())
Expect(tp.UnmarshalFromSessionTicket(bytes.NewReader(b))).To(Succeed())
Expect(tp.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal))
Expect(tp.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote))
Expect(tp.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni))
Expand All @@ -506,9 +505,7 @@ var _ = Describe("Transport Parameters", func() {

It("rejects the parameters if the version changed", func() {
var p TransportParameters
buf := &bytes.Buffer{}
p.MarshalForSessionTicket(buf)
data := buf.Bytes()
data := p.MarshalForSessionTicket(nil)
b := &bytes.Buffer{}
quicvarint.Write(b, transportParameterMarshalingVersion+1)
b.Write(data[quicvarint.Len(transportParameterMarshalingVersion):])
Expand Down
115 changes: 60 additions & 55 deletions internal/wire/transport_parameters.go
Expand Up @@ -2,6 +2,7 @@ package wire

import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -313,94 +314,98 @@ func (p *TransportParameters) readNumericTransportParameter(

// Marshal the transport parameters
func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
b := &bytes.Buffer{}
// Typical Transport Parameters consume around 110 bytes, depending on the exact values,
// especially the lengths of the Connection IDs.
// Allocate 256 bytes, so we won't have to grow the slice in any case.
b := make([]byte, 0, 256)

// add a greased value
quicvarint.Write(b, uint64(27+31*rand.Intn(100)))
b = quicvarint.Append(b, uint64(27+31*rand.Intn(100)))
length := rand.Intn(16)
randomData := make([]byte, length)
rand.Read(randomData)
quicvarint.Write(b, uint64(length))
b.Write(randomData)
b = quicvarint.Append(b, uint64(length))
b = b[:len(b)+length]
rand.Read(b[len(b)-length:])

// initial_max_stream_data_bidi_local
p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal))
b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal))
// initial_max_stream_data_bidi_remote
p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote))
b = p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote))
// initial_max_stream_data_uni
p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni))
b = p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni))
// initial_max_data
p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData))
b = p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData))
// initial_max_bidi_streams
p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum))
b = p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum))
// initial_max_uni_streams
p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum))
b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum))
// idle_timeout
p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond))
b = p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond))
// max_packet_size
p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(protocol.MaxPacketBufferSize))
b = p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(protocol.MaxPacketBufferSize))
// max_ack_delay
// Only send it if is different from the default value.
if p.MaxAckDelay != protocol.DefaultMaxAckDelay {
p.marshalVarintParam(b, maxAckDelayParameterID, uint64(p.MaxAckDelay/time.Millisecond))
b = p.marshalVarintParam(b, maxAckDelayParameterID, uint64(p.MaxAckDelay/time.Millisecond))
}
// ack_delay_exponent
// Only send it if is different from the default value.
if p.AckDelayExponent != protocol.DefaultAckDelayExponent {
p.marshalVarintParam(b, ackDelayExponentParameterID, uint64(p.AckDelayExponent))
b = p.marshalVarintParam(b, ackDelayExponentParameterID, uint64(p.AckDelayExponent))
}
// disable_active_migration
if p.DisableActiveMigration {
quicvarint.Write(b, uint64(disableActiveMigrationParameterID))
quicvarint.Write(b, 0)
b = quicvarint.Append(b, uint64(disableActiveMigrationParameterID))
b = quicvarint.Append(b, 0)
}
if pers == protocol.PerspectiveServer {
// stateless_reset_token
if p.StatelessResetToken != nil {
quicvarint.Write(b, uint64(statelessResetTokenParameterID))
quicvarint.Write(b, 16)
b.Write(p.StatelessResetToken[:])
b = quicvarint.Append(b, uint64(statelessResetTokenParameterID))
b = quicvarint.Append(b, 16)
b = append(b, p.StatelessResetToken[:]...)
}
// original_destination_connection_id
quicvarint.Write(b, uint64(originalDestinationConnectionIDParameterID))
quicvarint.Write(b, uint64(p.OriginalDestinationConnectionID.Len()))
b.Write(p.OriginalDestinationConnectionID.Bytes())
b = quicvarint.Append(b, uint64(originalDestinationConnectionIDParameterID))
b = quicvarint.Append(b, uint64(p.OriginalDestinationConnectionID.Len()))
b = append(b, p.OriginalDestinationConnectionID.Bytes()...)
// preferred_address
if p.PreferredAddress != nil {
quicvarint.Write(b, uint64(preferredAddressParameterID))
quicvarint.Write(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16)
b = quicvarint.Append(b, uint64(preferredAddressParameterID))
b = quicvarint.Append(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16)
ipv4 := p.PreferredAddress.IPv4
b.Write(ipv4[len(ipv4)-4:])
utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv4Port)
b.Write(p.PreferredAddress.IPv6)
utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv6Port)
b.WriteByte(uint8(p.PreferredAddress.ConnectionID.Len()))
b.Write(p.PreferredAddress.ConnectionID.Bytes())
b.Write(p.PreferredAddress.StatelessResetToken[:])
b = append(b, ipv4[len(ipv4)-4:]...)
b = append(b, []byte{0, 0}...)
binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv4Port)
b = append(b, p.PreferredAddress.IPv6...)
b = append(b, []byte{0, 0}...)
binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv6Port)
b = append(b, uint8(p.PreferredAddress.ConnectionID.Len()))
b = append(b, p.PreferredAddress.ConnectionID.Bytes()...)
b = append(b, p.PreferredAddress.StatelessResetToken[:]...)
}
}
// active_connection_id_limit
p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit)
b = p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit)
// initial_source_connection_id
quicvarint.Write(b, uint64(initialSourceConnectionIDParameterID))
quicvarint.Write(b, uint64(p.InitialSourceConnectionID.Len()))
b.Write(p.InitialSourceConnectionID.Bytes())
b = quicvarint.Append(b, uint64(initialSourceConnectionIDParameterID))
b = quicvarint.Append(b, uint64(p.InitialSourceConnectionID.Len()))
b = append(b, p.InitialSourceConnectionID.Bytes()...)
// retry_source_connection_id
if pers == protocol.PerspectiveServer && p.RetrySourceConnectionID != nil {
quicvarint.Write(b, uint64(retrySourceConnectionIDParameterID))
quicvarint.Write(b, uint64(p.RetrySourceConnectionID.Len()))
b.Write(p.RetrySourceConnectionID.Bytes())
b = quicvarint.Append(b, uint64(retrySourceConnectionIDParameterID))
b = quicvarint.Append(b, uint64(p.RetrySourceConnectionID.Len()))
b = append(b, p.RetrySourceConnectionID.Bytes()...)
}
if p.MaxDatagramFrameSize != protocol.InvalidByteCount {
p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize))
b = p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize))
}
return b.Bytes()
return b
}

func (p *TransportParameters) marshalVarintParam(b *bytes.Buffer, id transportParameterID, val uint64) {
quicvarint.Write(b, uint64(id))
quicvarint.Write(b, uint64(quicvarint.Len(val)))
quicvarint.Write(b, val)
func (p *TransportParameters) marshalVarintParam(b []byte, id transportParameterID, val uint64) []byte {
b = quicvarint.Append(b, uint64(id))
b = quicvarint.Append(b, uint64(quicvarint.Len(val)))
return quicvarint.Append(b, val)
}

// MarshalForSessionTicket marshals the transport parameters we save in the session ticket.
Expand All @@ -411,23 +416,23 @@ func (p *TransportParameters) marshalVarintParam(b *bytes.Buffer, id transportPa
// if the transport parameters changed.
// Since the session ticket is encrypted, the serialization format is defined by the server.
// For convenience, we use the same format that we also use for sending the transport parameters.
func (p *TransportParameters) MarshalForSessionTicket(b *bytes.Buffer) {
quicvarint.Write(b, transportParameterMarshalingVersion)
func (p *TransportParameters) MarshalForSessionTicket(b []byte) []byte {
b = quicvarint.Append(b, transportParameterMarshalingVersion)

// initial_max_stream_data_bidi_local
p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal))
b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal))
// initial_max_stream_data_bidi_remote
p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote))
b = p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote))
// initial_max_stream_data_uni
p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni))
b = p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni))
// initial_max_data
p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData))
b = p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData))
// initial_max_bidi_streams
p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum))
b = p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum))
// initial_max_uni_streams
p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum))
b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum))
// active_connection_id_limit
p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit)
return p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit)
}

// UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket.
Expand Down

0 comments on commit 7023b52

Please sign in to comment.