Skip to content

Commit

Permalink
Add support for providing a custom ConnectionID generator via Config
Browse files Browse the repository at this point in the history
This work makes it possible for servers or clients to control how
ConnectionIDs are generated, which in turn will force peers in the
connection to use those ConnectionIDs as destination connection IDs  when sending packets.

This is useful for scenarios where we want to perform some kind
selection on the QUIC packets at the L4 level.
  • Loading branch information
joliveirinha committed Aug 16, 2022
1 parent d5efd34 commit 977b842
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 32 deletions.
7 changes: 3 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ type client struct {
}

var (
// make it possible to mock connection ID generation in the tests
generateConnectionID = protocol.GenerateConnectionID
// make it possible to mock connection ID for initial generation in the tests
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
)

Expand Down Expand Up @@ -193,7 +192,7 @@ func dialContext(
return nil, err
}
config = populateClientConfig(config, createdPacketConn)
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer)
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -256,7 +255,7 @@ func newClient(
}
}

srcConnID, err := generateConnectionID(config.ConnectionIDLength)
srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err
}
Expand Down
21 changes: 14 additions & 7 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,16 @@ var _ = Describe("Client", func() {
})

Context("Dialing", func() {
var origGenerateConnectionID func(int) (protocol.ConnectionID, error)
var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error)

BeforeEach(func() {
origGenerateConnectionID = generateConnectionID
origGenerateConnectionIDForInitial = generateConnectionIDForInitial
generateConnectionID = func(int) (protocol.ConnectionID, error) {
return connID, nil
}
generateConnectionIDForInitial = func() (protocol.ConnectionID, error) {
return connID, nil
}
})

AfterEach(func() {
generateConnectionID = origGenerateConnectionID
generateConnectionIDForInitial = origGenerateConnectionIDForInitial
})

Expand Down Expand Up @@ -524,7 +518,7 @@ var _ = Describe("Client", func() {
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)

config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}}
c := make(chan struct{})
var cconn sendConn
var version protocol.VersionNumber
Expand Down Expand Up @@ -602,10 +596,23 @@ var _ = Describe("Client", func() {
return conn
}

config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
_, err := DialAddr("localhost:7890", tlsConf, config)
Expect(err).ToNot(HaveOccurred())
Expect(counter).To(Equal(2))
})
})
})

type mockedConnIDGenerator struct {
ConnID protocol.ConnectionID
}

func (m *mockedConnIDGenerator) GenerateConnectionID() ([]byte, error) {
return m.ConnID, nil
}

