Skip to content

Commit

Permalink
disable address validation by default
Browse files Browse the repository at this point in the history
We should provide safe defaults. Since we implement the 3x amplification
limit, disabling address validation is not unsafe, and will save 1 RTT
for every handshake for applications that don't explicitely configure
Retries.
  • Loading branch information
marten-seemann committed Aug 12, 2022
1 parent fc64970 commit 5da735a
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 62 deletions.
2 changes: 1 addition & 1 deletion config.go
Expand Up @@ -46,7 +46,7 @@ func populateServerConfig(config *Config) *Config {
config.MaxRetryTokenAge = protocol.RetryTokenValidity
}
if config.RequireAddressValidation == nil {
config.RequireAddressValidation = func(net.Addr) bool { return true }
config.RequireAddressValidation = func(net.Addr) bool { return false }
}
return config
}
Expand Down
10 changes: 4 additions & 6 deletions integrationtests/self/handshake_drop_test.go
Expand Up @@ -37,13 +37,11 @@ var _ = Describe("Handshake drop tests", func() {

startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool, version protocol.VersionNumber) {
conf := getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
Versions: []protocol.VersionNumber{version},
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return doRetry },
})
if doRetry {
conf.RequireAddressValidation = func(net.Addr) bool { return true }
}
var tlsConf *tls.Config
if longCertChain {
tlsConf = getTLSConfigWithLongCertChain()
Expand Down
3 changes: 1 addition & 2 deletions integrationtests/self/handshake_rtt_test.go
Expand Up @@ -101,6 +101,7 @@ var _ = Describe("Handshake RTT tests", func() {
// 1 RTT for verifying the source address
// 1 RTT for the TLS handshake
It("is forward-secure after 2 RTTs", func() {
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
runServerAndProxy()
_, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
Expand All @@ -112,7 +113,6 @@ var _ = Describe("Handshake RTT tests", func() {
})

It("establishes a connection in 1 RTT when the server doesn't require a token", func() {
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
runServerAndProxy()
_, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
Expand All @@ -124,7 +124,6 @@ var _ = Describe("Handshake RTT tests", func() {
})

It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() {
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384}
runServerAndProxy()
_, err := quic.DialAddr(
Expand Down
3 changes: 0 additions & 3 deletions integrationtests/self/handshake_test.go
Expand Up @@ -344,7 +344,6 @@ var _ = Describe("Handshake tests", func() {
}

BeforeEach(func() {
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
var err error
// start the server, but don't call Accept
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expand Down Expand Up @@ -474,8 +473,6 @@ var _ = Describe("Handshake tests", func() {

Context("using tokens", func() {
It("uses tokens provided in NEW_TOKEN frames", func() {
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }

server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())

Expand Down
1 change: 1 addition & 0 deletions integrationtests/self/mitm_test.go
Expand Up @@ -379,6 +379,7 @@ var _ = Describe("MITM test", func() {
// as it has already accepted a retry.
// TODO: determine behavior when server does not send Retry packets
It("fails when a forged Retry packet with modified srcConnID is sent to client", func() {
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
var initialPacketIntercepted bool
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
if dir == quicproxy.DirectionIncoming && !initialPacketIntercepted {
Expand Down
5 changes: 2 additions & 3 deletions integrationtests/self/packetization_test.go
Expand Up @@ -26,9 +26,8 @@ var _ = Describe("Packetization", func() {
"localhost:0",
getTLSConfig(),
getQuicConfig(&quic.Config{
RequireAddressValidation: func(net.Addr) bool { return false },
DisablePathMTUDiscovery: true,
Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }),
DisablePathMTUDiscovery: true,
Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }),
}),
)
Expect(err).ToNot(HaveOccurred())
Expand Down
57 changes: 23 additions & 34 deletions integrationtests/self/zero_rtt_test.go
Expand Up @@ -55,9 +55,7 @@ var _ = Describe("0-RTT", func() {
dialAndReceiveSessionTicket := func(serverConf *quic.Config) (*tls.Config, *tls.Config) {
tlsConf := getTLSConfig()
if serverConf == nil {
serverConf = getQuicConfig(&quic.Config{
RequireAddressValidation: func(net.Addr) bool { return false },
})
serverConf = getQuicConfig(nil)
serverConf.Versions = []protocol.VersionNumber{version}
}
ln, err := quic.ListenAddrEarly(
Expand Down Expand Up @@ -197,9 +195,8 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return false },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Versions: []protocol.VersionNumber{version},
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
Expect(err).ToNot(HaveOccurred())
Expand Down Expand Up @@ -255,9 +252,8 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return false },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Versions: []protocol.VersionNumber{version},
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
Expect(err).ToNot(HaveOccurred())
Expand Down Expand Up @@ -400,8 +396,9 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
Expect(err).ToNot(HaveOccurred())
Expand Down Expand Up @@ -452,19 +449,17 @@ var _ = Describe("0-RTT", func() {
It("doesn't reject 0-RTT when the server's transport stream limit increased", func() {
const maxStreams = 1
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
MaxIncomingUniStreams: maxStreams,
RequireAddressValidation: func(net.Addr) bool { return false },
MaxIncomingUniStreams: maxStreams,
}))

tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return false },
MaxIncomingUniStreams: maxStreams + 1,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Versions: []protocol.VersionNumber{version},
MaxIncomingUniStreams: maxStreams + 1,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
Expect(err).ToNot(HaveOccurred())
Expand Down Expand Up @@ -498,19 +493,17 @@ var _ = Describe("0-RTT", func() {
It("rejects 0-RTT when the server's stream limit decreased", func() {
const maxStreams = 42
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
MaxIncomingStreams: maxStreams,
RequireAddressValidation: func(net.Addr) bool { return false },
MaxIncomingStreams: maxStreams,
}))

tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return false },
MaxIncomingStreams: maxStreams - 1,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Versions: []protocol.VersionNumber{version},
MaxIncomingStreams: maxStreams - 1,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -537,9 +530,8 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return false },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Versions: []protocol.VersionNumber{version},
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -560,16 +552,14 @@ var _ = Describe("0-RTT", func() {
func(addFlowControlLimit func(*quic.Config, uint64)) {
tracer := newPacketTracer()
firstConf := getQuicConfig(&quic.Config{
RequireAddressValidation: func(net.Addr) bool { return false },
Versions: []protocol.VersionNumber{version},
Versions: []protocol.VersionNumber{version},
})
addFlowControlLimit(firstConf, 3)
tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf)

secondConf := getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return false },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Versions: []protocol.VersionNumber{version},
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
})
addFlowControlLimit(secondConf, 100)
ln, err := quic.ListenAddrEarly(
Expand Down Expand Up @@ -722,9 +712,8 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return false },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Versions: []protocol.VersionNumber{version},
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
Expect(err).ToNot(HaveOccurred())
Expand Down
16 changes: 3 additions & 13 deletions server_test.go
Expand Up @@ -241,8 +241,9 @@ var _ = Describe("Server", func() {

It("creates a connection when the token is accepted", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
retryToken, err := serv.tokenGenerator.NewRetryToken(
&net.UDPAddr{},
raddr,
protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde},
protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad},
)
Expand All @@ -256,6 +257,7 @@ var _ = Describe("Server", func() {
Token: retryToken,
}
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
p.remoteAddr = raddr
run := make(chan struct{})
var token protocol.StatelessResetToken
rand.Read(token[:])
Expand Down Expand Up @@ -451,7 +453,6 @@ var _ = Describe("Server", func() {
})

It("creates a connection, if no token is required", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
Expand Down Expand Up @@ -534,7 +535,6 @@ var _ = Describe("Server", func() {
}).AnyTimes()
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()

serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
acceptConn := make(chan struct{})
var counter uint32 // to be used as an atomic, so we query it in Eventually
serv.newConn = func(
Expand Down Expand Up @@ -588,7 +588,6 @@ var _ = Describe("Server", func() {
})

It("only creates a single connection for a duplicate Initial", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
var createdConn bool
conn := NewMockQuicConn(mockCtrl)
serv.newConn = func(
Expand Down Expand Up @@ -620,8 +619,6 @@ var _ = Describe("Server", func() {
})

It("rejects new connection attempts if the accept queue is full", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }

serv.newConn = func(
_ sendConn,
runner connRunner,
Expand Down Expand Up @@ -688,8 +685,6 @@ var _ = Describe("Server", func() {
})

It("doesn't accept new connections if they were closed in the mean time", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }

p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
ctx, cancel := context.WithCancel(context.Background())
connCreated := make(chan struct{})
Expand Down Expand Up @@ -999,7 +994,6 @@ var _ = Describe("Server", func() {
}()

ctx, cancel := context.WithCancel(context.Background()) // handshake context
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
serv.newConn = func(
_ sendConn,
runner connRunner,
Expand Down Expand Up @@ -1073,7 +1067,6 @@ var _ = Describe("Server", func() {
}()

ready := make(chan struct{})
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
serv.newConn = func(
_ sendConn,
runner connRunner,
Expand Down Expand Up @@ -1114,7 +1107,6 @@ var _ = Describe("Server", func() {
})

It("rejects new connection attempts if the accept queue is full", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}

serv.newConn = func(
Expand Down Expand Up @@ -1175,8 +1167,6 @@ var _ = Describe("Server", func() {
})

It("doesn't accept new connections if they were closed in the mean time", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }

p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
ctx, cancel := context.WithCancel(context.Background())
connCreated := make(chan struct{})
Expand Down

0 comments on commit 5da735a

Please sign in to comment.