Skip to content

Commit

Permalink
simplify packing of long header ACK-only packets
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Sep 4, 2022
1 parent 9ef5301 commit 29cb0db
Showing 1 changed file with 40 additions and 60 deletions.
100 changes: 40 additions & 60 deletions packet_packer.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,23 +322,26 @@ func (p *packetPacker) packetLength(hdr *wire.ExtendedHeader, payload *payload)

func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) {
var pay *payload
var encLevel protocol.EncryptionLevel
var sealer sealer
var hdr *wire.ExtendedHeader
encLevel := protocol.EncryptionInitial
if !handshakeConfirmed {
var ack *wire.AckFrame
ack = p.acks.GetAckFrame(protocol.EncryptionInitial, true)
if ack != nil {
encLevel = protocol.EncryptionInitial
} else {
ack = p.acks.GetAckFrame(protocol.EncryptionHandshake, true)
if ack != nil {
encLevel = protocol.EncryptionHandshake
hdr, pay = p.maybeGetCryptoPacket(p.maxPacketSize, protocol.EncryptionInitial, true, true)
if pay != nil {
var err error
sealer, err = p.cryptoSetup.GetInitialSealer()
if err != nil {
return nil, err
}
}

if ack != nil {
pay = &payload{
ack: ack,
length: ack.Length(p.version),
} else {
encLevel = protocol.EncryptionHandshake
hdr, pay = p.maybeGetCryptoPacket(p.maxPacketSize, protocol.EncryptionHandshake, true, true)
if pay != nil {
var err error
sealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil {
return nil, err
}
}
}
}
Expand All @@ -348,12 +351,14 @@ func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacke
return nil, nil
}
encLevel = protocol.Encryption1RTT
s, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return nil, err
}
hdr = p.getShortHeader(s.KeyPhase())
sealer = s
}

sealer, hdr, err := p.getSealerAndHeader(encLevel)
if err != nil {
return nil, err
}
return p.writeSinglePacket(hdr, pay, encLevel, sealer)
}

Expand Down Expand Up @@ -387,7 +392,7 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) {
}
var size protocol.ByteCount
if initialSealer != nil {
initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), size, protocol.EncryptionInitial)
initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), protocol.EncryptionInitial, false, size == 0)
if initialPayload != nil {
size += p.packetLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead())
numPackets++
Expand All @@ -403,7 +408,7 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) {
return nil, err
}
if handshakeSealer != nil {
handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), size, protocol.EncryptionHandshake)
handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), protocol.EncryptionHandshake, false, size == 0)
if handshakePayload != nil {
s := p.packetLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead())
size += s
Expand Down Expand Up @@ -495,7 +500,17 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
}, nil
}

func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*wire.ExtendedHeader, *payload) {
func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool) (*wire.ExtendedHeader, *payload) {
if onlyAck {
if ack := p.acks.GetAckFrame(encLevel, true); ack != nil {
var payload payload
payload.ack = ack
payload.length = ack.Length(p.version)
return p.getLongHeader(encLevel), &payload
}
return nil, nil
}

var s cryptoStream
var hasRetransmission bool
//nolint:exhaustive // Initial and Handshake are the only two encryption levels here.
Expand All @@ -510,7 +525,7 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize, currentSize protocol.

hasData := s.HasData()
var ack *wire.AckFrame
if encLevel == protocol.EncryptionInitial || currentSize == 0 {
if ackAllowed {
ack = p.acks.GetAckFrame(encLevel, !hasRetransmission && !hasData)
}
if !hasData && !hasRetransmission && ack == nil {
Expand Down Expand Up @@ -677,14 +692,14 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (
if err != nil {
return nil, err
}
hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionInitial)
hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true)
case protocol.EncryptionHandshake:
var err error
sealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil {
return nil, err
}
hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionHandshake)
hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true)
case protocol.Encryption1RTT:
oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
Expand Down Expand Up @@ -738,41 +753,6 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B
}, nil
}

func (p *packetPacker) getSealerAndHeader(encLevel protocol.EncryptionLevel) (sealer, *wire.ExtendedHeader, error) {
switch encLevel {
case protocol.EncryptionInitial:
sealer, err := p.cryptoSetup.GetInitialSealer()
if err != nil {
return nil, nil, err
}
hdr := p.getLongHeader(protocol.EncryptionInitial)
return sealer, hdr, nil
case protocol.Encryption0RTT:
sealer, err := p.cryptoSetup.Get0RTTSealer()
if err != nil {
return nil, nil, err
}
hdr := p.getLongHeader(protocol.Encryption0RTT)
return sealer, hdr, nil
case protocol.EncryptionHandshake:
sealer, err := p.cryptoSetup.GetHandshakeSealer()
if err != nil {
return nil, nil, err
}
hdr := p.getLongHeader(protocol.EncryptionHandshake)
return sealer, hdr, nil
case protocol.Encryption1RTT:
sealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return nil, nil, err
}
hdr := p.getShortHeader(sealer.KeyPhase())
return sealer, hdr, nil
default:
return nil, nil, fmt.Errorf("unexpected encryption level: %s", encLevel)
}
}

func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader {
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
hdr := &wire.ExtendedHeader{}
Expand Down

0 comments on commit 29cb0db

Please sign in to comment.