From 5da735abee6095d0bebcd660ecbd4a23ab3a3b2e Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 11 Aug 2022 22:03:10 +0400 Subject: [PATCH] disable address validation by default We should provide safe defaults. Since we implement the 3x amplification limit, disabling address validation is not unsafe, and will save 1 RTT for every handshake for applications that don't explicitely configure Retries. --- config.go | 2 +- integrationtests/self/handshake_drop_test.go | 10 ++-- integrationtests/self/handshake_rtt_test.go | 3 +- integrationtests/self/handshake_test.go | 3 -- integrationtests/self/mitm_test.go | 1 + integrationtests/self/packetization_test.go | 5 +- integrationtests/self/zero_rtt_test.go | 57 ++++++++------------ server_test.go | 16 ++---- 8 files changed, 35 insertions(+), 62 deletions(-) diff --git a/config.go b/config.go index c54e3454826..558505c153c 100644 --- a/config.go +++ b/config.go @@ -46,7 +46,7 @@ func populateServerConfig(config *Config) *Config { config.MaxRetryTokenAge = protocol.RetryTokenValidity } if config.RequireAddressValidation == nil { - config.RequireAddressValidation = func(net.Addr) bool { return true } + config.RequireAddressValidation = func(net.Addr) bool { return false } } return config } diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index 525f19999b7..2bd6e362514 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -37,13 +37,11 @@ var _ = Describe("Handshake drop tests", func() { startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool, version protocol.VersionNumber) { conf := getQuicConfig(&quic.Config{ - MaxIdleTimeout: timeout, - HandshakeIdleTimeout: timeout, - Versions: []protocol.VersionNumber{version}, + MaxIdleTimeout: timeout, + HandshakeIdleTimeout: timeout, + Versions: []protocol.VersionNumber{version}, + RequireAddressValidation: func(net.Addr) bool { return doRetry }, }) - if doRetry { - conf.RequireAddressValidation = func(net.Addr) bool { return true } - } var tlsConf *tls.Config if longCertChain { tlsConf = getTLSConfigWithLongCertChain() diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index 22496b26432..0f767d1b925 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -101,6 +101,7 @@ var _ = Describe("Handshake RTT tests", func() { // 1 RTT for verifying the source address // 1 RTT for the TLS handshake It("is forward-secure after 2 RTTs", func() { + serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } runServerAndProxy() _, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), @@ -112,7 +113,6 @@ var _ = Describe("Handshake RTT tests", func() { }) It("establishes a connection in 1 RTT when the server doesn't require a token", func() { - serverConfig.RequireAddressValidation = func(net.Addr) bool { return false } runServerAndProxy() _, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), @@ -124,7 +124,6 @@ var _ = Describe("Handshake RTT tests", func() { }) It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() { - serverConfig.RequireAddressValidation = func(net.Addr) bool { return false } serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384} runServerAndProxy() _, err := quic.DialAddr( diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index e494535c920..6cdef36d111 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -344,7 +344,6 @@ var _ = Describe("Handshake tests", func() { } BeforeEach(func() { - serverConfig.RequireAddressValidation = func(net.Addr) bool { return false } var err error // start the server, but don't call Accept server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) @@ -474,8 +473,6 @@ var _ = Describe("Handshake tests", func() { Context("using tokens", func() { It("uses tokens provided in NEW_TOKEN frames", func() { - serverConfig.RequireAddressValidation = func(net.Addr) bool { return false } - server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index 34bb14c6a3f..087fe8d9129 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -379,6 +379,7 @@ var _ = Describe("MITM test", func() { // as it has already accepted a retry. // TODO: determine behavior when server does not send Retry packets It("fails when a forged Retry packet with modified srcConnID is sent to client", func() { + serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } var initialPacketIntercepted bool delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { if dir == quicproxy.DirectionIncoming && !initialPacketIntercepted { diff --git a/integrationtests/self/packetization_test.go b/integrationtests/self/packetization_test.go index 497bd43d3e6..e326f17507d 100644 --- a/integrationtests/self/packetization_test.go +++ b/integrationtests/self/packetization_test.go @@ -26,9 +26,8 @@ var _ = Describe("Packetization", func() { "localhost:0", getTLSConfig(), getQuicConfig(&quic.Config{ - RequireAddressValidation: func(net.Addr) bool { return false }, - DisablePathMTUDiscovery: true, - Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }), + DisablePathMTUDiscovery: true, + Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }), }), ) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 45d6106b4d4..585f079bc69 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -55,9 +55,7 @@ var _ = Describe("0-RTT", func() { dialAndReceiveSessionTicket := func(serverConf *quic.Config) (*tls.Config, *tls.Config) { tlsConf := getTLSConfig() if serverConf == nil { - serverConf = getQuicConfig(&quic.Config{ - RequireAddressValidation: func(net.Addr) bool { return false }, - }) + serverConf = getQuicConfig(nil) serverConf.Versions = []protocol.VersionNumber{version} } ln, err := quic.ListenAddrEarly( @@ -197,9 +195,8 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - RequireAddressValidation: func(net.Addr) bool { return false }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -255,9 +252,8 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - RequireAddressValidation: func(net.Addr) bool { return false }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -400,8 +396,9 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + RequireAddressValidation: func(net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -452,8 +449,7 @@ var _ = Describe("0-RTT", func() { It("doesn't reject 0-RTT when the server's transport stream limit increased", func() { const maxStreams = 1 tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ - MaxIncomingUniStreams: maxStreams, - RequireAddressValidation: func(net.Addr) bool { return false }, + MaxIncomingUniStreams: maxStreams, })) tracer := newPacketTracer() @@ -461,10 +457,9 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - RequireAddressValidation: func(net.Addr) bool { return false }, - MaxIncomingUniStreams: maxStreams + 1, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + MaxIncomingUniStreams: maxStreams + 1, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -498,8 +493,7 @@ var _ = Describe("0-RTT", func() { It("rejects 0-RTT when the server's stream limit decreased", func() { const maxStreams = 42 tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ - MaxIncomingStreams: maxStreams, - RequireAddressValidation: func(net.Addr) bool { return false }, + MaxIncomingStreams: maxStreams, })) tracer := newPacketTracer() @@ -507,10 +501,9 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - RequireAddressValidation: func(net.Addr) bool { return false }, - MaxIncomingStreams: maxStreams - 1, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + MaxIncomingStreams: maxStreams - 1, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -537,9 +530,8 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - RequireAddressValidation: func(net.Addr) bool { return false }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -560,16 +552,14 @@ var _ = Describe("0-RTT", func() { func(addFlowControlLimit func(*quic.Config, uint64)) { tracer := newPacketTracer() firstConf := getQuicConfig(&quic.Config{ - RequireAddressValidation: func(net.Addr) bool { return false }, - Versions: []protocol.VersionNumber{version}, + Versions: []protocol.VersionNumber{version}, }) addFlowControlLimit(firstConf, 3) tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) secondConf := getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - RequireAddressValidation: func(net.Addr) bool { return false }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }) addFlowControlLimit(secondConf, 100) ln, err := quic.ListenAddrEarly( @@ -722,9 +712,8 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - RequireAddressValidation: func(net.Addr) bool { return false }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) diff --git a/server_test.go b/server_test.go index 5640a05a929..4944bccfd2c 100644 --- a/server_test.go +++ b/server_test.go @@ -241,8 +241,9 @@ var _ = Describe("Server", func() { It("creates a connection when the token is accepted", func() { serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} retryToken, err := serv.tokenGenerator.NewRetryToken( - &net.UDPAddr{}, + raddr, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, ) @@ -256,6 +257,7 @@ var _ = Describe("Server", func() { Token: retryToken, } p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + p.remoteAddr = raddr run := make(chan struct{}) var token protocol.StatelessResetToken rand.Read(token[:]) @@ -451,7 +453,6 @@ var _ = Describe("Server", func() { }) It("creates a connection, if no token is required", func() { - serv.config.RequireAddressValidation = func(net.Addr) bool { return false } hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, @@ -534,7 +535,6 @@ var _ = Describe("Server", func() { }).AnyTimes() tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes() - serv.config.RequireAddressValidation = func(net.Addr) bool { return false } acceptConn := make(chan struct{}) var counter uint32 // to be used as an atomic, so we query it in Eventually serv.newConn = func( @@ -588,7 +588,6 @@ var _ = Describe("Server", func() { }) It("only creates a single connection for a duplicate Initial", func() { - serv.config.RequireAddressValidation = func(net.Addr) bool { return false } var createdConn bool conn := NewMockQuicConn(mockCtrl) serv.newConn = func( @@ -620,8 +619,6 @@ var _ = Describe("Server", func() { }) It("rejects new connection attempts if the accept queue is full", func() { - serv.config.RequireAddressValidation = func(net.Addr) bool { return false } - serv.newConn = func( _ sendConn, runner connRunner, @@ -688,8 +685,6 @@ var _ = Describe("Server", func() { }) It("doesn't accept new connections if they were closed in the mean time", func() { - serv.config.RequireAddressValidation = func(net.Addr) bool { return false } - p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) ctx, cancel := context.WithCancel(context.Background()) connCreated := make(chan struct{}) @@ -999,7 +994,6 @@ var _ = Describe("Server", func() { }() ctx, cancel := context.WithCancel(context.Background()) // handshake context - serv.config.RequireAddressValidation = func(net.Addr) bool { return false } serv.newConn = func( _ sendConn, runner connRunner, @@ -1073,7 +1067,6 @@ var _ = Describe("Server", func() { }() ready := make(chan struct{}) - serv.config.RequireAddressValidation = func(net.Addr) bool { return false } serv.newConn = func( _ sendConn, runner connRunner, @@ -1114,7 +1107,6 @@ var _ = Describe("Server", func() { }) It("rejects new connection attempts if the accept queue is full", func() { - serv.config.RequireAddressValidation = func(net.Addr) bool { return false } senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} serv.newConn = func( @@ -1175,8 +1167,6 @@ var _ = Describe("Server", func() { }) It("doesn't accept new connections if they were closed in the mean time", func() { - serv.config.RequireAddressValidation = func(net.Addr) bool { return false } - p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) ctx, cancel := context.WithCancel(context.Background()) connCreated := make(chan struct{})