Skip to content

Commit

Permalink
wire: use quicvarint.Parse to when parsing transport parameters (#4482)
Browse files Browse the repository at this point in the history
* wire: add a benchmark for parsing of transport parameters

* wire: use quicvarint.Parse to when parsing transport parameters
  • Loading branch information
marten-seemann committed May 5, 2024
1 parent bb6f066 commit 1514095
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 86 deletions.
5 changes: 2 additions & 3 deletions fuzzing/transportparameters/fuzz.go
@@ -1,7 +1,6 @@
package transportparameters

import (
"bytes"
"errors"
"fmt"

Expand Down Expand Up @@ -55,12 +54,12 @@ func fuzzTransportParameters(data []byte, sentByServer bool) int {

func fuzzTransportParametersForSessionTicket(data []byte) int {
tp := &wire.TransportParameters{}
if err := tp.UnmarshalFromSessionTicket(bytes.NewReader(data)); err != nil {
if err := tp.UnmarshalFromSessionTicket(data); err != nil {
return 0
}
b := tp.MarshalForSessionTicket(nil)
tp2 := &wire.TransportParameters{}
if err := tp2.UnmarshalFromSessionTicket(bytes.NewReader(b)); err != nil {
if err := tp2.UnmarshalFromSessionTicket(b); err != nil {
panic(err)
}
return 1
Expand Down
12 changes: 6 additions & 6 deletions internal/handshake/crypto_setup.go
@@ -1,7 +1,6 @@
package handshake

import (
"bytes"
"context"
"crypto/tls"
"errors"
Expand Down Expand Up @@ -338,25 +337,26 @@ func (h *cryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (a
return false
}

func decodeDataFromSessionState(data []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) {
r := bytes.NewReader(data)
ver, err := quicvarint.Read(r)
func decodeDataFromSessionState(b []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) {
ver, l, err := quicvarint.Parse(b)
if err != nil {
return 0, nil, err
}
b = b[l:]
if ver != clientSessionStateRevision {
return 0, nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
}
rttEncoded, err := quicvarint.Read(r)
rttEncoded, l, err := quicvarint.Parse(b)
if err != nil {
return 0, nil, err
}
b = b[l:]
rtt := time.Duration(rttEncoded) * time.Microsecond
if !earlyData {
return rtt, nil, nil
}
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
if err := tp.UnmarshalFromSessionTicket(b); err != nil {
return 0, nil, err
}
return rtt, &tp, nil
Expand Down
12 changes: 6 additions & 6 deletions internal/handshake/session_ticket.go
@@ -1,7 +1,6 @@
package handshake

import (
"bytes"
"errors"
"fmt"
"time"
Expand All @@ -28,25 +27,26 @@ func (t *sessionTicket) Marshal() []byte {
}

func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error {
r := bytes.NewReader(b)
rev, err := quicvarint.Read(r)
rev, l, err := quicvarint.Parse(b)
if err != nil {
return errors.New("failed to read session ticket revision")
}
b = b[l:]
if rev != sessionTicketRevision {
return fmt.Errorf("unknown session ticket revision: %d", rev)
}
rtt, err := quicvarint.Read(r)
rtt, l, err := quicvarint.Parse(b)
if err != nil {
return errors.New("failed to read RTT")
}
b = b[l:]
if using0RTT {
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
if err := tp.UnmarshalFromSessionTicket(b); err != nil {
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
}
t.Parameters = &tp
} else if r.Len() > 0 {
} else if len(b) > 0 {
return fmt.Errorf("the session ticket has more bytes than expected")
}
t.RTT = time.Duration(rtt) * time.Microsecond
Expand Down
92 changes: 78 additions & 14 deletions internal/wire/transport_parameter_test.go
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"math"
"net/netip"
"testing"
"time"

"golang.org/x/exp/rand"
Expand All @@ -17,20 +18,20 @@ import (
. "github.com/onsi/gomega"
)

var _ = Describe("Transport Parameters", func() {
getRandomValueUpTo := func(max int64) uint64 {
maxVals := []int64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4}
m := maxVals[int(rand.Int31n(4))]
if m > max {
m = max
}
return uint64(rand.Int63n(m))
func getRandomValueUpTo(max int64) uint64 {
maxVals := []int64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4}
m := maxVals[int(rand.Int31n(4))]
if m > max {
m = max
}
return uint64(rand.Int63n(m))
}

getRandomValue := func() uint64 {
return getRandomValueUpTo(math.MaxInt64)
}
func getRandomValue() uint64 {
return getRandomValueUpTo(math.MaxInt64)
}

var _ = Describe("Transport Parameters", func() {
BeforeEach(func() {
rand.Seed(uint64(GinkgoRandomSeed()))
})
Expand Down Expand Up @@ -504,7 +505,7 @@ var _ = Describe("Transport Parameters", func() {
Expect(params.ValidFor0RTT(params)).To(BeTrue())
b := params.MarshalForSessionTicket(nil)
var tp TransportParameters
Expect(tp.UnmarshalFromSessionTicket(bytes.NewReader(b))).To(Succeed())
Expect(tp.UnmarshalFromSessionTicket(b)).To(Succeed())
Expect(tp.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal))
Expect(tp.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote))
Expect(tp.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni))
Expand All @@ -517,15 +518,15 @@ var _ = Describe("Transport Parameters", func() {

It("rejects the parameters if it can't parse them", func() {
var p TransportParameters
Expect(p.UnmarshalFromSessionTicket(bytes.NewReader([]byte("foobar")))).ToNot(Succeed())
Expect(p.UnmarshalFromSessionTicket([]byte("foobar"))).ToNot(Succeed())
})

It("rejects the parameters if the version changed", func() {
var p TransportParameters
data := p.MarshalForSessionTicket(nil)
b := quicvarint.Append(nil, transportParameterMarshalingVersion+1)
b = append(b, data[quicvarint.Len(transportParameterMarshalingVersion):]...)
Expect(p.UnmarshalFromSessionTicket(bytes.NewReader(b))).To(MatchError(fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1)))
Expect(p.UnmarshalFromSessionTicket(b)).To(MatchError(fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1)))
})

Context("rejects the parameters if they changed", func() {
Expand Down Expand Up @@ -722,3 +723,66 @@ var _ = Describe("Transport Parameters", func() {
})
})
})

func BenchmarkTransportParameters(b *testing.B) {
b.Run("without preferred address", func(b *testing.B) { benchmarkTransportParameters(b, false) })
b.Run("with preferred address", func(b *testing.B) { benchmarkTransportParameters(b, true) })
}

func benchmarkTransportParameters(b *testing.B, withPreferredAddress bool) {
var token protocol.StatelessResetToken
rand.Read(token[:])
rcid := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})
params := &TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()),
InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()),
InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()),
InitialMaxData: protocol.ByteCount(getRandomValue()),
MaxIdleTimeout: 0xcafe * time.Second,
MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))),
MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))),
DisableActiveMigration: true,
StatelessResetToken: &token,
OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}),
InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
RetrySourceConnectionID: &rcid,
AckDelayExponent: 13,
MaxAckDelay: 42 * time.Millisecond,
ActiveConnectionIDLimit: 2 + getRandomValueUpTo(math.MaxInt64-2),
MaxDatagramFrameSize: protocol.ByteCount(getRandomValue()),
}
var token2 protocol.StatelessResetToken
rand.Read(token2[:])
if withPreferredAddress {
var ip4 [4]byte
var ip6 [16]byte
rand.Read(ip4[:])
rand.Read(ip6[:])
params.PreferredAddress = &PreferredAddress{
IPv4: netip.AddrPortFrom(netip.AddrFrom4(ip4), 1234),
IPv6: netip.AddrPortFrom(netip.AddrFrom16(ip6), 4321),
ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}),
StatelessResetToken: token2,
}
}
data := params.Marshal(protocol.PerspectiveServer)

b.ResetTimer()
b.ReportAllocs()
var p TransportParameters
for i := 0; i < b.N; i++ {
if err := p.Unmarshal(data, protocol.PerspectiveServer); err != nil {
b.Fatal(err)
}
// check a few fields
if p.DisableActiveMigration != params.DisableActiveMigration ||
p.InitialMaxStreamDataBidiLocal != params.InitialMaxStreamDataBidiLocal ||
*p.StatelessResetToken != *params.StatelessResetToken ||
p.AckDelayExponent != params.AckDelayExponent {
b.Fatalf("params mismatch: %v vs %v", p, params)
}
if withPreferredAddress && *p.PreferredAddress != *params.PreferredAddress {
b.Fatalf("preferred address mismatch: %v vs %v", p.PreferredAddress, params.PreferredAddress)
}
}
}

0 comments on commit 1514095

Please sign in to comment.