Skip to content

Commit

Permalink
wire: simplify tracking of parsed length for Long Header parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed May 5, 2024
1 parent 130438d commit d9810a8
Showing 1 changed file with 14 additions and 22 deletions.
36 changes: 14 additions & 22 deletions internal/wire/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,44 +175,40 @@ func parseHeader(b []byte) (*Header, error) {
return h, err
}

func (h *Header) parseLongHeader(b []byte) (l protocol.ByteCount, err error) {
func (h *Header) parseLongHeader(b []byte) (protocol.ByteCount, error) {
startLen := len(b)
if len(b) < 5 {
return 0, io.EOF
}
h.Version = protocol.Version(binary.BigEndian.Uint32(b[:4]))
l = 4
if h.Version != 0 && h.typeByte&0x40 == 0 {
return l, errors.New("not a QUIC packet")
return protocol.ByteCount(startLen - len(b)), errors.New("not a QUIC packet")
}
destConnIDLen := int(b[4])
l++
if destConnIDLen > protocol.MaxConnIDLen {
return l, protocol.ErrInvalidConnectionIDLen
return protocol.ByteCount(startLen - len(b)), protocol.ErrInvalidConnectionIDLen
}
b = b[5:]
if len(b) < destConnIDLen+1 {
return l, io.EOF
return protocol.ByteCount(startLen - len(b)), io.EOF
}
h.DestConnectionID = protocol.ParseConnectionID(b[:destConnIDLen])
l += protocol.ByteCount(destConnIDLen)
srcConnIDLen := int(b[destConnIDLen])
l++
if srcConnIDLen > protocol.MaxConnIDLen {
return l, protocol.ErrInvalidConnectionIDLen
return protocol.ByteCount(startLen - len(b)), protocol.ErrInvalidConnectionIDLen
}
b = b[destConnIDLen+1:]
if len(b) < srcConnIDLen {
return l, io.EOF
return protocol.ByteCount(startLen - len(b)), io.EOF
}
h.SrcConnectionID = protocol.ParseConnectionID(b[:srcConnIDLen])
l += protocol.ByteCount(srcConnIDLen)
b = b[srcConnIDLen:]
if h.Version == 0 { // version negotiation packet
return l, nil
return protocol.ByteCount(startLen - len(b)), nil
}
// If we don't understand the version, we have no idea how to interpret the rest of the bytes
if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
return l, ErrUnsupportedVersion
return protocol.ByteCount(startLen - len(b)), ErrUnsupportedVersion
}

if h.Version == protocol.Version2 {
Expand Down Expand Up @@ -242,37 +238,33 @@ func (h *Header) parseLongHeader(b []byte) (l protocol.ByteCount, err error) {
if h.Type == protocol.PacketTypeRetry {
tokenLen := len(b) - 16
if tokenLen <= 0 {
return l, io.EOF
return protocol.ByteCount(startLen - len(b)), io.EOF
}
h.Token = make([]byte, tokenLen)
copy(h.Token, b[:tokenLen])
l += protocol.ByteCount(tokenLen)
return l + 16, nil
return protocol.ByteCount(startLen-len(b)+tokenLen) + 16, nil
}

if h.Type == protocol.PacketTypeInitial {
tokenLen, n, err := quicvarint.Parse(b)
l += protocol.ByteCount(n)
if err != nil {
return l, err
return protocol.ByteCount(startLen - len(b)), err
}
b = b[n:]
if tokenLen > uint64(len(b)) {
return l, io.EOF
return protocol.ByteCount(startLen - len(b)), io.EOF
}
l += protocol.ByteCount(tokenLen)
h.Token = make([]byte, tokenLen)
copy(h.Token, b[:tokenLen])
b = b[tokenLen:]
}

pl, n, err := quicvarint.Parse(b)
l += protocol.ByteCount(n)
if err != nil {
return 0, err
}
h.Length = protocol.ByteCount(pl)
return l, nil
return protocol.ByteCount(startLen - len(b) + n), nil
}

// ParsedLen returns the number of bytes that were consumed when parsing the header
Expand Down

0 comments on commit d9810a8

Please sign in to comment.