diff --git a/config.go b/config.go index 377c3ae9134..93735e7217c 100644 --- a/config.go +++ b/config.go @@ -2,11 +2,11 @@ package quic import ( "errors" + "net" "time" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" ) // Clone clones a Config @@ -39,8 +39,14 @@ func populateServerConfig(config *Config) *Config { if config.ConnectionIDLength == 0 { config.ConnectionIDLength = protocol.DefaultConnectionIDLength } - if config.AcceptToken == nil { - config.AcceptToken = defaultAcceptToken + if config.MaxTokenAge == 0 { + config.MaxTokenAge = protocol.TokenValidity + } + if config.MaxRetryTokenAge == 0 { + config.MaxRetryTokenAge = protocol.RetryTokenValidity + } + if config.RequireAddressValidation == nil { + config.RequireAddressValidation = func(net.Addr) bool { return false } } return config } @@ -104,7 +110,9 @@ func populateConfig(config *Config) *Config { Versions: versions, HandshakeIdleTimeout: handshakeIdleTimeout, MaxIdleTimeout: idleTimeout, - AcceptToken: config.AcceptToken, + MaxTokenAge: config.MaxTokenAge, + MaxRetryTokenAge: config.MaxRetryTokenAge, + RequireAddressValidation: config.RequireAddressValidation, KeepAlivePeriod: config.KeepAlivePeriod, InitialStreamReceiveWindow: initialStreamReceiveWindow, MaxStreamReceiveWindow: maxStreamReceiveWindow, diff --git a/config_test.go b/config_test.go index 692952f4fa8..f4cfe41d12c 100644 --- a/config_test.go +++ b/config_test.go @@ -45,7 +45,7 @@ var _ = Describe("Config", func() { } switch fn := typ.Field(i).Name; fn { - case "AcceptToken", "GetLogWriter", "AllowConnectionWindowIncrease": + case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease": // Can't compare functions. case "Versions": f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) @@ -55,6 +55,10 @@ var _ = Describe("Config", func() { f.Set(reflect.ValueOf(time.Second)) case "MaxIdleTimeout": f.Set(reflect.ValueOf(time.Hour)) + case "MaxTokenAge": + f.Set(reflect.ValueOf(2 * time.Hour)) + case "MaxRetryTokenAge": + f.Set(reflect.ValueOf(2 * time.Minute)) case "TokenStore": f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3))) case "InitialStreamReceiveWindow": @@ -100,14 +104,14 @@ var _ = Describe("Config", func() { Context("cloning", func() { It("clones function fields", func() { - var calledAcceptToken, calledAllowConnectionWindowIncrease bool + var calledAddrValidation, calledAllowConnectionWindowIncrease bool c1 := &Config{ - AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true }, AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, + RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true }, } c2 := c1.Clone() - c2.AcceptToken(&net.UDPAddr{}, &Token{}) - Expect(calledAcceptToken).To(BeTrue()) + c2.RequireAddressValidation(&net.UDPAddr{}) + Expect(calledAddrValidation).To(BeTrue()) c2.AllowConnectionWindowIncrease(nil, 1234) Expect(calledAllowConnectionWindowIncrease).To(BeTrue()) }) @@ -119,27 +123,26 @@ var _ = Describe("Config", func() { It("returns a copy", func() { c1 := &Config{ - MaxIncomingStreams: 100, - AcceptToken: func(_ net.Addr, _ *Token) bool { return true }, + MaxIncomingStreams: 100, + RequireAddressValidation: func(net.Addr) bool { return true }, } c2 := c1.Clone() c2.MaxIncomingStreams = 200 - c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } + c2.RequireAddressValidation = func(net.Addr) bool { return false } Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100)) - Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue()) + Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue()) }) }) Context("populating", func() { It("populates function fields", func() { - var calledAcceptToken bool - c1 := &Config{ - AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true }, - } + var calledAddrValidation bool + c1 := &Config{} + c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true } c2 := populateConfig(c1) - c2.AcceptToken(&net.UDPAddr{}, &Token{}) - Expect(calledAcceptToken).To(BeTrue()) + c2.RequireAddressValidation(&net.UDPAddr{}) + Expect(calledAddrValidation).To(BeTrue()) }) It("copies non-function fields", func() { @@ -164,7 +167,7 @@ var _ = Describe("Config", func() { It("populates empty fields with default values, for the server", func() { c := populateServerConfig(&Config{}) Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength)) - Expect(c.AcceptToken).ToNot(BeNil()) + Expect(c.RequireAddressValidation).ToNot(BeNil()) }) It("sets a default connection ID length if we didn't create the conn, for the client", func() { diff --git a/connection.go b/connection.go index e731480c624..24760a788cd 100644 --- a/connection.go +++ b/connection.go @@ -543,7 +543,11 @@ func (s *connection) run() error { s.timer = utils.NewTimer() - go s.cryptoStreamHandler.RunHandshake() + handshaking := make(chan struct{}) + go func() { + defer close(handshaking) + s.cryptoStreamHandler.RunHandshake() + }() go func() { if err := s.sendQueue.Run(); err != nil { s.destroyImpl(err) @@ -694,12 +698,13 @@ runLoop: } } + s.cryptoStreamHandler.Close() + <-handshaking s.handleCloseError(&closeErr) if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil { s.tracer.Close() } s.logger.Infof("Connection %s closed.", s.logID) - s.cryptoStreamHandler.Close() s.sendQueue.Close() s.timer.Stop() return closeErr.err diff --git a/fuzzing/tokens/fuzz.go b/fuzzing/tokens/fuzz.go index 9d414f77f19..1e1904ba3cd 100644 --- a/fuzzing/tokens/fuzz.go +++ b/fuzzing/tokens/fuzz.go @@ -2,7 +2,6 @@ package tokens import ( "encoding/binary" - "fmt" "math/rand" "net" "time" @@ -77,7 +76,6 @@ func newToken(tg *handshake.TokenGenerator, data []byte) int { if token.OriginalDestConnectionID != nil || token.RetrySrcConnectionID != nil { panic("didn't expect connection IDs") } - checkAddr(token.RemoteAddr, addr) return 1 } @@ -140,22 +138,5 @@ func newRetryToken(tg *handshake.TokenGenerator, data []byte) int { if !token.RetrySrcConnectionID.Equal(retrySrcConnID) { panic("retry src conn ID doesn't match") } - checkAddr(token.RemoteAddr, addr) return 1 } - -func checkAddr(tokenAddr string, addr net.Addr) { - if udpAddr, ok := addr.(*net.UDPAddr); ok { - // For UDP addresses, we encode only the IP (not the port). - if ip := udpAddr.IP.String(); tokenAddr != ip { - fmt.Printf("%s vs %s", tokenAddr, ip) - panic("wrong remote address for a net.UDPAddr") - } - return - } - - if tokenAddr != addr.String() { - fmt.Printf("%s vs %s", tokenAddr, addr.String()) - panic("wrong remote address") - } -} diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index b788533dfb0..2bd6e362514 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -37,13 +37,11 @@ var _ = Describe("Handshake drop tests", func() { startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool, version protocol.VersionNumber) { conf := getQuicConfig(&quic.Config{ - MaxIdleTimeout: timeout, - HandshakeIdleTimeout: timeout, - Versions: []protocol.VersionNumber{version}, + MaxIdleTimeout: timeout, + HandshakeIdleTimeout: timeout, + Versions: []protocol.VersionNumber{version}, + RequireAddressValidation: func(net.Addr) bool { return doRetry }, }) - if !doRetry { - conf.AcceptToken = func(net.Addr, *quic.Token) bool { return true } - } var tlsConf *tls.Config if longCertChain { tlsConf = getTLSConfigWithLongCertChain() diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index df092086cc5..0f767d1b925 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -101,6 +101,7 @@ var _ = Describe("Handshake RTT tests", func() { // 1 RTT for verifying the source address // 1 RTT for the TLS handshake It("is forward-secure after 2 RTTs", func() { + serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } runServerAndProxy() _, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), @@ -112,9 +113,6 @@ var _ = Describe("Handshake RTT tests", func() { }) It("establishes a connection in 1 RTT when the server doesn't require a token", func() { - serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool { - return true - } runServerAndProxy() _, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), @@ -126,9 +124,6 @@ var _ = Describe("Handshake RTT tests", func() { }) It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() { - serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool { - return true - } serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384} runServerAndProxy() _, err := quic.DialAddr( @@ -139,21 +134,4 @@ var _ = Describe("Handshake RTT tests", func() { Expect(err).ToNot(HaveOccurred()) expectDurationInRTTs(2) }) - - It("doesn't complete the handshake when the server never accepts the token", func() { - serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool { - return false - } - clientConfig.HandshakeIdleTimeout = 500 * time.Millisecond - runServerAndProxy() - _, err := quic.DialAddr( - fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), - getTLSClientConfig(), - clientConfig, - ) - Expect(err).To(HaveOccurred()) - nerr, ok := err.(net.Error) - Expect(ok).To(BeTrue()) - Expect(nerr.Timeout()).To(BeTrue()) - }) }) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index dee162f7504..6cdef36d111 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -344,12 +344,6 @@ var _ = Describe("Handshake tests", func() { } BeforeEach(func() { - serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool { - if token != nil { - Expect(token.IsRetryToken).To(BeFalse()) - } - return true - } var err error // start the server, but don't call Accept server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) @@ -479,14 +473,6 @@ var _ = Describe("Handshake tests", func() { Context("using tokens", func() { It("uses tokens provided in NEW_TOKEN frames", func() { - tokenChan := make(chan *quic.Token, 100) - serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool { - if token != nil && !token.IsRetryToken { - tokenChan <- token - } - return true - } - server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) @@ -509,7 +495,6 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) Expect(gets).To(Receive()) Eventually(puts).Should(Receive()) - Expect(tokenChan).ToNot(Receive()) // received a token. Close this connection. Expect(conn.CloseWithError(0, "")).To(Succeed()) @@ -529,17 +514,13 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) defer conn.CloseWithError(0, "") Expect(gets).To(Receive()) - Expect(tokenChan).To(Receive()) Eventually(done).Should(BeClosed()) }) It("rejects invalid Retry token with the INVALID_TOKEN error", func() { - tokenChan := make(chan *quic.Token, 10) - serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool { - tokenChan <- token - return false - } + serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } + serverConfig.MaxRetryTokenAge = time.Nanosecond server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) @@ -554,18 +535,6 @@ var _ = Describe("Handshake tests", func() { var transportErr *quic.TransportError Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode).To(Equal(quic.InvalidToken)) - // Receiving a Retry might lead the client to measure a very small RTT. - // Then, it sometimes would retransmit the ClientHello before receiving the ServerHello. - Expect(len(tokenChan)).To(BeNumerically(">=", 2)) - token := <-tokenChan - Expect(token).To(BeNil()) - token = <-tokenChan - Expect(token).ToNot(BeNil()) - // If the ClientHello was retransmitted, make sure that it contained the same Retry token. - for i := 2; i < len(tokenChan); i++ { - Expect(<-tokenChan).To(Equal(token)) - } - Expect(token.IsRetryToken).To(BeTrue()) }) }) diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index 34bb14c6a3f..f7c6e6c7c02 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -29,24 +29,27 @@ var _ = Describe("MITM test", func() { const connIDLen = 6 // explicitly set the connection ID length, so the proxy can parse it var ( - proxy *quicproxy.QuicProxy serverUDPConn, clientUDPConn *net.UDPConn serverConn quic.Connection serverConfig *quic.Config ) - startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback) { + startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback) (proxyPort int, closeFn func()) { addr, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) serverUDPConn, err = net.ListenUDP("udp", addr) Expect(err).ToNot(HaveOccurred()) ln, err := quic.Listen(serverUDPConn, getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) + done := make(chan struct{}) go func() { defer GinkgoRecover() + defer close(done) var err error serverConn, err = ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) + if err != nil { + return + } str, err := serverConn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write(PRData) @@ -54,12 +57,18 @@ var _ = Describe("MITM test", func() { Expect(str.Close()).To(Succeed()) }() serverPort := ln.Addr().(*net.UDPAddr).Port - proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), DelayPacket: delayCb, DropPacket: dropCb, }) Expect(err).ToNot(HaveOccurred()) + return proxy.LocalPort(), func() { + proxy.Close() + ln.Close() + serverUDPConn.Close() + <-done + } } BeforeEach(func() { @@ -79,8 +88,6 @@ var _ = Describe("MITM test", func() { // Test shutdown is tricky due to the proxy. Just wait for a bit. time.Sleep(50 * time.Millisecond) Expect(clientUDPConn.Close()).To(Succeed()) - Expect(serverUDPConn.Close()).To(Succeed()) - Expect(proxy.Close()).To(Succeed()) }) Context("injecting invalid packets", func() { @@ -120,13 +127,14 @@ var _ = Describe("MITM test", func() { } runTest := func(delayCb quicproxy.DelayCallback) { - startServerAndProxy(delayCb, nil) - raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + proxyPort, closeFn := startServerAndProxy(delayCb, nil) + defer closeFn() + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) conn, err := quic.Dial( clientUDPConn, raddr, - fmt.Sprintf("localhost:%d", proxy.LocalPort()), + fmt.Sprintf("localhost:%d", proxyPort), getTLSClientConfig(), getQuicConfig(&quic.Config{ Versions: []protocol.VersionNumber{version}, @@ -166,13 +174,14 @@ var _ = Describe("MITM test", func() { }) runTest := func(dropCb quicproxy.DropCallback) { - startServerAndProxy(nil, dropCb) - raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + proxyPort, closeFn := startServerAndProxy(nil, dropCb) + defer closeFn() + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) conn, err := quic.Dial( clientUDPConn, raddr, - fmt.Sprintf("localhost:%d", proxy.LocalPort()), + fmt.Sprintf("localhost:%d", proxyPort), getTLSClientConfig(), getQuicConfig(&quic.Config{ Versions: []protocol.VersionNumber{version}, @@ -283,65 +292,14 @@ var _ = Describe("MITM test", func() { const rtt = 20 * time.Millisecond - // AfterEach closes the proxy, but each function is responsible - // for closing client and server connections - AfterEach(func() { - // Test shutdown is tricky due to the proxy. Just wait for a bit. - time.Sleep(50 * time.Millisecond) - Expect(proxy.Close()).To(Succeed()) - }) - - // sendForgedVersionNegotiationPacket sends a fake VN packet with no supported versions - // from serverUDPConn to client's remoteAddr - // expects hdr from an Initial packet intercepted from client - sendForgedVersionNegotationPacket := func(conn net.PacketConn, remoteAddr net.Addr, hdr *wire.Header) { - // Create fake version negotiation packet with no supported versions - versions := []protocol.VersionNumber{} - packet := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, versions) - - // Send the packet - _, err := conn.WriteTo(packet, remoteAddr) - Expect(err).ToNot(HaveOccurred()) - } - - // sendForgedRetryPacket sends a fake Retry packet with a modified srcConnID - // from serverUDPConn to client's remoteAddr - // expects hdr from an Initial packet intercepted from client - sendForgedRetryPacket := func(conn net.PacketConn, remoteAddr net.Addr, hdr *wire.Header) { - var x byte = 0x12 - fakeSrcConnID := protocol.ConnectionID{x, x, x, x, x, x, x, x} - retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version) - - _, err := conn.WriteTo(retryPacket, remoteAddr) - Expect(err).ToNot(HaveOccurred()) - } - - // Send a forged Initial packet with no frames to client - // expects hdr from an Initial packet intercepted from client - sendForgedInitialPacket := func(conn net.PacketConn, remoteAddr net.Addr, hdr *wire.Header) { - initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, nil) - _, err := conn.WriteTo(initialPacket, remoteAddr) - Expect(err).ToNot(HaveOccurred()) - } - - // Send a forged Initial packet with ACK for random packet to client - // expects hdr from an Initial packet intercepted from client - sendForgedInitialPacketWithAck := func(conn net.PacketConn, remoteAddr net.Addr, hdr *wire.Header) { - // Fake Initial with ACK for packet 2 (unsent) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, []wire.Frame{ack}) - _, err := conn.WriteTo(initialPacket, remoteAddr) - Expect(err).ToNot(HaveOccurred()) - } - - runTest := func(delayCb quicproxy.DelayCallback) error { - startServerAndProxy(delayCb, nil) - raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + runTest := func(delayCb quicproxy.DelayCallback) (closeFn func(), err error) { + proxyPort, closeFn := startServerAndProxy(delayCb, nil) + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) _, err = quic.Dial( clientUDPConn, raddr, - fmt.Sprintf("localhost:%d", proxy.LocalPort()), + fmt.Sprintf("localhost:%d", proxyPort), getTLSClientConfig(), getQuicConfig(&quic.Config{ Versions: []protocol.VersionNumber{version}, @@ -349,11 +307,12 @@ var _ = Describe("MITM test", func() { HandshakeIdleTimeout: 2 * time.Second, }), ) - return err + return closeFn, err } // fails immediately because client connection closes when it can't find compatible version It("fails when a forged version negotiation packet is sent to client", func() { + done := make(chan struct{}) delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { if dir == quicproxy.DirectionIncoming { defer GinkgoRecover() @@ -365,24 +324,36 @@ var _ = Describe("MITM test", func() { return 0 } - sendForgedVersionNegotationPacket(serverUDPConn, clientUDPConn.LocalAddr(), hdr) + // Create fake version negotiation packet with no supported versions + versions := []protocol.VersionNumber{} + packet := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, versions) + + // Send the packet + _, err = serverUDPConn.WriteTo(packet, clientUDPConn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) + close(done) } return rtt / 2 } - err := runTest(delayCb) + closeFn, err := runTest(delayCb) + defer closeFn() Expect(err).To(HaveOccurred()) vnErr := &quic.VersionNegotiationError{} Expect(errors.As(err, &vnErr)).To(BeTrue()) + Eventually(done).Should(BeClosed()) }) // times out, because client doesn't accept subsequent real retry packets from server // as it has already accepted a retry. // TODO: determine behavior when server does not send Retry packets It("fails when a forged Retry packet with modified srcConnID is sent to client", func() { + serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } var initialPacketIntercepted bool + done := make(chan struct{}) delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { if dir == quicproxy.DirectionIncoming && !initialPacketIntercepted { defer GinkgoRecover() + defer close(done) hdr, _, _, err := wire.ParsePacket(raw, connIDLen) Expect(err).ToNot(HaveOccurred()) @@ -392,59 +363,82 @@ var _ = Describe("MITM test", func() { } initialPacketIntercepted = true - sendForgedRetryPacket(serverUDPConn, clientUDPConn.LocalAddr(), hdr) + fakeSrcConnID := protocol.ConnectionID{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12} + retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version) + + _, err = serverUDPConn.WriteTo(retryPacket, clientUDPConn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) } return rtt / 2 } - err := runTest(delayCb) + closeFn, err := runTest(delayCb) + defer closeFn() Expect(err).To(HaveOccurred()) Expect(err.(net.Error).Timeout()).To(BeTrue()) + Eventually(done).Should(BeClosed()) }) // times out, because client doesn't accept real retry packets from server because // it has already accepted an initial. // TODO: determine behavior when server does not send Retry packets It("fails when a forged initial packet is sent to client", func() { + done := make(chan struct{}) + var injected bool delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { if dir == quicproxy.DirectionIncoming { defer GinkgoRecover() hdr, _, _, err := wire.ParsePacket(raw, connIDLen) Expect(err).ToNot(HaveOccurred()) - - if hdr.Type != protocol.PacketTypeInitial { + if hdr.Type != protocol.PacketTypeInitial || injected { return 0 } - - sendForgedInitialPacket(serverUDPConn, clientUDPConn.LocalAddr(), hdr) + defer close(done) + injected = true + initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, nil) + _, err = serverUDPConn.WriteTo(initialPacket, clientUDPConn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) } return rtt } - err := runTest(delayCb) + closeFn, err := runTest(delayCb) + defer closeFn() Expect(err).To(HaveOccurred()) Expect(err.(net.Error).Timeout()).To(BeTrue()) + Eventually(done).Should(BeClosed()) }) // client connection closes immediately on receiving ack for unsent packet It("fails when a forged initial packet with ack for unsent packet is sent to client", func() { - clientAddr := clientUDPConn.LocalAddr() + done := make(chan struct{}) + var injected bool delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { if dir == quicproxy.DirectionIncoming { + defer GinkgoRecover() + hdr, _, _, err := wire.ParsePacket(raw, connIDLen) Expect(err).ToNot(HaveOccurred()) - if hdr.Type != protocol.PacketTypeInitial { + if hdr.Type != protocol.PacketTypeInitial || injected { return 0 } - sendForgedInitialPacketWithAck(serverUDPConn, clientAddr, hdr) + defer close(done) + injected = true + // Fake Initial with ACK for packet 2 (unsent) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} + initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, []wire.Frame{ack}) + _, err = serverUDPConn.WriteTo(initialPacket, clientUDPConn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) } return rtt } - err := runTest(delayCb) + closeFn, err := runTest(delayCb) + defer closeFn() Expect(err).To(HaveOccurred()) var transportErr *quic.TransportError Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode).To(Equal(quic.ProtocolViolation)) Expect(transportErr.ErrorMessage).To(ContainSubstring("received ACK for an unsent packet")) + Eventually(done).Should(BeClosed()) }) }) }) diff --git a/integrationtests/self/packetization_test.go b/integrationtests/self/packetization_test.go index d7d1c65908c..e326f17507d 100644 --- a/integrationtests/self/packetization_test.go +++ b/integrationtests/self/packetization_test.go @@ -26,7 +26,6 @@ var _ = Describe("Packetization", func() { "localhost:0", getTLSConfig(), getQuicConfig(&quic.Config{ - AcceptToken: func(net.Addr, *quic.Token) bool { return true }, DisablePathMTUDiscovery: true, Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }), }), diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index f172174abd0..242f8602d27 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -17,7 +17,9 @@ import ( mrand "math/rand" "net" "os" + "runtime/pprof" "strconv" + "strings" "sync" "testing" "time" @@ -294,7 +296,14 @@ var _ = BeforeEach(func() { } }) +func areHandshakesRunning() bool { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 1) + return strings.Contains(b.String(), "RunHandshake") +} + var _ = AfterEach(func() { + Expect(areHandshakesRunning()).To(BeFalse()) if debugLog() { logFile, err := os.Create(logFileName) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index 2b61d4ec534..a0ad73f732e 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -1,14 +1,11 @@ package self_test import ( - "bytes" "context" "fmt" "io" mrand "math/rand" "net" - "runtime/pprof" - "strings" "sync/atomic" "time" @@ -45,12 +42,6 @@ func (c *faultyConn) WriteTo(p []byte, addr net.Addr) (int, error) { return 0, io.ErrClosedPipe } -func areHandshakesRunning() bool { - var b bytes.Buffer - pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "RunHandshake") -} - var _ = Describe("Timeout tests", func() { checkTimeoutError := func(err error) { ExpectWithOffset(1, err).To(MatchError(&quic.IdleTimeoutError{})) @@ -382,14 +373,6 @@ var _ = Describe("Timeout tests", func() { Context("faulty packet conns", func() { const handshakeTimeout = time.Second / 2 - BeforeEach(func() { - Expect(areHandshakesRunning()).To(BeFalse()) - }) - - AfterEach(func() { - Expect(areHandshakesRunning()).To(BeFalse()) - }) - runServer := func(ln quic.Listener) error { conn, err := ln.Accept(context.Background()) if err != nil { diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 39be9eadbbb..c39a02dda75 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -55,9 +55,7 @@ var _ = Describe("0-RTT", func() { dialAndReceiveSessionTicket := func(serverConf *quic.Config) (*tls.Config, *tls.Config) { tlsConf := getTLSConfig() if serverConf == nil { - serverConf = getQuicConfig(&quic.Config{ - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - }) + serverConf = getQuicConfig(nil) serverConf.Versions = []protocol.VersionNumber{version} } ln, err := quic.ListenAddrEarly( @@ -197,9 +195,8 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -255,9 +252,8 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -400,8 +396,9 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + RequireAddressValidation: func(net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -453,7 +450,6 @@ var _ = Describe("0-RTT", func() { const maxStreams = 1 tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ MaxIncomingUniStreams: maxStreams, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, })) tracer := newPacketTracer() @@ -462,7 +458,6 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ Versions: []protocol.VersionNumber{version}, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, MaxIncomingUniStreams: maxStreams + 1, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), @@ -499,7 +494,6 @@ var _ = Describe("0-RTT", func() { const maxStreams = 42 tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ MaxIncomingStreams: maxStreams, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, })) tracer := newPacketTracer() @@ -508,7 +502,6 @@ var _ = Describe("0-RTT", func() { tlsConf, getQuicConfig(&quic.Config{ Versions: []protocol.VersionNumber{version}, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, MaxIncomingStreams: maxStreams - 1, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), @@ -537,9 +530,8 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -560,16 +552,14 @@ var _ = Describe("0-RTT", func() { func(addFlowControlLimit func(*quic.Config, uint64)) { tracer := newPacketTracer() firstConf := getQuicConfig(&quic.Config{ - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - Versions: []protocol.VersionNumber{version}, + Versions: []protocol.VersionNumber{version}, }) addFlowControlLimit(firstConf, 3) tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) secondConf := getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }) addFlowControlLimit(secondConf, 100) ln, err := quic.ListenAddrEarly( @@ -722,9 +712,8 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) diff --git a/interface.go b/interface.go index 6130b5497f4..c19a5c9338e 100644 --- a/interface.go +++ b/interface.go @@ -26,16 +26,6 @@ const ( Version2 = protocol.Version2 ) -// A Token can be used to verify the ownership of the client address. -type Token struct { - // IsRetryToken encodes how the client received the token. There are two ways: - // * In a Retry packet sent when trying to establish a new connection. - // * In a NEW_TOKEN frame on a previous connection. - IsRetryToken bool - RemoteAddr string - SentTime time.Time -} - // A ClientToken is a token received by the client. // It can be used to skip address validation on future connection attempts. type ClientToken struct { @@ -233,14 +223,18 @@ type Config struct { // If the timeout is exceeded, the connection is closed. // If this value is zero, the timeout is set to 30 seconds. MaxIdleTimeout time.Duration - // AcceptToken determines if a Token is accepted. - // It is called with token = nil if the client didn't send a token. - // If not set, a default verification function is used: - // * it verifies that the address matches, and - // * if the token is a retry token, that it was issued within the last 5 seconds - // * else, that it was issued within the last 24 hours. - // This option is only valid for the server. - AcceptToken func(clientAddr net.Addr, token *Token) bool + // RequireAddressValidation determines if a QUIC Retry packet is sent. + // This allows the server to verify the client's address, at the cost of increasing the handshake latency by 1 RTT. + // See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details. + // If not set, every client is forced to prove its remote address. + RequireAddressValidation func(net.Addr) bool + // MaxRetryTokenAge is the maximum age of a Retry token. + // If not set, it defaults to 5 seconds. Only valid for a server. + MaxRetryTokenAge time.Duration + // MaxTokenAge is the maximum age of the token presented during the handshake, + // for tokens that were issued on a previous connection. + // If not set, it defaults to 24 hours. Only valid for a server. + MaxTokenAge time.Duration // The TokenStore stores tokens received from the server. // Tokens are used to skip address validation on future connection attempts. // The key used to store tokens is the ServerName from the tls.Config, if set diff --git a/internal/handshake/token_generator.go b/internal/handshake/token_generator.go index 2df5fcd8c3e..228b2fa68b2 100644 --- a/internal/handshake/token_generator.go +++ b/internal/handshake/token_generator.go @@ -1,6 +1,7 @@ package handshake import ( + "bytes" "encoding/asn1" "fmt" "io" @@ -17,14 +18,18 @@ const ( // A Token is derived from the client address and can be used to verify the ownership of this address. type Token struct { - IsRetryToken bool - RemoteAddr string - SentTime time.Time + IsRetryToken bool + SentTime time.Time + encodedRemoteAddr []byte // only set for retry tokens OriginalDestConnectionID protocol.ConnectionID RetrySrcConnectionID protocol.ConnectionID } +func (t *Token) ValidateRemoteAddr(addr net.Addr) bool { + return bytes.Equal(encodeRemoteAddr(addr), t.encodedRemoteAddr) +} + // token is the struct that is used for ASN1 serialization and deserialization type token struct { IsRetryToken bool @@ -101,9 +106,9 @@ func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) { return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) } token := &Token{ - IsRetryToken: t.IsRetryToken, - RemoteAddr: decodeRemoteAddr(t.RemoteAddr), - SentTime: time.Unix(0, t.Timestamp), + IsRetryToken: t.IsRetryToken, + SentTime: time.Unix(0, t.Timestamp), + encodedRemoteAddr: t.RemoteAddr, } if t.IsRetryToken { token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID) @@ -119,16 +124,3 @@ func encodeRemoteAddr(remoteAddr net.Addr) []byte { } return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...) } - -// decodeRemoteAddr decodes the remote address saved in the token -func decodeRemoteAddr(data []byte) string { - // data will never be empty for a token that we generated. - // Check it to be on the safe side - if len(data) == 0 { - return "" - } - if data[0] == tokenPrefixIP { - return net.IP(data[1:]).String() - } - return string(data[1:]) -} diff --git a/internal/handshake/token_generator_test.go b/internal/handshake/token_generator_test.go index 3aef6a3d126..4d4be0f239e 100644 --- a/internal/handshake/token_generator_test.go +++ b/internal/handshake/token_generator_test.go @@ -35,16 +35,13 @@ var _ = Describe("Token Generator", func() { }) It("accepts a valid token", func() { - ip := net.IPv4(192, 168, 0, 1) - tokenEnc, err := tokenGen.NewRetryToken( - &net.UDPAddr{IP: ip, Port: 1337}, - nil, - nil, - ) + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + tokenEnc, err := tokenGen.NewRetryToken(addr, nil, nil) Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) - Expect(token.RemoteAddr).To(Equal("192.168.0.1")) + Expect(token.ValidateRemoteAddr(addr)).To(BeTrue()) + Expect(token.ValidateRemoteAddr(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 2), Port: 1337})).To(BeFalse()) Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) Expect(token.OriginalDestConnectionID.Len()).To(BeZero()) Expect(token.RetrySrcConnectionID.Len()).To(BeZero()) @@ -110,7 +107,7 @@ var _ = Describe("Token Generator", func() { Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) - Expect(token.RemoteAddr).To(Equal(ip.String())) + Expect(token.ValidateRemoteAddr(raddr)).To(BeTrue()) Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) } }) @@ -121,7 +118,8 @@ var _ = Describe("Token Generator", func() { Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) - Expect(token.RemoteAddr).To(Equal("192.168.13.37:1337")) + Expect(token.ValidateRemoteAddr(raddr)).To(BeTrue()) + Expect(token.ValidateRemoteAddr(&net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1338})).To(BeFalse()) Expect(token.SentTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) }) }) diff --git a/interop/server/main.go b/interop/server/main.go index 2674fd5bdc6..d92ec1d355b 100644 --- a/interop/server/main.go +++ b/interop/server/main.go @@ -44,8 +44,8 @@ func main() { } // a quic.Config that doesn't do a Retry quicConf := &quic.Config{ - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - Tracer: qlog.NewTracer(getLogWriter), + RequireAddressValidation: func(net.Addr) bool { return testcase == "retry" }, + Tracer: qlog.NewTracer(getLogWriter), } cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key") if err != nil { @@ -58,15 +58,11 @@ func main() { } switch testcase { - case "versionnegotiation", "handshake", "transfer", "resumption", "zerortt", "multiconnect": + case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "zerortt", "multiconnect": err = runHTTP09Server(quicConf) case "chacha20": tlsConf.CipherSuites = []uint16{tls.TLS_CHACHA20_POLY1305_SHA256} err = runHTTP09Server(quicConf) - case "retry": - // By default, quic-go performs a Retry on every incoming connection. - quicConf.AcceptToken = nil - err = runHTTP09Server(quicConf) case "http3": err = runHTTP3Server(quicConf) default: diff --git a/server.go b/server.go index 0e64297050c..1d14302ec19 100644 --- a/server.go +++ b/server.go @@ -241,26 +241,6 @@ func (s *baseServer) run() { } } -var defaultAcceptToken = func(clientAddr net.Addr, token *Token) bool { - if token == nil { - return false - } - validity := protocol.TokenValidity - if token.IsRetryToken { - validity = protocol.RetryTokenValidity - } - if time.Now().After(token.SentTime.Add(validity)) { - return false - } - var sourceAddr string - if udpAddr, ok := clientAddr.(*net.UDPAddr); ok { - sourceAddr = udpAddr.IP.String() - } else { - sourceAddr = clientAddr.String() - } - return sourceAddr == token.RemoteAddr -} - // Accept returns connections that already completed the handshake. // It is only valid if acceptEarlyConns is false. func (s *baseServer) Accept(ctx context.Context) (Connection, error) { @@ -405,33 +385,45 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro } var ( - token *Token + token *handshake.Token retrySrcConnID *protocol.ConnectionID ) origDestConnID := hdr.DestConnectionID if len(hdr.Token) > 0 { - c, err := s.tokenGenerator.DecodeToken(hdr.Token) + tok, err := s.tokenGenerator.DecodeToken(hdr.Token) if err == nil { - token = &Token{ - IsRetryToken: c.IsRetryToken, - RemoteAddr: c.RemoteAddr, - SentTime: c.SentTime, - } - if token.IsRetryToken { - origDestConnID = c.OriginalDestConnectionID - retrySrcConnID = &c.RetrySrcConnectionID + if tok.IsRetryToken { + origDestConnID = tok.OriginalDestConnectionID + retrySrcConnID = &tok.RetrySrcConnectionID } + token = tok } } - if !s.config.AcceptToken(p.remoteAddr, token) { + if token != nil { + addrIsValid := token.ValidateRemoteAddr(p.remoteAddr) + // For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error. + // We just ignore them, and act as if there was no token on this packet at all. + // This also means we might send a Retry later. + if !token.IsRetryToken && (time.Since(token.SentTime) > s.config.MaxTokenAge || !addrIsValid) { + token = nil + } else if token.IsRetryToken && (time.Since(token.SentTime) > s.config.MaxRetryTokenAge || !addrIsValid) { + // For Retry tokens, we send an INVALID_ERROR if + // * the token is too old, or + // * the token is invalid, in case of a retry token. + go func() { + defer p.buffer.Release() + if token != nil && token.IsRetryToken { + if err := s.maybeSendInvalidToken(p, hdr); err != nil { + s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) + } + } + }() + return nil + } + } + if token == nil && s.config.RequireAddressValidation(p.remoteAddr) { go func() { defer p.buffer.Release() - if token != nil && token.IsRetryToken { - if err := s.maybeSendInvalidToken(p, hdr); err != nil { - s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) - } - return - } if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil { s.logger.Debugf("Error sending Retry: %s", err) } diff --git a/server_test.go b/server_test.go index 2cdb98d06db..4944bccfd2c 100644 --- a/server_test.go +++ b/server_test.go @@ -126,22 +126,22 @@ var _ = Describe("Server", func() { Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) Expect(server.config.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) Expect(server.config.MaxIdleTimeout).To(Equal(protocol.DefaultIdleTimeout)) - Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(defaultAcceptToken))) - Expect(server.config.KeepAlivePeriod).To(Equal(0 * time.Second)) + Expect(server.config.RequireAddressValidation).ToNot(BeNil()) + Expect(server.config.KeepAlivePeriod).To(BeZero()) // stop the listener Expect(ln.Close()).To(Succeed()) }) It("setups with the right values", func() { supportedVersions := []protocol.VersionNumber{protocol.VersionTLS} - acceptToken := func(_ net.Addr, _ *Token) bool { return true } + requireAddrVal := func(net.Addr) bool { return true } config := Config{ - Versions: supportedVersions, - AcceptToken: acceptToken, - HandshakeIdleTimeout: 1337 * time.Hour, - MaxIdleTimeout: 42 * time.Minute, - KeepAlivePeriod: 5 * time.Second, - StatelessResetKey: []byte("foobar"), + Versions: supportedVersions, + HandshakeIdleTimeout: 1337 * time.Hour, + MaxIdleTimeout: 42 * time.Minute, + KeepAlivePeriod: 5 * time.Second, + StatelessResetKey: []byte("foobar"), + RequireAddressValidation: requireAddrVal, } ln, err := Listen(conn, tlsConf, &config) Expect(err).ToNot(HaveOccurred()) @@ -150,7 +150,7 @@ var _ = Describe("Server", func() { Expect(server.config.Versions).To(Equal(supportedVersions)) Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour)) Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) - Expect(reflect.ValueOf(server.config.AcceptToken)).To(Equal(reflect.ValueOf(acceptToken))) + Expect(reflect.ValueOf(server.config.RequireAddressValidation)).To(Equal(reflect.ValueOf(requireAddrVal))) Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second)) Expect(server.config.StatelessResetKey).To(Equal([]byte("foobar"))) // stop the listener @@ -239,62 +239,11 @@ var _ = Describe("Server", func() { time.Sleep(50 * time.Millisecond) }) - It("decodes the token from the Token field", func() { - raddr := &net.UDPAddr{ - IP: net.IPv4(192, 168, 13, 37), - Port: 1337, - } - done := make(chan struct{}) - serv.config.AcceptToken = func(addr net.Addr, token *Token) bool { - Expect(addr).To(Equal(raddr)) - Expect(token).ToNot(BeNil()) - close(done) - return false - } - token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil) - Expect(err).ToNot(HaveOccurred()) - packet := getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - Token: token, - Version: serv.config.Versions[0], - }, make([]byte, protocol.MinInitialPacketSize)) - packet.remoteAddr = raddr - conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - - It("passes an empty token to the callback, if decoding fails", func() { - raddr := &net.UDPAddr{ - IP: net.IPv4(192, 168, 13, 37), - Port: 1337, - } - done := make(chan struct{}) - serv.config.AcceptToken = func(addr net.Addr, token *Token) bool { - Expect(addr).To(Equal(raddr)) - Expect(token).To(BeNil()) - close(done) - return false - } - packet := getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - Token: []byte("foobar"), - Version: serv.config.Versions[0], - }, make([]byte, protocol.MinInitialPacketSize)) - packet.remoteAddr = raddr - conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) - tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - It("creates a connection when the token is accepted", func() { - serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true } + serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} retryToken, err := serv.tokenGenerator.NewRetryToken( - &net.UDPAddr{}, + raddr, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, ) @@ -308,6 +257,7 @@ var _ = Describe("Server", func() { Token: retryToken, } p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + p.remoteAddr = raddr run := make(chan struct{}) var token protocol.StatelessResetToken rand.Read(token[:]) @@ -469,8 +419,8 @@ var _ = Describe("Server", func() { time.Sleep(scaleDuration(20 * time.Millisecond)) }) - It("replies with a Retry packet, if a Token is required", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } + It("replies with a Retry packet, if a token is required", func() { + serv.config.RequireAddressValidation = func(net.Addr) bool { return true } hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, @@ -502,81 +452,7 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) - It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } - token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - Token: token, - Version: protocol.VersionTLS, - } - packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet - raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - packet.remoteAddr = raddr - tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(frames).To(HaveLen(1)) - Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := frames[0].(*logging.ConnectionCloseFrame) - Expect(ccf.IsApplicationError).To(BeFalse()) - Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) - }) - done := make(chan struct{}) - conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { - defer close(done) - replyHdr := parseHeader(b) - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - _, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) - extHdr, err := unpackHeader(opener, replyHdr, b, hdr.Version) - Expect(err).ToNot(HaveOccurred()) - data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) - Expect(err).ToNot(HaveOccurred()) - f, err := wire.NewFrameParser(false, hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := f.(*wire.ConnectionCloseFrame) - Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) - Expect(ccf.ReasonPhrase).To(BeEmpty()) - return len(b), nil - }) - serv.handlePacket(packet) - Eventually(done).Should(BeClosed()) - }) - - It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return false } - token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) - Expect(err).ToNot(HaveOccurred()) - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - Token: token, - Version: protocol.VersionTLS, - } - packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet - packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - done := make(chan struct{}) - tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) - serv.handlePacket(packet) - // make sure there are no Write calls on the packet conn - time.Sleep(50 * time.Millisecond) - Eventually(done).Should(BeClosed()) - }) - - It("creates a connection, if no Token is required", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } + It("creates a connection, if no token is required", func() { hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, @@ -659,7 +535,6 @@ var _ = Describe("Server", func() { }).AnyTimes() tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes() - serv.config.AcceptToken = func(net.Addr, *Token) bool { return true } acceptConn := make(chan struct{}) var counter uint32 // to be used as an atomic, so we query it in Eventually serv.newConn = func( @@ -713,7 +588,6 @@ var _ = Describe("Server", func() { }) It("only creates a single connection for a duplicate Initial", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } var createdConn bool conn := NewMockQuicConn(mockCtrl) serv.newConn = func( @@ -745,8 +619,6 @@ var _ = Describe("Server", func() { }) It("rejects new connection attempts if the accept queue is full", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - serv.newConn = func( _ sendConn, runner connRunner, @@ -813,8 +685,6 @@ var _ = Describe("Server", func() { }) It("doesn't accept new connections if they were closed in the mean time", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) ctx, cancel := context.WithCancel(context.Background()) connCreated := make(chan struct{}) @@ -877,6 +747,200 @@ var _ = Describe("Server", func() { }) }) + Context("token validation", func() { + checkInvalidToken := func(b []byte, origHdr *wire.Header) { + replyHdr := parseHeader(b) + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(replyHdr.SrcConnectionID).To(Equal(origHdr.DestConnectionID)) + Expect(replyHdr.DestConnectionID).To(Equal(origHdr.SrcConnectionID)) + _, opener := handshake.NewInitialAEAD(origHdr.DestConnectionID, protocol.PerspectiveClient, replyHdr.Version) + extHdr, err := unpackHeader(opener, replyHdr, b, origHdr.Version) + Expect(err).ToNot(HaveOccurred()) + data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) + Expect(err).ToNot(HaveOccurred()) + f, err := wire.NewFrameParser(false, origHdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := f.(*wire.ConnectionCloseFrame) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) + Expect(ccf.ReasonPhrase).To(BeEmpty()) + } + + It("decodes the token from the token field", func() { + raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} + token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil) + Expect(err).ToNot(HaveOccurred()) + packet := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Token: token, + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = raddr + conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) + + done := make(chan struct{}) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) }) + serv.handlePacket(packet) + Eventually(done).Should(BeClosed()) + }) + + It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { + serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Token: token, + Version: protocol.VersionTLS, + } + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + packet.remoteAddr = raddr + tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(frames).To(HaveLen(1)) + Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := frames[0].(*logging.ConnectionCloseFrame) + Expect(ccf.IsApplicationError).To(BeFalse()) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) + }) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + checkInvalidToken(b, hdr) + return len(b), nil + }) + serv.handlePacket(packet) + Eventually(done).Should(BeClosed()) + }) + + It("sends an INVALID_TOKEN error, if an expired retry token is received", func() { + serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + serv.config.MaxRetryTokenAge = time.Millisecond + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil) + Expect(err).ToNot(HaveOccurred()) + time.Sleep(2 * time.Millisecond) // make sure the token is expired + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Token: token, + Version: protocol.VersionTLS, + } + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = raddr + tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(frames).To(HaveLen(1)) + Expect(frames[0]).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := frames[0].(*logging.ConnectionCloseFrame) + Expect(ccf.IsApplicationError).To(BeFalse()) + Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.InvalidToken)) + }) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + checkInvalidToken(b, hdr) + return len(b), nil + }) + serv.handlePacket(packet) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't send an INVALID_TOKEN error, if an invalid non-retry token is received", func() { + serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}) + Expect(err).ToNot(HaveOccurred()) + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Token: token, + Version: protocol.VersionTLS, + } + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + packet.remoteAddr = raddr + tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + replyHdr := parseHeader(b) + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) + return len(b), nil + }) + serv.handlePacket(packet) + // make sure there are no Write calls on the packet conn + Eventually(done).Should(BeClosed()) + }) + + It("sends an INVALID_TOKEN error, if an expired non-retry token is received", func() { + serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + serv.config.MaxTokenAge = time.Millisecond + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + token, err := serv.tokenGenerator.NewToken(raddr) + Expect(err).ToNot(HaveOccurred()) + time.Sleep(2 * time.Millisecond) // make sure the token is expired + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Token: token, + Version: protocol.VersionTLS, + } + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = raddr + tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) + }) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + return len(b), nil + }) + serv.handlePacket(packet) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { + serv.config.RequireAddressValidation = func(net.Addr) bool { return true } + token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Token: token, + Version: protocol.VersionTLS, + } + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.data[len(packet.data)-10] ^= 0xff // corrupt the packet + packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + done := make(chan struct{}) + tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) + serv.handlePacket(packet) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) + Eventually(done).Should(BeClosed()) + }) + }) + Context("accepting connections", func() { It("returns Accept when an error occurs", func() { testErr := errors.New("test err") @@ -930,7 +994,6 @@ var _ = Describe("Server", func() { }() ctx, cancel := context.WithCancel(context.Background()) // handshake context - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } serv.newConn = func( _ sendConn, runner connRunner, @@ -1004,7 +1067,6 @@ var _ = Describe("Server", func() { }() ready := make(chan struct{}) - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } serv.newConn = func( _ sendConn, runner connRunner, @@ -1045,7 +1107,6 @@ var _ = Describe("Server", func() { }) It("rejects new connection attempts if the accept queue is full", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} serv.newConn = func( @@ -1106,8 +1167,6 @@ var _ = Describe("Server", func() { }) It("doesn't accept new connections if they were closed in the mean time", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) ctx, cancel := context.WithCancel(context.Background()) connCreated := make(chan struct{}) @@ -1166,72 +1225,3 @@ var _ = Describe("Server", func() { }) }) }) - -var _ = Describe("default source address verification", func() { - It("accepts a token", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} - token := &Token{ - IsRetryToken: true, - RemoteAddr: "192.168.0.1", - SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(time.Second), // will expire in 1 second - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue()) - }) - - It("requests verification if no token is provided", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} - Expect(defaultAcceptToken(remoteAddr, nil)).To(BeFalse()) - }) - - It("rejects a token if the address doesn't match", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} - token := &Token{ - IsRetryToken: true, - RemoteAddr: "127.0.0.1", - SentTime: time.Now(), - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse()) - }) - - It("accepts a token for a remote address is not a UDP address", func() { - remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - token := &Token{ - IsRetryToken: true, - RemoteAddr: "192.168.0.1:1337", - SentTime: time.Now(), - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue()) - }) - - It("rejects an invalid token for a remote address is not a UDP address", func() { - remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - token := &Token{ - IsRetryToken: true, - RemoteAddr: "192.168.0.1:7331", // mismatching port - SentTime: time.Now(), - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse()) - }) - - It("rejects an expired token", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} - token := &Token{ - IsRetryToken: true, - RemoteAddr: "192.168.0.1", - SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second), // expired 1 second ago - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeFalse()) - }) - - It("accepts a non-retry token", func() { - Expect(protocol.RetryTokenValidity).To(BeNumerically("<", protocol.TokenValidity)) - remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)} - token := &Token{ - IsRetryToken: false, - RemoteAddr: "192.168.0.1", - // if this was a retry token, it would have expired one second ago - SentTime: time.Now().Add(-protocol.RetryTokenValidity).Add(-time.Second), - } - Expect(defaultAcceptToken(remoteAddr, token)).To(BeTrue()) - }) -})