Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement a new API to let servers control client address verification #3501

Merged
merged 3 commits into from Aug 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 13 additions & 5 deletions config.go
Expand Up @@ -2,11 +2,11 @@ package quic

import (
"errors"
"net"
"time"

"github.com/lucas-clemente/quic-go/internal/utils"

"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)

// Clone clones a Config
Expand Down Expand Up @@ -39,8 +39,14 @@ func populateServerConfig(config *Config) *Config {
if config.ConnectionIDLength == 0 {
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
}
if config.AcceptToken == nil {
config.AcceptToken = defaultAcceptToken
if config.MaxTokenAge == 0 {
config.MaxTokenAge = protocol.TokenValidity
}
if config.MaxRetryTokenAge == 0 {
config.MaxRetryTokenAge = protocol.RetryTokenValidity
}
if config.RequireAddressValidation == nil {
config.RequireAddressValidation = func(net.Addr) bool { return false }
}
return config
}
Expand Down Expand Up @@ -104,7 +110,9 @@ func populateConfig(config *Config) *Config {
Versions: versions,
HandshakeIdleTimeout: handshakeIdleTimeout,
MaxIdleTimeout: idleTimeout,
AcceptToken: config.AcceptToken,
MaxTokenAge: config.MaxTokenAge,
MaxRetryTokenAge: config.MaxRetryTokenAge,
RequireAddressValidation: config.RequireAddressValidation,
KeepAlivePeriod: config.KeepAlivePeriod,
InitialStreamReceiveWindow: initialStreamReceiveWindow,
MaxStreamReceiveWindow: maxStreamReceiveWindow,
Expand Down
35 changes: 19 additions & 16 deletions config_test.go
Expand Up @@ -45,7 +45,7 @@ var _ = Describe("Config", func() {
}

switch fn := typ.Field(i).Name; fn {
case "AcceptToken", "GetLogWriter", "AllowConnectionWindowIncrease":
case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease":
// Can't compare functions.
case "Versions":
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
Expand All @@ -55,6 +55,10 @@ var _ = Describe("Config", func() {
f.Set(reflect.ValueOf(time.Second))
case "MaxIdleTimeout":
f.Set(reflect.ValueOf(time.Hour))
case "MaxTokenAge":
f.Set(reflect.ValueOf(2 * time.Hour))
case "MaxRetryTokenAge":
f.Set(reflect.ValueOf(2 * time.Minute))
case "TokenStore":
f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3)))
case "InitialStreamReceiveWindow":
Expand Down Expand Up @@ -100,14 +104,14 @@ var _ = Describe("Config", func() {

Context("cloning", func() {
It("clones function fields", func() {
var calledAcceptToken, calledAllowConnectionWindowIncrease bool
var calledAddrValidation, calledAllowConnectionWindowIncrease bool
c1 := &Config{
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true },
}
c2 := c1.Clone()
c2.AcceptToken(&net.UDPAddr{}, &Token{})
Expect(calledAcceptToken).To(BeTrue())
c2.RequireAddressValidation(&net.UDPAddr{})
Expect(calledAddrValidation).To(BeTrue())
c2.AllowConnectionWindowIncrease(nil, 1234)
Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
})
Expand All @@ -119,27 +123,26 @@ var _ = Describe("Config", func() {

It("returns a copy", func() {
c1 := &Config{
MaxIncomingStreams: 100,
AcceptToken: func(_ net.Addr, _ *Token) bool { return true },
MaxIncomingStreams: 100,
RequireAddressValidation: func(net.Addr) bool { return true },
}
c2 := c1.Clone()
c2.MaxIncomingStreams = 200
c2.AcceptToken = func(_ net.Addr, _ *Token) bool { return false }
c2.RequireAddressValidation = func(net.Addr) bool { return false }

Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
Expect(c1.AcceptToken(&net.UDPAddr{}, nil)).To(BeTrue())
Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue())
})
})

Context("populating", func() {
It("populates function fields", func() {
var calledAcceptToken bool
c1 := &Config{
AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true },
}
var calledAddrValidation bool
c1 := &Config{}
c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
c2 := populateConfig(c1)
c2.AcceptToken(&net.UDPAddr{}, &Token{})
Expect(calledAcceptToken).To(BeTrue())
c2.RequireAddressValidation(&net.UDPAddr{})
Expect(calledAddrValidation).To(BeTrue())
})

It("copies non-function fields", func() {
Expand All @@ -164,7 +167,7 @@ var _ = Describe("Config", func() {
It("populates empty fields with default values, for the server", func() {
c := populateServerConfig(&Config{})
Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
Expect(c.AcceptToken).ToNot(BeNil())
Expect(c.RequireAddressValidation).ToNot(BeNil())
})

It("sets a default connection ID length if we didn't create the conn, for the client", func() {
Expand Down
9 changes: 7 additions & 2 deletions connection.go
Expand Up @@ -543,7 +543,11 @@ func (s *connection) run() error {

s.timer = utils.NewTimer()

go s.cryptoStreamHandler.RunHandshake()
handshaking := make(chan struct{})
go func() {
defer close(handshaking)
s.cryptoStreamHandler.RunHandshake()
}()
go func() {
if err := s.sendQueue.Run(); err != nil {
s.destroyImpl(err)
Expand Down Expand Up @@ -694,12 +698,13 @@ runLoop:
}
}

s.cryptoStreamHandler.Close()
<-handshaking
s.handleCloseError(&closeErr)
if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil {
s.tracer.Close()
}
s.logger.Infof("Connection %s closed.", s.logID)
s.cryptoStreamHandler.Close()
s.sendQueue.Close()
s.timer.Stop()
return closeErr.err
Expand Down
19 changes: 0 additions & 19 deletions fuzzing/tokens/fuzz.go
Expand Up @@ -2,7 +2,6 @@ package tokens

import (
"encoding/binary"
"fmt"
"math/rand"
"net"
"time"
Expand Down Expand Up @@ -77,7 +76,6 @@ func newToken(tg *handshake.TokenGenerator, data []byte) int {
if token.OriginalDestConnectionID != nil || token.RetrySrcConnectionID != nil {
panic("didn't expect connection IDs")
}
checkAddr(token.RemoteAddr, addr)
return 1
}

Expand Down Expand Up @@ -140,22 +138,5 @@ func newRetryToken(tg *handshake.TokenGenerator, data []byte) int {
if !token.RetrySrcConnectionID.Equal(retrySrcConnID) {
panic("retry src conn ID doesn't match")
}
checkAddr(token.RemoteAddr, addr)
return 1
}

func checkAddr(tokenAddr string, addr net.Addr) {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
// For UDP addresses, we encode only the IP (not the port).
if ip := udpAddr.IP.String(); tokenAddr != ip {
fmt.Printf("%s vs %s", tokenAddr, ip)
panic("wrong remote address for a net.UDPAddr")
}
return
}

if tokenAddr != addr.String() {
fmt.Printf("%s vs %s", tokenAddr, addr.String())
panic("wrong remote address")
}
}
10 changes: 4 additions & 6 deletions integrationtests/self/handshake_drop_test.go
Expand Up @@ -37,13 +37,11 @@ var _ = Describe("Handshake drop tests", func() {

startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool, version protocol.VersionNumber) {
conf := getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
Versions: []protocol.VersionNumber{version},
MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout,
Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return doRetry },
})
if !doRetry {
conf.AcceptToken = func(net.Addr, *quic.Token) bool { return true }
}
var tlsConf *tls.Config
if longCertChain {
tlsConf = getTLSConfigWithLongCertChain()
Expand Down
24 changes: 1 addition & 23 deletions integrationtests/self/handshake_rtt_test.go
Expand Up @@ -101,6 +101,7 @@ var _ = Describe("Handshake RTT tests", func() {
// 1 RTT for verifying the source address
// 1 RTT for the TLS handshake
It("is forward-secure after 2 RTTs", func() {
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
runServerAndProxy()
_, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
Expand All @@ -112,9 +113,6 @@ var _ = Describe("Handshake RTT tests", func() {
})

It("establishes a connection in 1 RTT when the server doesn't require a token", func() {
serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool {
return true
}
runServerAndProxy()
_, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
Expand All @@ -126,9 +124,6 @@ var _ = Describe("Handshake RTT tests", func() {
})

It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() {
serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool {
return true
}
serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384}
runServerAndProxy()
_, err := quic.DialAddr(
Expand All @@ -139,21 +134,4 @@ var _ = Describe("Handshake RTT tests", func() {
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})

It("doesn't complete the handshake when the server never accepts the token", func() {
serverConfig.AcceptToken = func(_ net.Addr, _ *quic.Token) bool {
return false
}
clientConfig.HandshakeIdleTimeout = 500 * time.Millisecond
runServerAndProxy()
_, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
getTLSClientConfig(),
clientConfig,
)
Expect(err).To(HaveOccurred())
nerr, ok := err.(net.Error)
Expect(ok).To(BeTrue())
Expect(nerr.Timeout()).To(BeTrue())
})
})
35 changes: 2 additions & 33 deletions integrationtests/self/handshake_test.go
Expand Up @@ -344,12 +344,6 @@ var _ = Describe("Handshake tests", func() {
}

BeforeEach(func() {
serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool {
if token != nil {
Expect(token.IsRetryToken).To(BeFalse())
}
return true
}
var err error
// start the server, but don't call Accept
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expand Down Expand Up @@ -479,14 +473,6 @@ var _ = Describe("Handshake tests", func() {

Context("using tokens", func() {
It("uses tokens provided in NEW_TOKEN frames", func() {
tokenChan := make(chan *quic.Token, 100)
serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool {
if token != nil && !token.IsRetryToken {
tokenChan <- token
}
return true
}

server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())

Expand All @@ -509,7 +495,6 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
Expect(gets).To(Receive())
Eventually(puts).Should(Receive())
Expect(tokenChan).ToNot(Receive())
// received a token. Close this connection.
Expect(conn.CloseWithError(0, "")).To(Succeed())

Expand All @@ -529,17 +514,13 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
Expect(gets).To(Receive())
Expect(tokenChan).To(Receive())

Eventually(done).Should(BeClosed())
})

It("rejects invalid Retry token with the INVALID_TOKEN error", func() {
tokenChan := make(chan *quic.Token, 10)
serverConfig.AcceptToken = func(addr net.Addr, token *quic.Token) bool {
tokenChan <- token
return false
}
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
serverConfig.MaxRetryTokenAge = time.Nanosecond

server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -554,18 +535,6 @@ var _ = Describe("Handshake tests", func() {
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.InvalidToken))
// Receiving a Retry might lead the client to measure a very small RTT.
// Then, it sometimes would retransmit the ClientHello before receiving the ServerHello.
Expect(len(tokenChan)).To(BeNumerically(">=", 2))
token := <-tokenChan
Expect(token).To(BeNil())
token = <-tokenChan
Expect(token).ToNot(BeNil())
// If the ClientHello was retransmitted, make sure that it contained the same Retry token.
for i := 2; i < len(tokenChan); i++ {
Expect(<-tokenChan).To(Equal(token))
}
Expect(token.IsRetryToken).To(BeTrue())
})
})

Expand Down