func (m *mockedConnIDGenerator) ConnectionIDLen() int {
return m.ConnID.Len()
}
23 changes: 15 additions & 8 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ func validateConfig(config *Config) error {
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateServerConfig(config *Config) *Config {
config = populateConfig(config)
if config.ConnectionIDLength == 0 {
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
}
config = populateConfig(config, protocol.DefaultConnectionIDLength)
if config.AcceptToken == nil {
config.AcceptToken = defaultAcceptToken
}
Expand All @@ -48,21 +45,26 @@ func populateServerConfig(config *Config) *Config {
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
config = populateConfig(config)
if config.ConnectionIDLength == 0 && !createdPacketConn {
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
var defaultConnIdLen = protocol.DefaultConnectionIDLength
if createdPacketConn {
defaultConnIdLen = 0
}

config = populateConfig(config, defaultConnIdLen)
return config
}

func populateConfig(config *Config) *Config {
func populateConfig(config *Config, defaultConnIDLen int) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions
if len(versions) == 0 {
versions = protocol.SupportedVersions
}
if config.ConnectionIDLength == 0 {
config.ConnectionIDLength = defaultConnIDLen
}
handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout
if config.HandshakeIdleTimeout != 0 {
handshakeIdleTimeout = config.HandshakeIdleTimeout
Expand Down Expand Up @@ -99,6 +101,10 @@ func populateConfig(config *Config) *Config {
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
connIDGenerator := config.ConnectionIDGenerator
if connIDGenerator == nil {
connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: config.ConnectionIDLength}
}

return &Config{
Versions: versions,
Expand All @@ -114,6 +120,7 @@ func populateConfig(config *Config) *Config {
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
ConnectionIDLength: config.ConnectionIDLength,
ConnectionIDGenerator: connIDGenerator,
StatelessResetKey: config.StatelessResetKey,
TokenStore: config.TokenStore,
EnableDatagrams: config.EnableDatagrams,
Expand Down
8 changes: 5 additions & 3 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ var _ = Describe("Config", func() {
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
case "ConnectionIDLength":
f.Set(reflect.ValueOf(8))
case "ConnectionIDGenerator":
f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength}))
case "HandshakeIdleTimeout":
f.Set(reflect.ValueOf(time.Second))
case "MaxIdleTimeout":
Expand Down Expand Up @@ -137,18 +139,18 @@ var _ = Describe("Config", func() {
c1 := &Config{
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
}
c2 := populateConfig(c1)
c2 := populateConfig(c1, protocol.DefaultConnectionIDLength)
c2.AcceptToken(&net.UDPAddr{}, &Token{})
Expect(calledAcceptToken).To(BeTrue())
})

It("copies non-function fields", func() {
c := configWithNonZeroNonFunctionFields()
Expect(populateConfig(c)).To(Equal(c))
Expect(populateConfig(c, protocol.DefaultConnectionIDLength)).To(Equal(c))
})

It("populates empty fields with default values", func() {
c := populateConfig(&Config{})
c := populateConfig(&Config{}, protocol.DefaultConnectionIDLength)
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData))
Expand Down
9 changes: 5 additions & 4 deletions conn_id_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

type connIDGenerator struct {
connIDLen int
generator ConnectionIDGenerator
highestSeq uint64

activeSrcConnIDs map[uint64]protocol.ConnectionID
Expand All @@ -35,10 +35,11 @@ func newConnIDGenerator(
retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func(protocol.ConnectionID, packetHandler),
queueControlFrame func(wire.Frame),
generator ConnectionIDGenerator,
version protocol.VersionNumber,
) *connIDGenerator {
m := &connIDGenerator{
connIDLen: initialConnectionID.Len(),
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
addConnectionID: addConnectionID,
getStatelessResetToken: getStatelessResetToken,
Expand All @@ -54,7 +55,7 @@ func newConnIDGenerator(
}

func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
if m.connIDLen == 0 {
if m.generator.ConnectionIDLen() == 0 {
return nil
}
// The active_connection_id_limit transport parameter is the number of
Expand Down Expand Up @@ -99,7 +100,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect
}

func (m *connIDGenerator) issueNewConnID() error {
connID, err := protocol.GenerateConnectionID(m.connIDLen)
connID, err := m.generator.GenerateConnectionID()
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions conn_id_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ var _ = Describe("Connection ID Generator", func() {
func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
func(c protocol.ConnectionID, h packetHandler) { replacedWithClosed[string(c)] = h },
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
&protocol.DefaultConnectionIDGenerator{ConnLen: initialConnID.Len()},
protocol.VersionDraft29,
)
})
Expand Down
2 changes: 2 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ var newConnection = func(
runner.Retire,
runner.ReplaceWithClosed,
s.queueControlFrame,
s.config.ConnectionIDGenerator,
s.version,
)
s.preSetup()
Expand Down Expand Up @@ -407,6 +408,7 @@ var newClientConnection = func(
runner.Retire,
runner.ReplaceWithClosed,
s.queueControlFrame,
s.config.ConnectionIDGenerator,
s.version,
)
s.preSetup()
Expand Down
17 changes: 17 additions & 0 deletions interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,18 @@ type EarlyConnection interface {
NextConnection() Connection
}

// A ConnectionIDGenerator is a interface that allows clients to implement their own format
// for the Connection IDs that servers/clients use as SrcConnectionID in QUIC packets.
type ConnectionIDGenerator interface {

// GenerateConnectionID generates a new ConnectionID.
GenerateConnectionID() ([]byte, error)

// ConnectionIDLen tells what is the length of the ConnectionIDs generated by the implementation of
// this interface.
ConnectionIDLen() int
}

// Config contains all configuration data needed for a QUIC server or client.
type Config struct {
// The QUIC versions that can be negotiated.
Expand All @@ -223,6 +235,11 @@ type Config struct {
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
ConnectionIDLength int
// An optional ConnectionIDGenerator to be used for ConnectionIDs generated during the lifecycle of a QUIC connection.
// The goal is to give some control on how connection IDs, which can be useful in some scenarios, in particular for servers.
// By default, if not provided, random connection IDs with the length given by ConnectionIDLength is used.
// Otherwise, if one is provided, then ConnectionIDLength is ignored.
ConnectionIDGenerator ConnectionIDGenerator
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
// Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted.
// If this value is zero, the timeout is set to 5 seconds.
Expand Down
12 changes: 12 additions & 0 deletions internal/protocol/connection_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,15 @@ func (c ConnectionID) String() string {
}
return fmt.Sprintf("%x", c.Bytes())
}

type DefaultConnectionIDGenerator struct {
ConnLen int
}

func (d *DefaultConnectionIDGenerator) GenerateConnectionID() ([]byte, error) {
return GenerateConnectionID(d.ConnLen)
}

func (d *DefaultConnectionIDGenerator) ConnectionIDLen() int {
return d.ConnLen
}
12 changes: 6 additions & 6 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl
}
}

connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer)
connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -341,7 +341,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
}
// If we're creating a new connection, the packet will be passed to the connection.
// The header will then be parsed again.
hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength)
hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDGenerator.ConnectionIDLen())
if err != nil && err != wire.ErrUnsupportedVersion {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
Expand Down Expand Up @@ -450,11 +450,11 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
return nil
}

connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
connID, err := s.config.ConnectionIDGenerator.GenerateConnectionID()
if err != nil {
return err
}
s.logger.Debugf("Changing connection ID to %s.", connID)
s.logger.Debugf("Changing connection ID to %s.", protocol.ConnectionID(connID))
var conn quicConn
tracingID := nextConnTracingID()
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler {
Expand Down Expand Up @@ -535,7 +535,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
// Log the Initial packet now.
// If no Retry is sent, the packet will be logged by the connection.
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
srcConnID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
srcConnID, err := s.config.ConnectionIDGenerator.GenerateConnectionID()
if err != nil {
return err
}
Expand All @@ -551,7 +551,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
replyHdr.DestConnectionID = hdr.SrcConnectionID
replyHdr.Token = token
if s.logger.Debug() {
s.logger.Debugf("Changing connection ID to %s.", srcConnID)
s.logger.Debugf("Changing connection ID to %s.", protocol.ConnectionID(srcConnID))
s.logger.Debugf("-> Sending Retry")
replyHdr.Log(s.logger)
}
Expand Down

0 comments on commit 977b842

Please sign in to comment.