Skip to content

Commit

Permalink
wire: use quicvarint.Parse to when parsing transport parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed May 5, 2024
1 parent 9e54f95 commit 0e6c728
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 75 deletions.
5 changes: 2 additions & 3 deletions fuzzing/transportparameters/fuzz.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package transportparameters

import (
"bytes"
"errors"
"fmt"

Expand Down Expand Up @@ -55,12 +54,12 @@ func fuzzTransportParameters(data []byte, sentByServer bool) int {

func fuzzTransportParametersForSessionTicket(data []byte) int {
tp := &wire.TransportParameters{}
if err := tp.UnmarshalFromSessionTicket(bytes.NewReader(data)); err != nil {
if err := tp.UnmarshalFromSessionTicket(data); err != nil {
return 0
}
b := tp.MarshalForSessionTicket(nil)
tp2 := &wire.TransportParameters{}
if err := tp2.UnmarshalFromSessionTicket(bytes.NewReader(b)); err != nil {
if err := tp2.UnmarshalFromSessionTicket(b); err != nil {
panic(err)
}
return 1
Expand Down
12 changes: 6 additions & 6 deletions internal/handshake/crypto_setup.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package handshake

import (
"bytes"
"context"
"crypto/tls"
"errors"
Expand Down Expand Up @@ -338,25 +337,26 @@ func (h *cryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (a
return false
}

func decodeDataFromSessionState(data []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) {
r := bytes.NewReader(data)
ver, err := quicvarint.Read(r)
func decodeDataFromSessionState(b []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) {
ver, l, err := quicvarint.Parse(b)
if err != nil {
return 0, nil, err
}
b = b[l:]
if ver != clientSessionStateRevision {
return 0, nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
}
rttEncoded, err := quicvarint.Read(r)
rttEncoded, l, err := quicvarint.Parse(b)
if err != nil {
return 0, nil, err
}
b = b[l:]
rtt := time.Duration(rttEncoded) * time.Microsecond
if !earlyData {
return rtt, nil, nil
}
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
if err := tp.UnmarshalFromSessionTicket(b); err != nil {
return 0, nil, err
}
return rtt, &tp, nil
Expand Down
12 changes: 6 additions & 6 deletions internal/handshake/session_ticket.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package handshake

import (
"bytes"
"errors"
"fmt"
"time"
Expand All @@ -28,25 +27,26 @@ func (t *sessionTicket) Marshal() []byte {
}

func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error {
r := bytes.NewReader(b)
rev, err := quicvarint.Read(r)
rev, l, err := quicvarint.Parse(b)
if err != nil {
return errors.New("failed to read session ticket revision")
}
b = b[l:]
if rev != sessionTicketRevision {
return fmt.Errorf("unknown session ticket revision: %d", rev)
}
rtt, err := quicvarint.Read(r)
rtt, l, err := quicvarint.Parse(b)
if err != nil {
return errors.New("failed to read RTT")
}
b = b[l:]
if using0RTT {
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
if err := tp.UnmarshalFromSessionTicket(b); err != nil {
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
}
t.Parameters = &tp
} else if r.Len() > 0 {
} else if len(b) > 0 {
return fmt.Errorf("the session ticket has more bytes than expected")
}
t.RTT = time.Duration(rtt) * time.Microsecond
Expand Down
6 changes: 3 additions & 3 deletions internal/wire/transport_parameter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ var _ = Describe("Transport Parameters", func() {
Expect(params.ValidFor0RTT(params)).To(BeTrue())
b := params.MarshalForSessionTicket(nil)
var tp TransportParameters
Expect(tp.UnmarshalFromSessionTicket(bytes.NewReader(b))).To(Succeed())
Expect(tp.UnmarshalFromSessionTicket(b)).To(Succeed())
Expect(tp.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal))
Expect(tp.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote))
Expect(tp.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni))
Expand All @@ -518,15 +518,15 @@ var _ = Describe("Transport Parameters", func() {

It("rejects the parameters if it can't parse them", func() {
var p TransportParameters
Expect(p.UnmarshalFromSessionTicket(bytes.NewReader([]byte("foobar")))).ToNot(Succeed())
Expect(p.UnmarshalFromSessionTicket([]byte("foobar"))).ToNot(Succeed())
})

It("rejects the parameters if the version changed", func() {
var p TransportParameters
data := p.MarshalForSessionTicket(nil)
b := quicvarint.Append(nil, transportParameterMarshalingVersion+1)
b = append(b, data[quicvarint.Len(transportParameterMarshalingVersion):]...)
Expect(p.UnmarshalFromSessionTicket(bytes.NewReader(b))).To(MatchError(fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1)))
Expect(p.UnmarshalFromSessionTicket(b)).To(MatchError(fmt.Sprintf("unknown transport parameter marshaling version: %d", transportParameterMarshalingVersion+1)))
})

Context("rejects the parameters if they changed", func() {
Expand Down
119 changes: 62 additions & 57 deletions internal/wire/transport_parameters.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package wire

import (
"bytes"
"crypto/rand"
"encoding/binary"
"errors"
Expand All @@ -13,7 +12,6 @@ import (

"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
)

Expand Down Expand Up @@ -89,7 +87,7 @@ type TransportParameters struct {

// Unmarshal the transport parameters
func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error {
if err := p.unmarshal(bytes.NewReader(data), sentBy, false); err != nil {
if err := p.unmarshal(data, sentBy, false); err != nil {
return &qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: err.Error(),
Expand All @@ -98,7 +96,7 @@ func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective
return nil
}

func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspective, fromSessionTicket bool) error {
func (p *TransportParameters) unmarshal(b []byte, sentBy protocol.Perspective, fromSessionTicket bool) error {
// needed to check that every parameter is only sent at most once
var parameterIDs []transportParameterID

Expand All @@ -112,18 +110,20 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
p.MaxAckDelay = protocol.DefaultMaxAckDelay
p.MaxDatagramFrameSize = protocol.InvalidByteCount

for r.Len() > 0 {
paramIDInt, err := quicvarint.Read(r)
for len(b) > 0 {
paramIDInt, l, err := quicvarint.Parse(b)
if err != nil {
return err
}
paramID := transportParameterID(paramIDInt)
paramLen, err := quicvarint.Read(r)
b = b[l:]
paramLen, l, err := quicvarint.Parse(b)
if err != nil {
return err
}
if uint64(r.Len()) < paramLen {
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", r.Len(), paramLen)
b = b[l:]
if uint64(len(b)) < paramLen {
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(b), paramLen)
}
parameterIDs = append(parameterIDs, paramID)
switch paramID {
Expand All @@ -141,16 +141,18 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
maxAckDelayParameterID,
maxDatagramFrameSizeParameterID,
ackDelayExponentParameterID:
if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil {
if err := p.readNumericTransportParameter(b, paramID, int(paramLen)); err != nil {
return err
}
b = b[paramLen:]
case preferredAddressParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent a preferred_address")
}
if err := p.readPreferredAddress(r, int(paramLen)); err != nil {
if err := p.readPreferredAddress(b, int(paramLen)); err != nil {
return err
}
b = b[paramLen:]
case disableActiveMigrationParameterID:
if paramLen != 0 {
return fmt.Errorf("wrong length for disable_active_migration: %d (expected empty)", paramLen)
Expand All @@ -164,25 +166,41 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
}
var token protocol.StatelessResetToken
r.Read(token[:])
if len(b) < len(token) {
return io.EOF
}
copy(token[:], b)
b = b[len(token):]
p.StatelessResetToken = &token
case originalDestinationConnectionIDParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent an original_destination_connection_id")
}
p.OriginalDestinationConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen))
if paramLen > protocol.MaxConnIDLen {
return protocol.ErrInvalidConnectionIDLen
}
p.OriginalDestinationConnectionID = protocol.ParseConnectionID(b[:paramLen])
b = b[paramLen:]
readOriginalDestinationConnectionID = true
case initialSourceConnectionIDParameterID:
p.InitialSourceConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen))
if paramLen > protocol.MaxConnIDLen {
return protocol.ErrInvalidConnectionIDLen
}
p.InitialSourceConnectionID = protocol.ParseConnectionID(b[:paramLen])
b = b[paramLen:]
readInitialSourceConnectionID = true
case retrySourceConnectionIDParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent a retry_source_connection_id")
}
connID, _ := protocol.ReadConnectionID(r, int(paramLen))
if paramLen > protocol.MaxConnIDLen {
return protocol.ErrInvalidConnectionIDLen
}
connID := protocol.ParseConnectionID(b[:paramLen])
b = b[paramLen:]
p.RetrySourceConnectionID = &connID
default:
r.Seek(int64(paramLen), io.SeekCurrent)
b = b[paramLen:]
}
}

Expand Down Expand Up @@ -212,60 +230,47 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
return nil
}

func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error {
remainingLen := r.Len()
func (p *TransportParameters) readPreferredAddress(b []byte, expectedLen int) error {
remainingLen := len(b)
pa := &PreferredAddress{}
var ipv4 [4]byte
if _, err := io.ReadFull(r, ipv4[:]); err != nil {
return err
if len(b) < 4+2+16+2+1 {
return io.EOF
}
port, err := utils.BigEndian.ReadUint16(r)
if err != nil {
return err
}
pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port)
var ipv4 [4]byte
copy(ipv4[:], b[:4])
port4 := binary.BigEndian.Uint16(b[4:])
b = b[4+2:]
pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port4)
var ipv6 [16]byte
if _, err := io.ReadFull(r, ipv6[:]); err != nil {
return err
}
port, err = utils.BigEndian.ReadUint16(r)
if err != nil {
return err
}
pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port)
connIDLen, err := r.ReadByte()
if err != nil {
return err
}
copy(ipv6[:], b[:16])
port6 := binary.BigEndian.Uint16(b[16:])
pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port6)
b = b[16+2:]
connIDLen := int(b[0])
b = b[1:]
if connIDLen == 0 || connIDLen > protocol.MaxConnIDLen {
return fmt.Errorf("invalid connection ID length: %d", connIDLen)
}
connID, err := protocol.ReadConnectionID(r, int(connIDLen))
if err != nil {
return err
}
pa.ConnectionID = connID
if _, err := io.ReadFull(r, pa.StatelessResetToken[:]); err != nil {
return err
if len(b) < connIDLen+len(pa.StatelessResetToken) {
return io.EOF
}
if bytesRead := remainingLen - r.Len(); bytesRead != expectedLen {
pa.ConnectionID = protocol.ParseConnectionID(b[:connIDLen])
b = b[connIDLen:]
copy(pa.StatelessResetToken[:], b)
b = b[len(pa.StatelessResetToken):]
if bytesRead := remainingLen - len(b); bytesRead != expectedLen {
return fmt.Errorf("expected preferred_address to be %d long, read %d bytes", expectedLen, bytesRead)
}
p.PreferredAddress = pa
return nil
}

func (p *TransportParameters) readNumericTransportParameter(
r *bytes.Reader,
paramID transportParameterID,
expectedLen int,
) error {
remainingLen := r.Len()
val, err := quicvarint.Read(r)
func (p *TransportParameters) readNumericTransportParameter(b []byte, paramID transportParameterID, expectedLen int) error {
val, l, err := quicvarint.Parse(b)
if err != nil {
return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err)
}
if remainingLen-r.Len() != expectedLen {
if l != expectedLen {
return fmt.Errorf("inconsistent transport parameter length for transport parameter %#x", paramID)
}
//nolint:exhaustive // This only covers the numeric transport parameters.
Expand Down Expand Up @@ -457,15 +462,15 @@ func (p *TransportParameters) MarshalForSessionTicket(b []byte) []byte {
}

// UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket.
func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error {
version, err := quicvarint.Read(r)
func (p *TransportParameters) UnmarshalFromSessionTicket(b []byte) error {
version, l, err := quicvarint.Parse(b)
if err != nil {
return err
}
if version != transportParameterMarshalingVersion {
return fmt.Errorf("unknown transport parameter marshaling version: %d", version)
}
return p.unmarshal(r, protocol.PerspectiveServer, true)
return p.unmarshal(b[l:], protocol.PerspectiveServer, true)
}

// ValidFor0RTT checks if the transport parameters match those saved in the session ticket.
Expand Down

0 comments on commit 0e6c728

Please sign in to comment.