Skip to content

Commit

Permalink
implement a more intuitive address validation API
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Aug 12, 2022
1 parent bebff46 commit fc64970
Show file tree
Hide file tree
Showing 14 changed files with 352 additions and 436 deletions.
18 changes: 13 additions & 5 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package quic

import (
"errors"
"net"
"time"

"github.com/lucas-clemente/quic-go/internal/utils"

"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)

// Clone clones a Config
Expand Down Expand Up @@ -39,8 +39,14 @@ func populateServerConfig(config *Config) *Config {
if config.ConnectionIDLength == 0 {
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
}
if config.AcceptToken == nil {
config.AcceptToken = defaultAcceptToken
if config.MaxTokenAge == 0 {
config.MaxTokenAge = protocol.TokenValidity
}
if config.MaxRetryTokenAge == 0 {
config.MaxRetryTokenAge = protocol.RetryTokenValidity
}
if config.RequireAddressValidation == nil {
config.RequireAddressValidation = func(net.Addr) bool { return true }
}
return config
}
Expand Down Expand Up @@ -104,7 +110,9 @@ func populateConfig(config *Config) *Config {
Versions: versions,
HandshakeIdleTimeout: handshakeIdleTimeout,
MaxIdleTimeout: idleTimeout,
AcceptToken: config.AcceptToken,
MaxTokenAge: config.MaxTokenAge,
MaxRetryTokenAge: config.MaxRetryTokenAge,
RequireAddressValidation: config.RequireAddressValidation,
KeepAlivePeriod: config.KeepAlivePeriod,
InitialStreamReceiveWindow: initialStreamReceiveWindow,
MaxStreamReceiveWindow: maxStreamReceiveWindow,
Expand Down
35 changes: 19 additions & 16 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ var _ = Describe("Config", func() {
}

switch fn := typ.Field(i).Name; fn {
case "AcceptToken", "GetLogWriter", "AllowConnectionWindowIncrease":
case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease":
// Can't compare functions.
case "Versions":
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
Expand All @@ -55,6 +55,10 @@ var _ = Describe("Config", func() {
f.Set(reflect.ValueOf(time.Second))
case "MaxIdleTimeout":
f.Set(reflect.ValueOf(time.Hour))
case "MaxTokenAge":
f.Set(reflect.ValueOf(2 * time.Hour))
case "MaxRetryTokenAge":
f.Set(reflect.ValueOf(2 * time.Minute))
case "TokenStore":
f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3)))
case "InitialStreamReceiveWindow":
Expand Down Expand Up @@ -100,14 +104,14 @@ var _ = Describe("Config", func() {

Context("cloning", func() {
It("clones function fields", func() {
var calledAcceptToken, calledAllowConnectionWindowIncrease bool
var calledAddrValidation, calledAllowConnectionWindowIncrease bool
c1 := &Config{
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true },
}
c2 := c1.Clone()
c2.AcceptToken(&net.UDPAddr{}, &Token{})
Expect(calledAcceptToken).To(BeTrue())
c2.RequireAddressValidation(&net.UDPAddr{})
Expect(calledAddrValidation).To(BeTrue())
c2.AllowConnectionWindowIncrease(nil, 1234)
Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
})
Expand All @@ -119,27 +123,26 @@ var _ = Describe("Config", func() {

It("returns a copy", func() {
c1 := &Config{
MaxIncomingStreams: 100,
AcceptToken: func(_ net.Addr, _ *Token) bool { return true },
MaxIncomingStreams: 100,
RequireAddressValidation: func(net.Addr) bool { return true },
}
c2 := c1.Clone()
c2.MaxIncomingStreams = 200
c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
c2.RequireAddressValidation = func(net.Addr) bool { return false }

Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue())
Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue())
})
})

Context("populating", func() {
It("populates function fields", func() {
var calledAcceptToken bool
c1 := &Config{
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
}
var calledAddrValidation bool
c1 := &Config{}
c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
c2 := populateConfig(c1)
c2.AcceptToken(&net.UDPAddr{}, &Token{})
Expect(calledAcceptToken).To(BeTrue())
c2.RequireAddressValidation(&net.UDPAddr{})
Expect(calledAddrValidation).To(BeTrue())
})

It("copies non-function fields", func() {
Expand All @@ -164,7 +167,7 @@ var _ = Describe("Config", func() {
It("populates empty fields with default values, for the server", func() {
c := populateServerConfig(&Config{})
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
Expect(c.AcceptToken).ToNot(BeNil())
Expect(c.RequireAddressValidation).ToNot(BeNil())
})

It("sets a default connection ID length if we didn't create the conn, for the client", func() {
Expand Down
18 changes: 0 additions & 18 deletions fuzzing/tokens/fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package tokens

import (
"encoding/binary"
"fmt"
"math/rand"
"net"
"time"
Expand Down Expand Up @@ -140,22 +139,5 @@ func newRetryToken(tg *handshake.TokenGenerator, data []byte) int {
if !token.RetrySrcConnectionID.Equal(retrySrcConnID) {
panic("retry src conn ID doesn't match")
}
checkAddr(token.RemoteAddr, addr)
return 1
}

func checkAddr(tokenAddr string, addr net.Addr) {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
// For UDP addresses, we encode only the IP (not the port).
if ip := udpAddr.IP.String(); tokenAddr != ip {
fmt.Printf("%s vs %s", tokenAddr, ip)
panic("wrong remote address for a net.UDPAddr")
}
return
}

if tokenAddr != addr.String() {
fmt.Printf("%s vs %s", tokenAddr, addr.String())
panic("wrong remote address")
}
}
4 changes: 2 additions & 2 deletions integrationtests/self/handshake_drop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ var _ = Describe("Handshake drop tests", func() {
HandshakeIdleTimeout: timeout,
Versions: []protocol.VersionNumber{version},
})
if !doRetry {
conf.AcceptToken = func(net.Addr, *quic.Token) bool { return true }
if doRetry {
conf.RequireAddressValidation = func(net.Addr) bool { return true }
}
var tlsConf *tls.Config
if longCertChain {
Expand Down
25 changes: 2 additions & 23 deletions integrationtests/self/handshake_rtt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ var _ = Describe("Handshake RTT tests", func() {
})

It("establishes a connection in 1 RTT when the server doesn't require a token", func() {
serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool {
return true
}
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
runServerAndProxy()
_, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
Expand All @@ -126,9 +124,7 @@ var _ = Describe("Handshake RTT tests", func() {
})

It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() {
serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool {
return true
}
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384}
runServerAndProxy()
_, err := quic.DialAddr(
Expand All @@ -139,21 +135,4 @@ var _ = Describe("Handshake RTT tests", func() {
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})

It("doesn't complete the handshake when the server never accepts the token", func() {
serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool {
return false
}
clientConfig.HandshakeIdleTimeout = 500 * time.Millisecond
runServerAndProxy()
_, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
clientConfig,
)
Expect(err).To(HaveOccurred())
nerr, ok := err.(net.Error)
Expect(ok).To(BeTrue())
Expect(nerr.Timeout()).To(BeTrue())
})
})
36 changes: 4 additions & 32 deletions integrationtests/self/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,12 +344,7 @@ var _ = Describe("Handshake tests", func() {
}

BeforeEach(func() {
serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool {
if token != nil {
Expect(token.IsRetryToken).To(BeFalse())
}
return true
}
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
var err error
// start the server, but don't call Accept
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expand Down Expand Up @@ -479,13 +474,7 @@ var _ = Describe("Handshake tests", func() {

Context("using tokens", func() {
It("uses tokens provided in NEW_TOKEN frames", func() {
tokenChan := make(chan *quic.Token, 100)
serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool {
if token != nil && !token.IsRetryToken {
tokenChan <- token
}
return true
}
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }

server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -509,7 +498,6 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
Expect(gets).To(Receive())
Eventually(puts).Should(Receive())
Expect(tokenChan).ToNot(Receive())
// received a token. Close this connection.
Expect(conn.CloseWithError(0, "")).To(Succeed())

Expand All @@ -529,17 +517,13 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
Expect(gets).To(Receive())
Expect(tokenChan).To(Receive())

Eventually(done).Should(BeClosed())
})

It("rejects invalid Retry token with the INVALID_TOKEN error", func() {
tokenChan := make(chan *quic.Token, 10)
serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool {
tokenChan <- token
return false
}
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
serverConfig.MaxRetryTokenAge = time.Nanosecond

server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -554,18 +538,6 @@ var _ = Describe("Handshake tests", func() {
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.InvalidToken))
// Receiving a Retry might lead the client to measure a very small RTT.
// Then, it sometimes would retransmit the ClientHello before receiving the ServerHello.
Expect(len(tokenChan)).To(BeNumerically(">=", 2))
token := <-tokenChan
Expect(token).To(BeNil())
token = <-tokenChan
Expect(token).ToNot(BeNil())
// If the ClientHello was retransmitted, make sure that it contained the same Retry token.
for i := 2; i < len(tokenChan); i++ {
Expect(<-tokenChan).To(Equal(token))
}
Expect(token.IsRetryToken).To(BeTrue())
})
})

Expand Down
6 changes: 3 additions & 3 deletions integrationtests/self/packetization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ var _ = Describe("Packetization", func() {
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{
AcceptToken: func(net.Addr, *quic.Token) bool { return true },
DisablePathMTUDiscovery: true,
Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }),
RequireAddressValidation: func(net.Addr) bool { return false },
DisablePathMTUDiscovery: true,
Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }),
}),
)
Expect(err).ToNot(HaveOccurred())
Expand Down

0 comments on commit fc64970

Please sign in to comment.