diff --git a/.circleci/config.yml b/.circleci/config.yml index db1e8f5e8fc..74170a533cc 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,5 +1,11 @@ version: 2.1 executors: + test-go118: + docker: + - image: "cimg/go:1.18" + environment: + runrace: true + TIMESCALE_FACTOR: 3 test-go117: docker: - image: "cimg/go:1.17" @@ -15,7 +21,7 @@ executors: jobs: "test": &test - executor: test-go117 + executor: test-go118 steps: - checkout - run: @@ -42,8 +48,11 @@ jobs: - run: name: "Run self integration tests with qlog" command: ginkgo -v -randomizeAllSpecs -trace integrationtests/self -- -qlog + go118: + <<: *test go117: <<: *test + executor: test-go117 go116: <<: *test executor: test-go116 @@ -51,5 +60,6 @@ jobs: workflows: workflow: jobs: - - go116 + - go118 - go117 + - go116 diff --git a/.github/workflows/cross-compile.yml b/.github/workflows/cross-compile.yml index 770d5d66564..64af3506dd7 100644 --- a/.github/workflows/cross-compile.yml +++ b/.github/workflows/cross-compile.yml @@ -4,7 +4,7 @@ jobs: strategy: fail-fast: false matrix: - go: [ "1.16.x", "1.17.x" ] + go: [ "1.17.x", "1.18.x" ] runs-on: ubuntu-latest name: "Cross Compilation (Go ${{matrix.go}})" steps: diff --git a/.github/workflows/go-generate.sh b/.github/workflows/go-generate.sh index ff168aa7081..37edcacc138 100755 --- a/.github/workflows/go-generate.sh +++ b/.github/workflows/go-generate.sh @@ -20,4 +20,4 @@ go generate ./... cd .. # don't compare fuzzing corpora -diff --exclude=corpus -ruN orig generated +diff --exclude=corpus --exclude=.git -ruN orig generated diff --git a/.github/workflows/go-generate.yml b/.github/workflows/go-generate.yml index 6f563f348cd..70a86bd477d 100644 --- a/.github/workflows/go-generate.yml +++ b/.github/workflows/go-generate.yml @@ -6,7 +6,7 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-go@v2 with: - go-version: "1.17.x" + go-version: "1.18.x" - name: Install dependencies run: go build - name: Install code generators diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 073eaedddc8..df81f7abc3e 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -5,7 +5,7 @@ jobs: strategy: fail-fast: false matrix: - go: [ "1.16.x", "1.17.x", "1.18.0-beta1" ] + go: [ "1.16.x", "1.17.x", "1.18.x" ] runs-on: ubuntu-latest env: DEBUG: false # set this to true to export qlogs and save them as artifacts diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index fe8b55235c0..504180ed5e0 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -7,7 +7,7 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-go@v2 with: - go-version: "1.17.x" + go-version: "1.18.x" - name: Check that no non-test files import Ginkgo or Gomega run: .github/workflows/no_ginkgo.sh - name: Check that go.mod is tidied @@ -28,4 +28,4 @@ jobs: - name: golangci-lint uses: golangci/golangci-lint-action@v2 with: - version: v1.41.1 + version: v1.45.2 diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml index 180c36c17f0..426ef3a7a2b 100644 --- a/.github/workflows/unit.yml +++ b/.github/workflows/unit.yml @@ -7,7 +7,7 @@ jobs: fail-fast: false matrix: os: [ "ubuntu", "windows", "macos" ] - go: [ "1.16.x", "1.17.x", "1.18.0-beta1" ] + go: [ "1.16.x", "1.17.x", "1.18.x" ] runs-on: ${{ matrix.os }}-latest name: Unit tests (${{ matrix.os}}, Go ${{ matrix.go }}) steps: diff --git a/.golangci.yml b/.golangci.yml index 05ddb79ac92..2589c053892 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -28,7 +28,6 @@ linters: - ineffassign - misspell - prealloc - - scopelint - staticcheck - stylecheck - structcheck diff --git a/README.md b/README.md index a5995879ca4..f284174c20b 100644 --- a/README.md +++ b/README.md @@ -3,17 +3,15 @@ [![PkgGoDev](https://pkg.go.dev/badge/github.com/lucas-clemente/quic-go)](https://pkg.go.dev/github.com/lucas-clemente/quic-go) -[![Travis Build Status](https://img.shields.io/travis/lucas-clemente/quic-go/master.svg?style=flat-square&label=Travis+build)](https://travis-ci.org/lucas-clemente/quic-go) -[![CircleCI Build Status](https://img.shields.io/circleci/project/github/lucas-clemente/quic-go.svg?style=flat-square&label=CircleCI+build)](https://circleci.com/gh/lucas-clemente/quic-go) -[![Windows Build Status](https://img.shields.io/appveyor/ci/lucas-clemente/quic-go/master.svg?style=flat-square&label=windows+build)](https://ci.appveyor.com/project/lucas-clemente/quic-go/branch/master) [![Code Coverage](https://img.shields.io/codecov/c/github/lucas-clemente/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/lucas-clemente/quic-go/) -quic-go is an implementation of the [QUIC protocol, RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000) protocol in Go. +quic-go is an implementation of the [QUIC protocol, RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000) protocol in Go, including the [Unreliable Datagram Extension, RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221). + In addition to RFC 9000, it currently implements the [IETF QUIC draft-29](https://tools.ietf.org/html/draft-ietf-quic-transport-29). Support for draft-29 will eventually be dropped, as it is phased out of the ecosystem. ## Guides -*We currently support Go 1.16.x and Go 1.17.x.* +*We currently support Go 1.16.x, Go 1.17.x, and Go 1.18.x.* Running tests: @@ -51,11 +49,12 @@ http.Client{ | [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) | | [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) | | [go-ipfs](https://github.com/ipfs/go-ipfs) | IPFS implementation in go | ![GitHub Repo stars](https://img.shields.io/github/stars/ipfs/go-ipfs?style=flat-square) | -| [nextdns](https://github.com/nextdns/nextdns) | NextDNS CLI client (DoH Proxy) | ![GitHub Repo stars](https://img.shields.io/github/stars/nextdns/nextdns?style=flat-square) | | [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) | | [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) | | [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) | -| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) | +| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) | +| [OONI Probe](https://github.com/ooni/probe-cli) | The Open Observatory of Network Interference (OONI) aims to empower decentralized efforts in documenting Internet censorship around the world. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) | + ## Contributing diff --git a/client.go b/client.go index 9dbe4ac5c3e..be8390e6524 100644 --- a/client.go +++ b/client.go @@ -14,7 +14,7 @@ import ( ) type client struct { - conn sendConn + sconn sendConn // If the client is created with DialAddr, we create a packet conn. // If it is started with Dial, we take a packet conn as a parameter. createdPacketConn bool @@ -35,7 +35,7 @@ type client struct { handshakeChan chan struct{} - session quicSession + conn quicConn tracer logging.ConnectionTracer tracingID uint64 @@ -49,26 +49,26 @@ var ( ) // DialAddr establishes a new QUIC connection to a server. -// It uses a new UDP connection and closes this connection when the QUIC session is closed. +// It uses a new UDP connection and closes this connection when the QUIC connection is closed. // The hostname for SNI is taken from the given address. // The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. func DialAddr( addr string, tlsConf *tls.Config, config *Config, -) (Session, error) { +) (Connection, error) { return DialAddrContext(context.Background(), addr, tlsConf, config) } // DialAddrEarly establishes a new 0-RTT QUIC connection to a server. -// It uses a new UDP connection and closes this connection when the QUIC session is closed. +// It uses a new UDP connection and closes this connection when the QUIC connection is closed. // The hostname for SNI is taken from the given address. // The tls.Config.CipherSuites allows setting of TLS 1.3 cipher suites. func DialAddrEarly( addr string, tlsConf *tls.Config, config *Config, -) (EarlySession, error) { +) (EarlyConnection, error) { return DialAddrEarlyContext(context.Background(), addr, tlsConf, config) } @@ -79,13 +79,13 @@ func DialAddrEarlyContext( addr string, tlsConf *tls.Config, config *Config, -) (EarlySession, error) { - sess, err := dialAddrContext(ctx, addr, tlsConf, config, true) +) (EarlyConnection, error) { + conn, err := dialAddrContext(ctx, addr, tlsConf, config, true) if err != nil { return nil, err } - utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early session") - return sess, nil + utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection") + return conn, nil } // DialAddrContext establishes a new QUIC connection to a server using the provided context. @@ -95,7 +95,7 @@ func DialAddrContext( addr string, tlsConf *tls.Config, config *Config, -) (Session, error) { +) (Connection, error) { return dialAddrContext(ctx, addr, tlsConf, config, false) } @@ -105,7 +105,7 @@ func dialAddrContext( tlsConf *tls.Config, config *Config, use0RTT bool, -) (quicSession, error) { +) (quicConn, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err @@ -131,7 +131,7 @@ func Dial( host string, tlsConf *tls.Config, config *Config, -) (Session, error) { +) (Connection, error) { return dialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config, false, false) } @@ -146,7 +146,7 @@ func DialEarly( host string, tlsConf *tls.Config, config *Config, -) (EarlySession, error) { +) (EarlyConnection, error) { return DialEarlyContext(context.Background(), pconn, remoteAddr, host, tlsConf, config) } @@ -159,7 +159,7 @@ func DialEarlyContext( host string, tlsConf *tls.Config, config *Config, -) (EarlySession, error) { +) (EarlyConnection, error) { return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, true, false) } @@ -172,7 +172,7 @@ func DialContext( host string, tlsConf *tls.Config, config *Config, -) (Session, error) { +) (Connection, error) { return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false, false) } @@ -185,7 +185,7 @@ func dialContext( config *Config, use0RTT bool, createdPacketConn bool, -) (quicSession, error) { +) (quicConn, error) { if tlsConf == nil { return nil, errors.New("quic: tls.Config not set") } @@ -203,21 +203,21 @@ func dialContext( } c.packetHandlers = packetHandlers - c.tracingID = nextSessionTracingID() + c.tracingID = nextConnTracingID() if c.config.Tracer != nil { c.tracer = c.config.Tracer.TracerForConnection( - context.WithValue(ctx, SessionTracingKey, c.tracingID), + context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID, ) } if c.tracer != nil { - c.tracer.StartedConnection(c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID) + c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID) } if err := c.dial(ctx); err != nil { return nil, err } - return c.session, nil + return c.conn, nil } func newClient( @@ -231,6 +231,8 @@ func newClient( ) (*client, error) { if tlsConf == nil { tlsConf = &tls.Config{} + } else { + tlsConf = tlsConf.Clone() } if tlsConf.ServerName == "" { sni := host @@ -265,7 +267,7 @@ func newClient( c := &client{ srcConnID: srcConnID, destConnID: destConnID, - conn: newSendPconn(pconn, remoteAddr), + sconn: newSendPconn(pconn, remoteAddr), createdPacketConn: createdPacketConn, use0RTT: use0RTT, tlsConf: tlsConf, @@ -278,10 +280,10 @@ func newClient( } func (c *client) dial(ctx context.Context) error { - c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) + c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) - c.session = newClientSession( - c.conn, + c.conn = newClientConnection( + c.sconn, c.packetHandlers, c.destConnID, c.srcConnID, @@ -295,11 +297,11 @@ func (c *client) dial(ctx context.Context) error { c.logger, c.version, ) - c.packetHandlers.Add(c.srcConnID, c.session) + c.packetHandlers.Add(c.srcConnID, c.conn) errorChan := make(chan error, 1) go func() { - err := c.session.run() // returns as soon as the session is closed + err := c.conn.run() // returns as soon as the connection is closed if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn { c.packetHandlers.Destroy() @@ -308,15 +310,15 @@ func (c *client) dial(ctx context.Context) error { }() // only set when we're using 0-RTT - // Otherwise, earlySessionChan will be nil. Receiving from a nil chan blocks forever. - var earlySessionChan <-chan struct{} + // Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever. + var earlyConnChan <-chan struct{} if c.use0RTT { - earlySessionChan = c.session.earlySessionReady() + earlyConnChan = c.conn.earlyConnReady() } select { case <-ctx.Done(): - c.session.shutdown() + c.conn.shutdown() return ctx.Err() case err := <-errorChan: var recreateErr *errCloseForRecreating @@ -327,10 +329,10 @@ func (c *client) dial(ctx context.Context) error { return c.dial(ctx) } return err - case <-earlySessionChan: + case <-earlyConnChan: // ready to send 0-RTT data return nil - case <-c.session.HandshakeComplete().Done(): + case <-c.conn.HandshakeComplete().Done(): // handshake successfully completed return nil } diff --git a/client_test.go b/client_test.go index 42031a47419..c7fbc0d3885 100644 --- a/client_test.go +++ b/client_test.go @@ -31,9 +31,9 @@ var _ = Describe("Client", func() { tracer *mocklogging.MockConnectionTracer config *Config - originalClientSessConstructor func( + originalClientConnConstructor func( conn sendConn, - runner sessionRunner, + runner connRunner, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, conf *Config, @@ -45,19 +45,18 @@ var _ = Describe("Client", func() { tracingID uint64, logger utils.Logger, v protocol.VersionNumber, - ) quicSession + ) quicConn ) BeforeEach(func() { tlsConf = &tls.Config{NextProtos: []string{"proto1"}} connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} - originalClientSessConstructor = newClientSession + originalClientConnConstructor = newClientConnection tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tr := mocklogging.NewMockTracer(mockCtrl) tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1) config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.VersionTLS}} - Eventually(areSessionsRunning).Should(BeFalse()) - // sess = NewMockQuicSession(mockCtrl) + Eventually(areConnsRunning).Should(BeFalse()) addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} packetConn = NewMockPacketConn(mockCtrl) packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() @@ -65,7 +64,7 @@ var _ = Describe("Client", func() { srcConnID: connID, destConnID: connID, version: protocol.VersionTLS, - conn: newSendPconn(packetConn, addr), + sconn: newSendPconn(packetConn, addr), tracer: tracer, logger: utils.DefaultLogger, } @@ -78,14 +77,14 @@ var _ = Describe("Client", func() { AfterEach(func() { connMuxer = origMultiplexer - newClientSession = originalClientSessConstructor + newClientConnection = originalClientConnConstructor }) AfterEach(func() { - if s, ok := cl.session.(*session); ok { + if s, ok := cl.conn.(*connection); ok { s.shutdown() } - Eventually(areSessionsRunning).Should(BeFalse()) + Eventually(areConnsRunning).Should(BeFalse()) }) Context("Dialing", func() { @@ -119,9 +118,9 @@ var _ = Describe("Client", func() { mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) remoteAddrChan := make(chan string, 1) - newClientSession = func( - conn sendConn, - _ sessionRunner, + newClientConnection = func( + sconn sendConn, + _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, _ *Config, @@ -133,12 +132,12 @@ var _ = Describe("Client", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { - remoteAddrChan <- conn.RemoteAddr().String() - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().run() - sess.EXPECT().HandshakeComplete().Return(context.Background()) - return sess + ) quicConn { + remoteAddrChan <- sconn.RemoteAddr().String() + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().run() + conn.EXPECT().HandshakeComplete().Return(context.Background()) + return conn } _, err := DialAddr("localhost:17890", tlsConf, &Config{HandshakeIdleTimeout: time.Millisecond}) Expect(err).ToNot(HaveOccurred()) @@ -152,9 +151,9 @@ var _ = Describe("Client", func() { mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) hostnameChan := make(chan string, 1) - newClientSession = func( + newClientConnection = func( _ sendConn, - _ sessionRunner, + _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, _ *Config, @@ -166,12 +165,12 @@ var _ = Describe("Client", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { + ) quicConn { hostnameChan <- tlsConf.ServerName - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().run() - sess.EXPECT().HandshakeComplete().Return(context.Background()) - return sess + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().run() + conn.EXPECT().HandshakeComplete().Return(context.Background()) + return conn } tlsConf.ServerName = "foobar" _, err := DialAddr("localhost:17890", tlsConf, nil) @@ -185,9 +184,9 @@ var _ = Describe("Client", func() { mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) hostnameChan := make(chan string, 1) - newClientSession = func( + newClientConnection = func( _ sendConn, - _ sessionRunner, + _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, _ *Config, @@ -199,12 +198,12 @@ var _ = Describe("Client", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { + ) quicConn { hostnameChan <- tlsConf.ServerName - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().HandshakeComplete().Return(context.Background()) - sess.EXPECT().run() - return sess + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().run() + return conn } tracer.EXPECT().StartedConnection(packetConn.LocalAddr(), addr, gomock.Any(), gomock.Any()) _, err := Dial( @@ -224,9 +223,9 @@ var _ = Describe("Client", func() { mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) run := make(chan struct{}) - newClientSession = func( + newClientConnection = func( _ sendConn, - runner sessionRunner, + runner connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, _ *Config, @@ -238,14 +237,14 @@ var _ = Describe("Client", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { + ) quicConn { Expect(enable0RTT).To(BeFalse()) - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().run().Do(func() { close(run) }) + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().run().Do(func() { close(run) }) ctx, cancel := context.WithCancel(context.Background()) cancel() - sess.EXPECT().HandshakeComplete().Return(ctx) - return sess + conn.EXPECT().HandshakeComplete().Return(ctx) + return conn } tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) s, err := Dial( @@ -260,16 +259,16 @@ var _ = Describe("Client", func() { Eventually(run).Should(BeClosed()) }) - It("returns early sessions", func() { + It("returns early connections", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) readyChan := make(chan struct{}) done := make(chan struct{}) - newClientSession = func( + newClientConnection = func( _ sendConn, - runner sessionRunner, + runner connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, _ *Config, @@ -281,13 +280,13 @@ var _ = Describe("Client", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { + ) quicConn { Expect(enable0RTT).To(BeTrue()) - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().run().Do(func() { <-done }) - sess.EXPECT().HandshakeComplete().Return(context.Background()) - sess.EXPECT().earlySessionReady().Return(readyChan) - return sess + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().run().Do(func() { <-done }) + conn.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().earlyConnReady().Return(readyChan) + return conn } go func() { @@ -315,9 +314,9 @@ var _ = Describe("Client", func() { mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) testErr := errors.New("early handshake error") - newClientSession = func( + newClientConnection = func( _ sendConn, - _ sessionRunner, + _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, _ *Config, @@ -329,11 +328,11 @@ var _ = Describe("Client", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().run().Return(testErr) - sess.EXPECT().HandshakeComplete().Return(context.Background()) - return sess + ) quicConn { + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().run().Return(testErr) + conn.EXPECT().HandshakeComplete().Return(context.Background()) + return conn } tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) _, err := Dial( @@ -346,21 +345,21 @@ var _ = Describe("Client", func() { Expect(err).To(MatchError(testErr)) }) - It("closes the session when the context is canceled", func() { + It("closes the connection when the context is canceled", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - sessionRunning := make(chan struct{}) - defer close(sessionRunning) - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().run().Do(func() { - <-sessionRunning + connRunning := make(chan struct{}) + defer close(connRunning) + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().run().Do(func() { + <-connRunning }) - sess.EXPECT().HandshakeComplete().Return(context.Background()) - newClientSession = func( + conn.EXPECT().HandshakeComplete().Return(context.Background()) + newClientConnection = func( _ sendConn, - _ sessionRunner, + _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, _ *Config, @@ -372,8 +371,8 @@ var _ = Describe("Client", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { - return sess + ) quicConn { + return conn } ctx, cancel := context.WithCancel(context.Background()) dialed := make(chan struct{}) @@ -392,7 +391,7 @@ var _ = Describe("Client", func() { close(dialed) }() Consistently(dialed).ShouldNot(BeClosed()) - sess.EXPECT().shutdown() + conn.EXPECT().shutdown() cancel() Eventually(dialed).Should(BeClosed()) }) @@ -406,13 +405,13 @@ var _ = Describe("Client", func() { mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - var conn sendConn + var sconn sendConn run := make(chan struct{}) - sessionCreated := make(chan struct{}) - sess := NewMockQuicSession(mockCtrl) - newClientSession = func( + connCreated := make(chan struct{}) + conn := NewMockQuicConn(mockCtrl) + newClientConnection = func( connP sendConn, - _ sessionRunner, + _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, _ *Config, @@ -424,15 +423,15 @@ var _ = Describe("Client", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { - conn = connP - close(sessionCreated) - return sess + ) quicConn { + sconn = connP + close(connCreated) + return conn } - sess.EXPECT().run().Do(func() { + conn.EXPECT().run().Do(func() { <-run }) - sess.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(context.Background()) done := make(chan struct{}) go func() { @@ -442,10 +441,10 @@ var _ = Describe("Client", func() { close(done) }() - Eventually(sessionCreated).Should(BeClosed()) + Eventually(connCreated).Should(BeClosed()) // check that the connection is not closed - Expect(conn.Write([]byte("foobar"))).To(Succeed()) + Expect(sconn.Write([]byte("foobar"))).To(Succeed()) manager.EXPECT().Destroy() close(run) @@ -520,7 +519,7 @@ var _ = Describe("Client", func() { }) }) - It("creates new sessions with the right parameters", func() { + It("creates new connections with the right parameters", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(connID, gomock.Any()) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) @@ -530,9 +529,9 @@ var _ = Describe("Client", func() { var cconn sendConn var version protocol.VersionNumber var conf *Config - newClientSession = func( + newClientConnection = func( connP sendConn, - _ sessionRunner, + _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, configP *Config, @@ -544,16 +543,16 @@ var _ = Describe("Client", func() { _ uint64, _ utils.Logger, versionP protocol.VersionNumber, - ) quicSession { + ) quicConn { cconn = connP version = versionP conf = configP close(c) // TODO: check connection IDs? - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().run() - sess.EXPECT().HandshakeComplete().Return(context.Background()) - return sess + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().run() + conn.EXPECT().HandshakeComplete().Return(context.Background()) + return conn } _, err := Dial(packetConn, addr, "localhost:1337", tlsConf, config) Expect(err).ToNot(HaveOccurred()) @@ -563,16 +562,16 @@ var _ = Describe("Client", func() { Expect(conf.Versions).To(Equal(config.Versions)) }) - It("creates a new session after version negotiation", func() { + It("creates a new connections after version negotiation", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(connID, gomock.Any()).Times(2) manager.EXPECT().Destroy() mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) var counter int - newClientSession = func( + newClientConnection = func( _ sendConn, - _ sessionRunner, + _ connRunner, _ protocol.ConnectionID, _ protocol.ConnectionID, configP *Config, @@ -584,23 +583,23 @@ var _ = Describe("Client", func() { _ uint64, _ utils.Logger, versionP protocol.VersionNumber, - ) quicSession { - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().HandshakeComplete().Return(context.Background()) + ) quicConn { + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().HandshakeComplete().Return(context.Background()) if counter == 0 { Expect(pn).To(BeZero()) Expect(hasNegotiatedVersion).To(BeFalse()) - sess.EXPECT().run().Return(&errCloseForRecreating{ + conn.EXPECT().run().Return(&errCloseForRecreating{ nextPacketNumber: 109, nextVersion: 789, }) } else { Expect(pn).To(Equal(protocol.PacketNumber(109))) Expect(hasNegotiatedVersion).To(BeTrue()) - sess.EXPECT().run() + conn.EXPECT().run() } counter++ - return sess + return conn } tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) diff --git a/closed_session.go b/closed_conn.go similarity index 51% rename from closed_session.go rename to closed_conn.go index 31279020246..35c2d7390a5 100644 --- a/closed_session.go +++ b/closed_conn.go @@ -7,15 +7,15 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) -// A closedLocalSession is a session that we closed locally. -// When receiving packets for such a session, we need to retransmit the packet containing the CONNECTION_CLOSE frame, +// A closedLocalConn is a connection that we closed locally. +// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame, // with an exponential backoff. -type closedLocalSession struct { +type closedLocalConn struct { conn sendConn connClosePacket []byte closeOnce sync.Once - closeChan chan struct{} // is closed when the session is closed or destroyed + closeChan chan struct{} // is closed when the connection is closed or destroyed receivedPackets chan *receivedPacket counter uint64 // number of packets received @@ -25,16 +25,16 @@ type closedLocalSession struct { logger utils.Logger } -var _ packetHandler = &closedLocalSession{} +var _ packetHandler = &closedLocalConn{} -// newClosedLocalSession creates a new closedLocalSession and runs it. -func newClosedLocalSession( +// newClosedLocalConn creates a new closedLocalConn and runs it. +func newClosedLocalConn( conn sendConn, connClosePacket []byte, perspective protocol.Perspective, logger utils.Logger, ) packetHandler { - s := &closedLocalSession{ + s := &closedLocalConn{ conn: conn, connClosePacket: connClosePacket, perspective: perspective, @@ -46,7 +46,7 @@ func newClosedLocalSession( return s } -func (s *closedLocalSession) run() { +func (s *closedLocalConn) run() { for { select { case p := <-s.receivedPackets: @@ -57,14 +57,14 @@ func (s *closedLocalSession) run() { } } -func (s *closedLocalSession) handlePacket(p *receivedPacket) { +func (s *closedLocalConn) handlePacket(p *receivedPacket) { select { case s.receivedPackets <- p: default: } } -func (s *closedLocalSession) handlePacketImpl(_ *receivedPacket) { +func (s *closedLocalConn) handlePacketImpl(_ *receivedPacket) { s.counter++ // exponential backoff // only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving @@ -79,34 +79,34 @@ func (s *closedLocalSession) handlePacketImpl(_ *receivedPacket) { } } -func (s *closedLocalSession) shutdown() { +func (s *closedLocalConn) shutdown() { s.destroy(nil) } -func (s *closedLocalSession) destroy(error) { +func (s *closedLocalConn) destroy(error) { s.closeOnce.Do(func() { close(s.closeChan) }) } -func (s *closedLocalSession) getPerspective() protocol.Perspective { +func (s *closedLocalConn) getPerspective() protocol.Perspective { return s.perspective } -// A closedRemoteSession is a session that was closed remotely. -// For such a session, we might receive reordered packets that were sent before the CONNECTION_CLOSE. +// A closedRemoteConn is a connection that was closed remotely. +// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE. // We can just ignore those packets. -type closedRemoteSession struct { +type closedRemoteConn struct { perspective protocol.Perspective } -var _ packetHandler = &closedRemoteSession{} +var _ packetHandler = &closedRemoteConn{} -func newClosedRemoteSession(pers protocol.Perspective) packetHandler { - return &closedRemoteSession{perspective: pers} +func newClosedRemoteConn(pers protocol.Perspective) packetHandler { + return &closedRemoteConn{perspective: pers} } -func (s *closedRemoteSession) handlePacket(*receivedPacket) {} -func (s *closedRemoteSession) shutdown() {} -func (s *closedRemoteSession) destroy(error) {} -func (s *closedRemoteSession) getPerspective() protocol.Perspective { return s.perspective } +func (s *closedRemoteConn) handlePacket(*receivedPacket) {} +func (s *closedRemoteConn) shutdown() {} +func (s *closedRemoteConn) destroy(error) {} +func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective } diff --git a/closed_session_test.go b/closed_conn_test.go similarity index 59% rename from closed_session_test.go rename to closed_conn_test.go index c329d79202d..e81b0050ed7 100644 --- a/closed_session_test.go +++ b/closed_conn_test.go @@ -12,45 +12,45 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("Closed local session", func() { +var _ = Describe("Closed local connection", func() { var ( - sess packetHandler + conn packetHandler mconn *MockSendConn ) BeforeEach(func() { mconn = NewMockSendConn(mockCtrl) - sess = newClosedLocalSession(mconn, []byte("close"), protocol.PerspectiveClient, utils.DefaultLogger) + conn = newClosedLocalConn(mconn, []byte("close"), protocol.PerspectiveClient, utils.DefaultLogger) }) AfterEach(func() { - Eventually(areClosedSessionsRunning).Should(BeFalse()) + Eventually(areClosedConnsRunning).Should(BeFalse()) }) It("tells its perspective", func() { - Expect(sess.getPerspective()).To(Equal(protocol.PerspectiveClient)) - // stop the session - sess.shutdown() + Expect(conn.getPerspective()).To(Equal(protocol.PerspectiveClient)) + // stop the connection + conn.shutdown() }) It("repeats the packet containing the CONNECTION_CLOSE frame", func() { written := make(chan []byte) mconn.EXPECT().Write(gomock.Any()).Do(func(p []byte) { written <- p }).AnyTimes() for i := 1; i <= 20; i++ { - sess.handlePacket(&receivedPacket{}) + conn.handlePacket(&receivedPacket{}) if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 { Eventually(written).Should(Receive(Equal([]byte("close")))) // receive the CONNECTION_CLOSE } else { Consistently(written, 10*time.Millisecond).Should(HaveLen(0)) } } - // stop the session - sess.shutdown() + // stop the connection + conn.shutdown() }) - It("destroys sessions", func() { - Eventually(areClosedSessionsRunning).Should(BeTrue()) - sess.destroy(errors.New("destroy")) - Eventually(areClosedSessionsRunning).Should(BeFalse()) + It("destroys connections", func() { + Eventually(areClosedConnsRunning).Should(BeTrue()) + conn.destroy(errors.New("destroy")) + Eventually(areClosedConnsRunning).Should(BeFalse()) }) }) diff --git a/config_test.go b/config_test.go index a8574f8d472..36f644e4098 100644 --- a/config_test.go +++ b/config_test.go @@ -103,7 +103,7 @@ var _ = Describe("Config", func() { var calledAcceptToken, calledAllowConnectionWindowIncrease bool c1 := &Config{ AcceptToken: func(_ net.Addr, _ *Token) bool { calledAcceptToken = true; return true }, - AllowConnectionWindowIncrease: func(Session, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, + AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, } c2 := c1.Clone() c2.AcceptToken(&net.UDPAddr{}, &Token{}) diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 26efae2cb85..8162b2ff6f8 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -171,7 +171,7 @@ var _ = Describe("Connection ID Generator", func() { } }) - It("replaces with a closed session for all connection IDs", func() { + It("replaces with a closed connection for all connection IDs", func() { Expect(g.SetMaxActiveConnIDs(5)).To(Succeed()) Expect(queuedFrames).To(HaveLen(4)) sess := NewMockPacketHandler(mockCtrl) diff --git a/session.go b/connection.go similarity index 87% rename from session.go rename to connection.go index 9c901caf30c..8c59b75fca0 100644 --- a/session.go +++ b/connection.go @@ -90,7 +90,7 @@ func (p *receivedPacket) Clone() *receivedPacket { } } -type sessionRunner interface { +type connRunner interface { Add(protocol.ConnectionID, packetHandler) bool GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken Retire(protocol.ConnectionID) @@ -124,18 +124,14 @@ type errCloseForRecreating struct { } func (e *errCloseForRecreating) Error() string { - return "closing session in order to recreate it" + return "closing connection in order to recreate it" } -var sessionTracingID uint64 // to be accessed atomically -func nextSessionTracingID() uint64 { return atomic.AddUint64(&sessionTracingID, 1) } +var connTracingID uint64 // to be accessed atomically +func nextConnTracingID() uint64 { return atomic.AddUint64(&connTracingID, 1) } -func pathMTUDiscoveryEnabled(config *Config) bool { - return !disablePathMTUDiscovery && !config.DisablePathMTUDiscovery -} - -// A Session is a QUIC session -type session struct { +// A Connection is a QUIC connection +type connection struct { // Destination connection ID used during the handshake. // Used to check source connection ID on incoming packets. handshakeDestConnID protocol.ConnectionID @@ -192,7 +188,7 @@ type session struct { undecryptablePacketsToProcess []*receivedPacket clientHelloWritten <-chan *wire.TransportParameters - earlySessionReadyChan chan struct{} + earlyConnReadyChan chan struct{} handshakeCompleteChan chan struct{} // is closed when the handshake completes handshakeComplete bool handshakeConfirmed bool @@ -201,8 +197,8 @@ type session struct { versionNegotiated bool receivedFirstPacket bool - idleTimeout time.Duration - sessionCreationTime time.Time + idleTimeout time.Duration + creationTime time.Time // The idle timeout is set based on the max of the time we received the last packet... lastPacketReceivedTime time.Time // ... and the time we sent a new ack-eliciting packet after receiving a packet. @@ -226,15 +222,15 @@ type session struct { } var ( - _ Session = &session{} - _ EarlySession = &session{} - _ streamSender = &session{} - deadlineSendImmediately = time.Time{}.Add(42 * time.Millisecond) // any value > time.Time{} and before time.Now() is fine + _ Connection = &connection{} + _ EarlyConnection = &connection{} + _ streamSender = &connection{} + deadlineSendImmediately = time.Time{}.Add(42 * time.Millisecond) // any value > time.Time{} and before time.Now() is fine ) -var newSession = func( +var newConnection = func( conn sendConn, - runner sessionRunner, + runner connRunner, origDestConnID protocol.ConnectionID, retrySrcConnID *protocol.ConnectionID, clientDestConnID protocol.ConnectionID, @@ -250,8 +246,8 @@ var newSession = func( tracingID uint64, logger utils.Logger, v protocol.VersionNumber, -) quicSession { - s := &session{ +) quicConn { + s := &connection{ conn: conn, config: conf, handshakeDestConnID: destConnID, @@ -287,7 +283,7 @@ var newSession = func( s.version, ) s.preSetup() - s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), SessionTracingKey, tracingID)) + s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( 0, getMaxPacketSize(s.conn.RemoteAddr()), @@ -368,9 +364,9 @@ var newSession = func( } // declare this as a variable, such that we can it mock it in the tests -var newClientSession = func( +var newClientConnection = func( conn sendConn, - runner sessionRunner, + runner connRunner, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, conf *Config, @@ -382,8 +378,8 @@ var newClientSession = func( tracingID uint64, logger utils.Logger, v protocol.VersionNumber, -) quicSession { - s := &session{ +) quicConn { + s := &connection{ conn: conn, config: conf, origDestConnID: destConnID, @@ -415,7 +411,7 @@ var newClientSession = func( s.version, ) s.preSetup() - s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), SessionTracingKey, tracingID)) + s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( initialPacketNumber, getMaxPacketSize(s.conn.RemoteAddr()), @@ -500,7 +496,7 @@ var newClientSession = func( return s } -func (s *session) preSetup() { +func (s *connection) preSetup() { s.sendQueue = newSendQueue(s.conn) s.retransmissionQueue = newRetransmissionQueue(s.version) s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams, s.version) @@ -518,7 +514,7 @@ func (s *session) preSetup() { s.rttStats, s.logger, ) - s.earlySessionReadyChan = make(chan struct{}) + s.earlyConnReadyChan = make(chan struct{}) s.streamsMap = newStreamsMap( s, s.newFlowController, @@ -528,14 +524,14 @@ func (s *session) preSetup() { s.version, ) s.framer = newFramer(s.streamsMap, s.version) - s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) + s.receivedPackets = make(chan *receivedPacket, protocol.MaxConnUnprocessedPackets) s.closeChan = make(chan closeError, 1) s.sendingScheduled = make(chan struct{}, 1) s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background()) now := time.Now() s.lastPacketReceivedTime = now - s.sessionCreationTime = now + s.creationTime = now s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame) if s.config.EnableDatagrams { @@ -543,8 +539,8 @@ func (s *session) preSetup() { } } -// run the session main loop -func (s *session) run() error { +// run the connection main loop +func (s *connection) run() error { defer s.ctxCancel() s.timer = utils.NewTimer() @@ -562,7 +558,7 @@ func (s *session) run() error { s.scheduleSending() if zeroRTTParams != nil { s.restoreTransportParameters(zeroRTTParams) - close(s.earlySessionReadyChan) + close(s.earlyConnReadyChan) } case closeErr := <-s.closeChan: // put the close error back into the channel, so that the run loop can receive it @@ -596,7 +592,7 @@ runLoop: if processed := s.handlePacketImpl(p); processed { processedUndecryptablePacket = true } - // Don't set timers and send packets if the packet made us close the session. + // Don't set timers and send packets if the packet made us close the connection. select { case closeErr = <-s.closeChan: break runLoop @@ -619,7 +615,7 @@ runLoop: case <-sendQueueAvailable: case firstPacket := <-s.receivedPackets: wasProcessed := s.handlePacketImpl(firstPacket) - // Don't set timers and send packets if the packet made us close the session. + // Don't set timers and send packets if the packet made us close the connection. select { case closeErr = <-s.closeChan: break runLoop @@ -668,11 +664,11 @@ runLoop: } if keepAliveTime := s.nextKeepAliveTime(); !keepAliveTime.IsZero() && !now.Before(keepAliveTime) { - // send a PING frame since there is no activity in the session + // send a PING frame since there is no activity in the connection s.logger.Debugf("Sending a keep-alive PING to keep the connection alive.") s.framer.QueueControlFrame(&wire.PingFrame{}) s.keepAlivePingSent = true - } else if !s.handshakeComplete && now.Sub(s.sessionCreationTime) >= s.config.handshakeTimeout() { + } else if !s.handshakeComplete && now.Sub(s.creationTime) >= s.config.handshakeTimeout() { s.destroyImpl(qerr.ErrHandshakeTimeout) continue } else { @@ -711,24 +707,24 @@ runLoop: return closeErr.err } -// blocks until the early session can be used -func (s *session) earlySessionReady() <-chan struct{} { - return s.earlySessionReadyChan +// blocks until the early connection can be used +func (s *connection) earlyConnReady() <-chan struct{} { + return s.earlyConnReadyChan } -func (s *session) HandshakeComplete() context.Context { +func (s *connection) HandshakeComplete() context.Context { return s.handshakeCtx } -func (s *session) Context() context.Context { +func (s *connection) Context() context.Context { return s.ctx } -func (s *session) supportsDatagrams() bool { +func (s *connection) supportsDatagrams() bool { return s.peerParams.MaxDatagramFrameSize != protocol.InvalidByteCount } -func (s *session) ConnectionState() ConnectionState { +func (s *connection) ConnectionState() ConnectionState { return ConnectionState{ TLS: s.cryptoStreamHandler.ConnectionState(), SupportsDatagrams: s.supportsDatagrams(), @@ -737,18 +733,18 @@ func (s *session) ConnectionState() ConnectionState { // Time when the next keep-alive packet should be sent. // It returns a zero time if no keep-alive should be sent. -func (s *session) nextKeepAliveTime() time.Time { +func (s *connection) nextKeepAliveTime() time.Time { if !s.config.KeepAlive || s.keepAlivePingSent || !s.firstAckElicitingPacketAfterIdleSentTime.IsZero() { return time.Time{} } return s.lastPacketReceivedTime.Add(s.keepAliveInterval) } -func (s *session) maybeResetTimer() { +func (s *connection) maybeResetTimer() { var deadline time.Time if !s.handshakeComplete { deadline = utils.MinTime( - s.sessionCreationTime.Add(s.config.handshakeTimeout()), + s.creationTime.Add(s.config.handshakeTimeout()), s.idleTimeoutStartTime().Add(s.config.HandshakeIdleTimeout), ) } else { @@ -758,7 +754,7 @@ func (s *session) maybeResetTimer() { deadline = s.idleTimeoutStartTime().Add(s.idleTimeout) } } - if s.handshakeConfirmed && pathMTUDiscoveryEnabled(s.config) { + if s.handshakeConfirmed && !s.config.DisablePathMTUDiscovery { if probeTime := s.mtuDiscoverer.NextProbeTime(); !probeTime.IsZero() { deadline = utils.MinTime(deadline, probeTime) } @@ -777,11 +773,11 @@ func (s *session) maybeResetTimer() { s.timer.Reset(deadline) } -func (s *session) idleTimeoutStartTime() time.Time { +func (s *connection) idleTimeoutStartTime() time.Time { return utils.MaxTime(s.lastPacketReceivedTime, s.firstAckElicitingPacketAfterIdleSentTime) } -func (s *session) handleHandshakeComplete() { +func (s *connection) handleHandshakeComplete() { s.handshakeComplete = true s.handshakeCompleteChan = nil // prevent this case from ever being selected again defer s.handshakeCtxCancel() @@ -817,12 +813,12 @@ func (s *session) handleHandshakeComplete() { s.queueControlFrame(&wire.HandshakeDoneFrame{}) } -func (s *session) handleHandshakeConfirmed() { +func (s *connection) handleHandshakeConfirmed() { s.handshakeConfirmed = true s.sentPacketHandler.SetHandshakeConfirmed() s.cryptoStreamHandler.SetHandshakeConfirmed() - if pathMTUDiscoveryEnabled(s.config) { + if !s.config.DisablePathMTUDiscovery { maxPacketSize := s.peerParams.MaxUDPPayloadSize if maxPacketSize == 0 { maxPacketSize = protocol.MaxByteCount @@ -840,7 +836,7 @@ func (s *session) handleHandshakeConfirmed() { } } -func (s *session) handlePacketImpl(rp *receivedPacket) bool { +func (s *connection) handlePacketImpl(rp *receivedPacket) bool { s.sentPacketHandler.ReceivedBytes(rp.Size()) if wire.IsVersionNegotiationPacket(rp.data) { @@ -908,7 +904,7 @@ func (s *session) handlePacketImpl(rp *receivedPacket) bool { return processed } -func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { +func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { var wasQueued bool defer func() { @@ -1000,7 +996,7 @@ func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool / return true } -func (s *session) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was this a valid Retry */ { +func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was this a valid Retry */ { if s.perspective == protocol.PerspectiveServer { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) @@ -1062,7 +1058,7 @@ func (s *session) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was t return true } -func (s *session) handleVersionNegotiationPacket(p *receivedPacket) { +func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) { if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets if s.tracer != nil { @@ -1116,7 +1112,7 @@ func (s *session) handleVersionNegotiationPacket(p *receivedPacket) { }) } -func (s *session) handleUnpackedPacket( +func (s *connection) handleUnpackedPacket( packet *unpackedPacket, ecn protocol.ECN, rcvTime time.Time, @@ -1148,10 +1144,10 @@ func (s *session) handleUnpackedPacket( s.handshakeDestConnID = cid s.connIDManager.ChangeInitialConnID(cid) } - // We create the session as soon as we receive the first packet from the client. + // We create the connection as soon as we receive the first packet from the client. // We do that before authenticating the packet. // That means that if the source connection ID was corrupted, - // we might have create a session with an incorrect source connection ID. + // we might have create a connection with an incorrect source connection ID. // Once we authenticate the first packet, we need to update it. if s.perspective == protocol.PerspectiveServer { if !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { @@ -1216,7 +1212,7 @@ func (s *session) handleUnpackedPacket( return s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting) } -func (s *session) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID) error { +func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID) error { var err error wire.LogFrame(s.logger, f, false) switch frame := f.(type) { @@ -1264,9 +1260,9 @@ func (s *session) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, d } // handlePacket is called by the server with a new packet -func (s *session) handlePacket(p *receivedPacket) { +func (s *connection) handlePacket(p *receivedPacket) { // Discard packets once the amount of queued packets is larger than - // the channel size, protocol.MaxSessionUnprocessedPackets + // the channel size, protocol.MaxConnUnprocessedPackets select { case s.receivedPackets <- p: default: @@ -1276,7 +1272,7 @@ func (s *session) handlePacket(p *receivedPacket) { } } -func (s *session) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame) { +func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame) { if frame.IsApplicationError { s.closeRemote(&qerr.ApplicationError{ Remote: true, @@ -1293,7 +1289,7 @@ func (s *session) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame) { }) } -func (s *session) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { +func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { encLevelChanged, err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel) if err != nil { return err @@ -1306,7 +1302,7 @@ func (s *session) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.E return nil } -func (s *session) handleStreamFrame(frame *wire.StreamFrame) error { +func (s *connection) handleStreamFrame(frame *wire.StreamFrame) error { str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) if err != nil { return err @@ -1319,11 +1315,11 @@ func (s *session) handleStreamFrame(frame *wire.StreamFrame) error { return str.handleStreamFrame(frame) } -func (s *session) handleMaxDataFrame(frame *wire.MaxDataFrame) { +func (s *connection) handleMaxDataFrame(frame *wire.MaxDataFrame) { s.connFlowController.UpdateSendWindow(frame.MaximumData) } -func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error { +func (s *connection) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error { str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) if err != nil { return err @@ -1336,11 +1332,11 @@ func (s *session) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error return nil } -func (s *session) handleMaxStreamsFrame(frame *wire.MaxStreamsFrame) { +func (s *connection) handleMaxStreamsFrame(frame *wire.MaxStreamsFrame) { s.streamsMap.HandleMaxStreamsFrame(frame) } -func (s *session) handleResetStreamFrame(frame *wire.ResetStreamFrame) error { +func (s *connection) handleResetStreamFrame(frame *wire.ResetStreamFrame) error { str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID) if err != nil { return err @@ -1352,7 +1348,7 @@ func (s *session) handleResetStreamFrame(frame *wire.ResetStreamFrame) error { return str.handleResetStreamFrame(frame) } -func (s *session) handleStopSendingFrame(frame *wire.StopSendingFrame) error { +func (s *connection) handleStopSendingFrame(frame *wire.StopSendingFrame) error { str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID) if err != nil { return err @@ -1365,11 +1361,11 @@ func (s *session) handleStopSendingFrame(frame *wire.StopSendingFrame) error { return nil } -func (s *session) handlePathChallengeFrame(frame *wire.PathChallengeFrame) { +func (s *connection) handlePathChallengeFrame(frame *wire.PathChallengeFrame) { s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data}) } -func (s *session) handleNewTokenFrame(frame *wire.NewTokenFrame) error { +func (s *connection) handleNewTokenFrame(frame *wire.NewTokenFrame) error { if s.perspective == protocol.PerspectiveServer { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, @@ -1382,15 +1378,15 @@ func (s *session) handleNewTokenFrame(frame *wire.NewTokenFrame) error { return nil } -func (s *session) handleNewConnectionIDFrame(f *wire.NewConnectionIDFrame) error { +func (s *connection) handleNewConnectionIDFrame(f *wire.NewConnectionIDFrame) error { return s.connIDManager.Add(f) } -func (s *session) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFrame, destConnID protocol.ConnectionID) error { +func (s *connection) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFrame, destConnID protocol.ConnectionID) error { return s.connIDGenerator.Retire(f.SequenceNumber, destConnID) } -func (s *session) handleHandshakeDoneFrame() error { +func (s *connection) handleHandshakeDoneFrame() error { if s.perspective == protocol.PerspectiveServer { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, @@ -1403,7 +1399,7 @@ func (s *session) handleHandshakeDoneFrame() error { return nil } -func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { +func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error { acked1RTTPacket, err := s.sentPacketHandler.ReceivedAck(frame, encLevel, s.lastPacketReceivedTime) if err != nil { return err @@ -1417,7 +1413,7 @@ func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encrypt return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked()) } -func (s *session) handleDatagramFrame(f *wire.DatagramFrame) error { +func (s *connection) handleDatagramFrame(f *wire.DatagramFrame) error { if f.Length(s.version) > protocol.MaxDatagramFrameSize { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, @@ -1428,50 +1424,50 @@ func (s *session) handleDatagramFrame(f *wire.DatagramFrame) error { return nil } -// closeLocal closes the session and send a CONNECTION_CLOSE containing the error -func (s *session) closeLocal(e error) { +// closeLocal closes the connection and send a CONNECTION_CLOSE containing the error +func (s *connection) closeLocal(e error) { s.closeOnce.Do(func() { if e == nil { - s.logger.Infof("Closing session.") + s.logger.Infof("Closing connection.") } else { - s.logger.Errorf("Closing session with error: %s", e) + s.logger.Errorf("Closing connection with error: %s", e) } s.closeChan <- closeError{err: e, immediate: false, remote: false} }) } -// destroy closes the session without sending the error on the wire -func (s *session) destroy(e error) { +// destroy closes the connection without sending the error on the wire +func (s *connection) destroy(e error) { s.destroyImpl(e) <-s.ctx.Done() } -func (s *session) destroyImpl(e error) { +func (s *connection) destroyImpl(e error) { s.closeOnce.Do(func() { if nerr, ok := e.(net.Error); ok && nerr.Timeout() { - s.logger.Errorf("Destroying session: %s", e) + s.logger.Errorf("Destroying connection: %s", e) } else { - s.logger.Errorf("Destroying session with error: %s", e) + s.logger.Errorf("Destroying connection with error: %s", e) } s.closeChan <- closeError{err: e, immediate: true, remote: false} }) } -func (s *session) closeRemote(e error) { +func (s *connection) closeRemote(e error) { s.closeOnce.Do(func() { - s.logger.Errorf("Peer closed session with error: %s", e) + s.logger.Errorf("Peer closed connection with error: %s", e) s.closeChan <- closeError{err: e, immediate: true, remote: true} }) } // Close the connection. It sends a NO_ERROR application error. // It waits until the run loop has stopped before returning -func (s *session) shutdown() { +func (s *connection) shutdown() { s.closeLocal(nil) <-s.ctx.Done() } -func (s *session) CloseWithError(code ApplicationErrorCode, desc string) error { +func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error { s.closeLocal(&qerr.ApplicationError{ ErrorCode: code, ErrorMessage: desc, @@ -1480,7 +1476,7 @@ func (s *session) CloseWithError(code ApplicationErrorCode, desc string) error { return nil } -func (s *session) handleCloseError(closeErr *closeError) { +func (s *connection) handleCloseError(closeErr *closeError) { e := closeErr.err if e == nil { e = &qerr.ApplicationError{} @@ -1524,7 +1520,7 @@ func (s *session) handleCloseError(closeErr *closeError) { // If this is a remote close we're done here if closeErr.remote { - s.connIDGenerator.ReplaceWithClosed(newClosedRemoteSession(s.perspective)) + s.connIDGenerator.ReplaceWithClosed(newClosedRemoteConn(s.perspective)) return } if closeErr.immediate { @@ -1535,11 +1531,11 @@ func (s *session) handleCloseError(closeErr *closeError) { if err != nil { s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err) } - cs := newClosedLocalSession(s.conn, connClosePacket, s.perspective, s.logger) + cs := newClosedLocalConn(s.conn, connClosePacket, s.perspective, s.logger) s.connIDGenerator.ReplaceWithClosed(cs) } -func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { +func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { s.sentPacketHandler.DropPackets(encLevel) s.receivedPacketHandler.DropPackets(encLevel) if s.tracer != nil { @@ -1557,7 +1553,7 @@ func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { } // is called for the client, when restoring transport parameters saved for 0-RTT -func (s *session) restoreTransportParameters(params *wire.TransportParameters) { +func (s *connection) restoreTransportParameters(params *wire.TransportParameters) { if s.logger.Debug() { s.logger.Debugf("Restoring Transport Parameters: %s", params) } @@ -1568,7 +1564,7 @@ func (s *session) restoreTransportParameters(params *wire.TransportParameters) { s.streamsMap.UpdateLimits(params) } -func (s *session) handleTransportParameters(params *wire.TransportParameters) { +func (s *connection) handleTransportParameters(params *wire.TransportParameters) { if err := s.checkTransportParameters(params); err != nil { s.closeLocal(&qerr.TransportError{ ErrorCode: qerr.TransportParameterError, @@ -1580,13 +1576,13 @@ func (s *session) handleTransportParameters(params *wire.TransportParameters) { // During a 0-RTT connection, we are only allowed to use the new transport parameters for 1-RTT packets. if s.perspective == protocol.PerspectiveServer { s.applyTransportParameters() - // On the server side, the early session is ready as soon as we processed + // On the server side, the early connection is ready as soon as we processed // the client's transport parameters. - close(s.earlySessionReadyChan) + close(s.earlyConnReadyChan) } } -func (s *session) checkTransportParameters(params *wire.TransportParameters) error { +func (s *connection) checkTransportParameters(params *wire.TransportParameters) error { if s.logger.Debug() { s.logger.Debugf("Processed Transport Parameters: %s", params) } @@ -1619,7 +1615,7 @@ func (s *session) checkTransportParameters(params *wire.TransportParameters) err return nil } -func (s *session) applyTransportParameters() { +func (s *connection) applyTransportParameters() { params := s.peerParams // Our local idle timeout will always be > 0. s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) @@ -1640,7 +1636,7 @@ func (s *session) applyTransportParameters() { } } -func (s *session) sendPackets() error { +func (s *connection) sendPackets() error { s.pacingDeadline = time.Time{} var sentPacket bool // only used in for packets sent in send mode SendAny @@ -1706,7 +1702,7 @@ func (s *session) sendPackets() error { } } -func (s *session) maybeSendAckOnlyPacket() error { +func (s *connection) maybeSendAckOnlyPacket() error { packet, err := s.packer.MaybePackAckPacket(s.handshakeConfirmed) if err != nil { return err @@ -1718,7 +1714,7 @@ func (s *session) maybeSendAckOnlyPacket() error { return nil } -func (s *session) sendProbePacket(encLevel protocol.EncryptionLevel) error { +func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { // Queue probe packets until we actually send out a packet, // or until there are no more packets to queue. var packet *packedPacket @@ -1754,13 +1750,13 @@ func (s *session) sendProbePacket(encLevel protocol.EncryptionLevel) error { } } if packet == nil || packet.packetContents == nil { - return fmt.Errorf("session BUG: couldn't pack %s probe packet", encLevel) + return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) } s.sendPackedPacket(packet, time.Now()) return nil } -func (s *session) sendPacket() (bool, error) { +func (s *connection) sendPacket() (bool, error) { if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset}) } @@ -1783,7 +1779,7 @@ func (s *session) sendPacket() (bool, error) { s.sendQueue.Send(packet.buffer) return true, nil } - if pathMTUDiscoveryEnabled(s.config) && s.mtuDiscoverer.ShouldSendProbe(now) { + if !s.config.DisablePathMTUDiscovery && s.mtuDiscoverer.ShouldSendProbe(now) { packet, err := s.packer.PackMTUProbePacket(s.mtuDiscoverer.GetPing()) if err != nil { return false, err @@ -1799,7 +1795,7 @@ func (s *session) sendPacket() (bool, error) { return true, nil } -func (s *session) sendPackedPacket(packet *packedPacket, now time.Time) { +func (s *connection) sendPackedPacket(packet *packedPacket, now time.Time) { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && packet.IsAckEliciting() { s.firstAckElicitingPacketAfterIdleSentTime = now } @@ -1809,7 +1805,7 @@ func (s *session) sendPackedPacket(packet *packedPacket, now time.Time) { s.sendQueue.Send(packet.buffer) } -func (s *session) sendConnectionClose(e error) ([]byte, error) { +func (s *connection) sendConnectionClose(e error) ([]byte, error) { var packet *coalescedPacket var err error var transportErr *qerr.TransportError @@ -1821,7 +1817,7 @@ func (s *session) sendConnectionClose(e error) ([]byte, error) { } else { packet, err = s.packer.PackConnectionClose(&qerr.TransportError{ ErrorCode: qerr.InternalError, - ErrorMessage: fmt.Sprintf("session BUG: unspecified error type (msg: %s)", e.Error()), + ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()), }) } if err != nil { @@ -1831,7 +1827,7 @@ func (s *session) sendConnectionClose(e error) ([]byte, error) { return packet.buffer.Data, s.conn.Write(packet.buffer.Data) } -func (s *session) logPacketContents(p *packetContents) { +func (s *connection) logPacketContents(p *packetContents) { // tracing if s.tracer != nil { frames := make([]logging.Frame, 0, len(p.frames)) @@ -1854,7 +1850,7 @@ func (s *session) logPacketContents(p *packetContents) { } } -func (s *session) logCoalescedPacket(packet *coalescedPacket) { +func (s *connection) logCoalescedPacket(packet *coalescedPacket) { if s.logger.Debug() { if len(packet.packets) > 1 { s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.packets), packet.buffer.Len(), s.logID) @@ -1867,7 +1863,7 @@ func (s *session) logCoalescedPacket(packet *coalescedPacket) { } } -func (s *session) logPacket(packet *packedPacket) { +func (s *connection) logPacket(packet *packedPacket) { if s.logger.Debug() { s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.header.PacketNumber, packet.buffer.Len(), s.logID, packet.EncryptionLevel()) } @@ -1875,32 +1871,32 @@ func (s *session) logPacket(packet *packedPacket) { } // AcceptStream returns the next stream openend by the peer -func (s *session) AcceptStream(ctx context.Context) (Stream, error) { +func (s *connection) AcceptStream(ctx context.Context) (Stream, error) { return s.streamsMap.AcceptStream(ctx) } -func (s *session) AcceptUniStream(ctx context.Context) (ReceiveStream, error) { +func (s *connection) AcceptUniStream(ctx context.Context) (ReceiveStream, error) { return s.streamsMap.AcceptUniStream(ctx) } // OpenStream opens a stream -func (s *session) OpenStream() (Stream, error) { +func (s *connection) OpenStream() (Stream, error) { return s.streamsMap.OpenStream() } -func (s *session) OpenStreamSync(ctx context.Context) (Stream, error) { +func (s *connection) OpenStreamSync(ctx context.Context) (Stream, error) { return s.streamsMap.OpenStreamSync(ctx) } -func (s *session) OpenUniStream() (SendStream, error) { +func (s *connection) OpenUniStream() (SendStream, error) { return s.streamsMap.OpenUniStream() } -func (s *session) OpenUniStreamSync(ctx context.Context) (SendStream, error) { +func (s *connection) OpenUniStreamSync(ctx context.Context) (SendStream, error) { return s.streamsMap.OpenUniStreamSync(ctx) } -func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController { +func (s *connection) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController { initialSendWindow := s.peerParams.InitialMaxStreamDataUni if id.Type() == protocol.StreamTypeBidi { if id.InitiatedBy() == s.perspective { @@ -1922,14 +1918,14 @@ func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlow } // scheduleSending signals that we have data for sending -func (s *session) scheduleSending() { +func (s *connection) scheduleSending() { select { case s.sendingScheduled <- struct{}{}: default: } } -func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket, hdr *wire.Header) { +func (s *connection) tryQueueingUndecryptablePacket(p *receivedPacket, hdr *wire.Header) { if s.handshakeComplete { panic("shouldn't queue undecryptable packets after handshake completion") } @@ -1947,33 +1943,33 @@ func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket, hdr *wire.He s.undecryptablePackets = append(s.undecryptablePackets, p) } -func (s *session) queueControlFrame(f wire.Frame) { +func (s *connection) queueControlFrame(f wire.Frame) { s.framer.QueueControlFrame(f) s.scheduleSending() } -func (s *session) onHasStreamWindowUpdate(id protocol.StreamID) { +func (s *connection) onHasStreamWindowUpdate(id protocol.StreamID) { s.windowUpdateQueue.AddStream(id) s.scheduleSending() } -func (s *session) onHasConnectionWindowUpdate() { +func (s *connection) onHasConnectionWindowUpdate() { s.windowUpdateQueue.AddConnection() s.scheduleSending() } -func (s *session) onHasStreamData(id protocol.StreamID) { +func (s *connection) onHasStreamData(id protocol.StreamID) { s.framer.AddActiveStream(id) s.scheduleSending() } -func (s *session) onStreamCompleted(id protocol.StreamID) { +func (s *connection) onStreamCompleted(id protocol.StreamID) { if err := s.streamsMap.DeleteStream(id); err != nil { s.closeLocal(err) } } -func (s *session) SendMessage(p []byte) error { +func (s *connection) SendMessage(p []byte) error { f := &wire.DatagramFrame{DataLenPresent: true} if protocol.ByteCount(len(p)) > f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version) { return errors.New("message too large") @@ -1983,27 +1979,27 @@ func (s *session) SendMessage(p []byte) error { return s.datagramQueue.AddAndWait(f) } -func (s *session) ReceiveMessage() ([]byte, error) { +func (s *connection) ReceiveMessage() ([]byte, error) { return s.datagramQueue.Receive() } -func (s *session) LocalAddr() net.Addr { +func (s *connection) LocalAddr() net.Addr { return s.conn.LocalAddr() } -func (s *session) RemoteAddr() net.Addr { +func (s *connection) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } -func (s *session) getPerspective() protocol.Perspective { +func (s *connection) getPerspective() protocol.Perspective { return s.perspective } -func (s *session) GetVersion() protocol.VersionNumber { +func (s *connection) GetVersion() protocol.VersionNumber { return s.version } -func (s *session) NextSession() Session { +func (s *connection) NextConnection() Connection { <-s.HandshakeComplete().Done() s.streamsMap.UseResetMaps() return s diff --git a/session_test.go b/connection_test.go similarity index 81% rename from session_test.go rename to connection_test.go index ac6da15287b..a95da57709e 100644 --- a/session_test.go +++ b/connection_test.go @@ -9,7 +9,6 @@ import ( "fmt" "io" "net" - "runtime" "runtime/pprof" "strings" "time" @@ -32,22 +31,22 @@ import ( . "github.com/onsi/gomega" ) -func areSessionsRunning() bool { +func areConnsRunning() bool { var b bytes.Buffer pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "quic-go.(*session).run") + return strings.Contains(b.String(), "quic-go.(*connection).run") } -func areClosedSessionsRunning() bool { +func areClosedConnsRunning() bool { var b bytes.Buffer pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "quic-go.(*closedLocalSession).run") + return strings.Contains(b.String(), "quic-go.(*closedLocalConn).run") } -var _ = Describe("Session", func() { +var _ = Describe("Connection", func() { var ( - sess *session - sessionRunner *MockSessionRunner + conn *connection + connRunner *MockConnRunner mconn *MockSendConn streamManager *MockStreamManager packer *MockPacker @@ -73,18 +72,18 @@ var _ = Describe("Session", func() { } expectReplaceWithClosed := func() { - sessionRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).MaxTimes(1) - sessionRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedLocalSession{})) + connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).MaxTimes(1) + connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{})) s.shutdown() - Eventually(areClosedSessionsRunning).Should(BeFalse()) + Eventually(areClosedConnsRunning).Should(BeFalse()) }) } BeforeEach(func() { - Eventually(areSessionsRunning).Should(BeFalse()) + Eventually(areConnsRunning).Should(BeFalse()) - sessionRunner = NewMockSessionRunner(mockCtrl) + connRunner = NewMockConnRunner(mockCtrl) mconn = NewMockSendConn(mockCtrl) mconn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes() mconn.EXPECT().LocalAddr().Return(localAddr).AnyTimes() @@ -95,9 +94,9 @@ var _ = Describe("Session", func() { tracer.EXPECT().SentTransportParameters(gomock.Any()) tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() tracer.EXPECT().UpdatedCongestionState(gomock.Any()) - sess = newSession( + conn = newConnection( mconn, - sessionRunner, + connRunner, nil, nil, clientDestConnID, @@ -113,19 +112,19 @@ var _ = Describe("Session", func() { 1234, utils.DefaultLogger, protocol.VersionTLS, - ).(*session) + ).(*connection) streamManager = NewMockStreamManager(mockCtrl) - sess.streamsMap = streamManager + conn.streamsMap = streamManager packer = NewMockPacker(mockCtrl) - sess.packer = packer + conn.packer = packer cryptoSetup = mocks.NewMockCryptoSetup(mockCtrl) - sess.cryptoStreamHandler = cryptoSetup - sess.handshakeComplete = true - sess.idleTimeout = time.Hour + conn.cryptoStreamHandler = cryptoSetup + conn.handshakeComplete = true + conn.idleTimeout = time.Hour }) AfterEach(func() { - Eventually(areSessionsRunning).Should(BeFalse()) + Eventually(areConnsRunning).Should(BeFalse()) }) Context("frame handling", func() { @@ -138,7 +137,7 @@ var _ = Describe("Session", func() { str := NewMockReceiveStreamI(mockCtrl) str.EXPECT().handleStreamFrame(f) streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(str, nil) - Expect(sess.handleStreamFrame(f)).To(Succeed()) + Expect(conn.handleStreamFrame(f)).To(Succeed()) }) It("returns errors", func() { @@ -150,12 +149,12 @@ var _ = Describe("Session", func() { str := NewMockReceiveStreamI(mockCtrl) str.EXPECT().handleStreamFrame(f).Return(testErr) streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(str, nil) - Expect(sess.handleStreamFrame(f)).To(MatchError(testErr)) + Expect(conn.handleStreamFrame(f)).To(MatchError(testErr)) }) It("ignores STREAM frames for closed streams", func() { streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(5)).Return(nil, nil) // for closed streams, the streamManager returns nil - Expect(sess.handleStreamFrame(&wire.StreamFrame{ + Expect(conn.handleStreamFrame(&wire.StreamFrame{ StreamID: 5, Data: []byte("foobar"), })).To(Succeed()) @@ -167,8 +166,8 @@ var _ = Describe("Session", func() { f := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 3}}} sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().ReceivedAck(f, protocol.EncryptionHandshake, gomock.Any()) - sess.sentPacketHandler = sph - err := sess.handleAckFrame(f, protocol.EncryptionHandshake) + conn.sentPacketHandler = sph + err := conn.handleAckFrame(f, protocol.EncryptionHandshake) Expect(err).ToNot(HaveOccurred()) }) }) @@ -183,7 +182,7 @@ var _ = Describe("Session", func() { str := NewMockReceiveStreamI(mockCtrl) streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(555)).Return(str, nil) str.EXPECT().handleResetStreamFrame(f) - err := sess.handleResetStreamFrame(f) + err := conn.handleResetStreamFrame(f) Expect(err).ToNot(HaveOccurred()) }) @@ -196,13 +195,13 @@ var _ = Describe("Session", func() { str := NewMockReceiveStreamI(mockCtrl) streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(7)).Return(str, nil) str.EXPECT().handleResetStreamFrame(f).Return(testErr) - err := sess.handleResetStreamFrame(f) + err := conn.handleResetStreamFrame(f) Expect(err).To(MatchError(testErr)) }) It("ignores RESET_STREAM frames for closed streams", func() { streamManager.EXPECT().GetOrOpenReceiveStream(protocol.StreamID(3)).Return(nil, nil) - Expect(sess.handleFrame(&wire.ResetStreamFrame{ + Expect(conn.handleFrame(&wire.ResetStreamFrame{ StreamID: 3, ErrorCode: 42, }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) @@ -214,7 +213,7 @@ var _ = Describe("Session", func() { BeforeEach(func() { connFC = mocks.NewMockConnectionFlowController(mockCtrl) - sess.connFlowController = connFC + conn.connFlowController = connFC }) It("updates the flow control window of a stream", func() { @@ -225,18 +224,18 @@ var _ = Describe("Session", func() { str := NewMockSendStreamI(mockCtrl) streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(12345)).Return(str, nil) str.EXPECT().updateSendWindow(protocol.ByteCount(0x1337)) - Expect(sess.handleMaxStreamDataFrame(f)).To(Succeed()) + Expect(conn.handleMaxStreamDataFrame(f)).To(Succeed()) }) It("updates the flow control window of the connection", func() { offset := protocol.ByteCount(0x800000) connFC.EXPECT().UpdateSendWindow(offset) - sess.handleMaxDataFrame(&wire.MaxDataFrame{MaximumData: offset}) + conn.handleMaxDataFrame(&wire.MaxDataFrame{MaximumData: offset}) }) It("ignores MAX_STREAM_DATA frames for a closed stream", func() { streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(10)).Return(nil, nil) - Expect(sess.handleFrame(&wire.MaxStreamDataFrame{ + Expect(conn.handleFrame(&wire.MaxStreamDataFrame{ StreamID: 10, MaximumStreamData: 1337, }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) @@ -250,7 +249,7 @@ var _ = Describe("Session", func() { MaxStreamNum: 10, } streamManager.EXPECT().HandleMaxStreamsFrame(f) - sess.handleMaxStreamsFrame(f) + conn.handleMaxStreamsFrame(f) }) }) @@ -263,13 +262,13 @@ var _ = Describe("Session", func() { str := NewMockSendStreamI(mockCtrl) streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(5)).Return(str, nil) str.EXPECT().handleStopSendingFrame(f) - err := sess.handleStopSendingFrame(f) + err := conn.handleStopSendingFrame(f) Expect(err).ToNot(HaveOccurred()) }) It("ignores STOP_SENDING frames for a closed stream", func() { streamManager.EXPECT().GetOrOpenSendStream(protocol.StreamID(3)).Return(nil, nil) - Expect(sess.handleFrame(&wire.StopSendingFrame{ + Expect(conn.handleFrame(&wire.StopSendingFrame{ StreamID: 3, ErrorCode: 1337, }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) @@ -277,50 +276,50 @@ var _ = Describe("Session", func() { }) It("handles NEW_CONNECTION_ID frames", func() { - Expect(sess.handleFrame(&wire.NewConnectionIDFrame{ + Expect(conn.handleFrame(&wire.NewConnectionIDFrame{ SequenceNumber: 10, ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) - Expect(sess.connIDManager.queue.Back().Value.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(conn.connIDManager.queue.Back().Value.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) }) It("handles PING frames", func() { - err := sess.handleFrame(&wire.PingFrame{}, protocol.Encryption1RTT, protocol.ConnectionID{}) + err := conn.handleFrame(&wire.PingFrame{}, protocol.Encryption1RTT, protocol.ConnectionID{}) Expect(err).NotTo(HaveOccurred()) }) It("rejects PATH_RESPONSE frames", func() { - err := sess.handleFrame(&wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, protocol.Encryption1RTT, protocol.ConnectionID{}) + err := conn.handleFrame(&wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}, protocol.Encryption1RTT, protocol.ConnectionID{}) Expect(err).To(MatchError("unexpected PATH_RESPONSE frame")) }) It("handles PATH_CHALLENGE frames", func() { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - err := sess.handleFrame(&wire.PathChallengeFrame{Data: data}, protocol.Encryption1RTT, protocol.ConnectionID{}) + err := conn.handleFrame(&wire.PathChallengeFrame{Data: data}, protocol.Encryption1RTT, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) - frames, _ := sess.framer.AppendControlFrames(nil, 1000) + frames, _ := conn.framer.AppendControlFrames(nil, 1000) Expect(frames).To(Equal([]ackhandler.Frame{{Frame: &wire.PathResponseFrame{Data: data}}})) }) It("rejects NEW_TOKEN frames", func() { - err := sess.handleNewTokenFrame(&wire.NewTokenFrame{}) + err := conn.handleNewTokenFrame(&wire.NewTokenFrame{}) Expect(err).To(HaveOccurred()) Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.ProtocolViolation)) }) It("handles BLOCKED frames", func() { - err := sess.handleFrame(&wire.DataBlockedFrame{}, protocol.Encryption1RTT, protocol.ConnectionID{}) + err := conn.handleFrame(&wire.DataBlockedFrame{}, protocol.Encryption1RTT, protocol.ConnectionID{}) Expect(err).NotTo(HaveOccurred()) }) It("handles STREAM_BLOCKED frames", func() { - err := sess.handleFrame(&wire.StreamDataBlockedFrame{}, protocol.Encryption1RTT, protocol.ConnectionID{}) + err := conn.handleFrame(&wire.StreamDataBlockedFrame{}, protocol.Encryption1RTT, protocol.ConnectionID{}) Expect(err).NotTo(HaveOccurred()) }) It("handles STREAMS_BLOCKED frames", func() { - err := sess.handleFrame(&wire.StreamsBlockedFrame{}, protocol.Encryption1RTT, protocol.ConnectionID{}) + err := conn.handleFrame(&wire.StreamsBlockedFrame{}, protocol.Encryption1RTT, protocol.ConnectionID{}) Expect(err).NotTo(HaveOccurred()) }) @@ -331,11 +330,11 @@ var _ = Describe("Session", func() { ErrorMessage: "foobar", } streamManager.EXPECT().CloseWithError(expectedErr) - sessionRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedRemoteSession{})) + connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) }) - sessionRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedRemoteSession{})) + connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) }) cryptoSetup.EXPECT().Close() gomock.InOrder( @@ -346,13 +345,13 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - Expect(sess.run()).To(MatchError(expectedErr)) + Expect(conn.run()).To(MatchError(expectedErr)) }() - Expect(sess.handleFrame(&wire.ConnectionCloseFrame{ + Expect(conn.handleFrame(&wire.ConnectionCloseFrame{ ErrorCode: uint64(qerr.StreamLimitError), ReasonPhrase: "foobar", }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) - Eventually(sess.Context().Done()).Should(BeClosed()) + Eventually(conn.Context().Done()).Should(BeClosed()) }) It("handles CONNECTION_CLOSE frames, with an application error code", func() { @@ -362,11 +361,11 @@ var _ = Describe("Session", func() { ErrorMessage: "foobar", } streamManager.EXPECT().CloseWithError(testErr) - sessionRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedRemoteSession{})) + connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) }) - sessionRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedRemoteSession{})) + connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) }) cryptoSetup.EXPECT().Close() gomock.InOrder( @@ -377,19 +376,19 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - Expect(sess.run()).To(MatchError(testErr)) + Expect(conn.run()).To(MatchError(testErr)) }() ccf := &wire.ConnectionCloseFrame{ ErrorCode: 0x1337, ReasonPhrase: "foobar", IsApplicationError: true, } - Expect(sess.handleFrame(ccf, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) - Eventually(sess.Context().Done()).Should(BeClosed()) + Expect(conn.handleFrame(ccf, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) + Eventually(conn.Context().Done()).Should(BeClosed()) }) It("errors on HANDSHAKE_DONE frames", func() { - Expect(sess.handleHandshakeDoneFrame()).To(MatchError(&qerr.TransportError{ + Expect(conn.handleHandshakeDoneFrame()).To(MatchError(&qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "received a HANDSHAKE_DONE frame", })) @@ -397,8 +396,8 @@ var _ = Describe("Session", func() { }) It("tells its versions", func() { - sess.version = 4242 - Expect(sess.GetVersion()).To(Equal(protocol.VersionNumber(4242))) + conn.version = 4242 + Expect(conn.GetVersion()).To(Equal(protocol.VersionNumber(4242))) }) Context("closing", func() { @@ -420,18 +419,18 @@ var _ = Describe("Session", func() { } }) - runSession := func() { + runConn := func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - runErr <- sess.run() + runErr <- conn.run() }() - Eventually(areSessionsRunning).Should(BeTrue()) + Eventually(areConnsRunning).Should(BeTrue()) } It("shuts down without error", func() { - sess.handshakeComplete = true - runSession() + conn.handshakeComplete = true + runConn() streamManager.EXPECT().CloseWithError(&qerr.ApplicationError{}) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() @@ -452,13 +451,13 @@ var _ = Describe("Session", func() { }), tracer.EXPECT().Close(), ) - sess.shutdown() - Eventually(areSessionsRunning).Should(BeFalse()) - Expect(sess.Context().Done()).To(BeClosed()) + conn.shutdown() + Eventually(areConnsRunning).Should(BeFalse()) + Expect(conn.Context().Done()).To(BeClosed()) }) It("only closes once", func() { - runSession() + runConn() streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() @@ -466,14 +465,14 @@ var _ = Describe("Session", func() { mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - sess.shutdown() - sess.shutdown() - Eventually(areSessionsRunning).Should(BeFalse()) - Expect(sess.Context().Done()).To(BeClosed()) + conn.shutdown() + conn.shutdown() + Eventually(areConnsRunning).Should(BeFalse()) + Expect(conn.Context().Done()).To(BeClosed()) }) It("closes with an error", func() { - runSession() + runConn() expectedErr := &qerr.ApplicationError{ ErrorCode: 0x1337, ErrorMessage: "test error", @@ -487,13 +486,13 @@ var _ = Describe("Session", func() { tracer.EXPECT().ClosedConnection(expectedErr), tracer.EXPECT().Close(), ) - sess.CloseWithError(0x1337, "test error") - Eventually(areSessionsRunning).Should(BeFalse()) - Expect(sess.Context().Done()).To(BeClosed()) + conn.CloseWithError(0x1337, "test error") + Eventually(areConnsRunning).Should(BeFalse()) + Expect(conn.Context().Done()).To(BeClosed()) }) It("includes the frame type in transport-level close frames", func() { - runSession() + runConn() expectedErr := &qerr.TransportError{ ErrorCode: 0x1337, FrameType: 0x42, @@ -508,16 +507,16 @@ var _ = Describe("Session", func() { tracer.EXPECT().ClosedConnection(expectedErr), tracer.EXPECT().Close(), ) - sess.closeLocal(expectedErr) - Eventually(areSessionsRunning).Should(BeFalse()) - Expect(sess.Context().Done()).To(BeClosed()) + conn.closeLocal(expectedErr) + Eventually(areConnsRunning).Should(BeFalse()) + Expect(conn.Context().Done()).To(BeClosed()) }) - It("destroys the session", func() { - runSession() + It("destroys the connection", func() { + runConn() testErr := errors.New("close") streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() // don't EXPECT any calls to mconn.Write() gomock.InOrder( @@ -529,8 +528,8 @@ var _ = Describe("Session", func() { }), tracer.EXPECT().Close(), ) - sess.destroy(testErr) - Eventually(areSessionsRunning).Should(BeFalse()) + conn.destroy(testErr) + Eventually(areConnsRunning).Should(BeFalse()) expectedRunErr = &qerr.TransportError{ ErrorCode: qerr.InternalError, ErrorMessage: testErr.Error(), @@ -538,7 +537,7 @@ var _ = Describe("Session", func() { }) It("cancels the context when the run loop exists", func() { - runSession() + runConn() streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() cryptoSetup.EXPECT().Close() @@ -546,7 +545,7 @@ var _ = Describe("Session", func() { returned := make(chan struct{}) go func() { defer GinkgoRecover() - ctx := sess.Context() + ctx := conn.Context() <-ctx.Done() Expect(ctx.Err()).To(MatchError(context.Canceled)) close(returned) @@ -555,27 +554,27 @@ var _ = Describe("Session", func() { mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - sess.shutdown() + conn.shutdown() Eventually(returned).Should(BeClosed()) }) It("doesn't send any more packets after receiving a CONNECTION_CLOSE", func() { unpacker := NewMockUnpacker(mockCtrl) - sess.handshakeConfirmed = true - sess.unpacker = unpacker - runSession() + conn.handshakeConfirmed = true + conn.unpacker = unpacker + runConn() cryptoSetup.EXPECT().Close() streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes() + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes() buf := &bytes.Buffer{} hdr := &wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: srcConnID}, PacketNumberLen: protocol.PacketNumberLen2, } - Expect(hdr.Write(buf, sess.version)).To(Succeed()) + Expect(hdr.Write(buf, conn.version)).To(Succeed()) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*wire.Header, time.Time, []byte) (*unpackedPacket, error) { buf := &bytes.Buffer{} - Expect((&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Write(buf, sess.version)).To(Succeed()) + Expect((&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Write(buf, conn.version)).To(Succeed()) return &unpackedPacket{ hdr: hdr, data: buf.Bytes(), @@ -589,21 +588,21 @@ var _ = Describe("Session", func() { tracer.EXPECT().Close(), ) // don't EXPECT any calls to packer.PackPacket() - sess.handlePacket(&receivedPacket{ + conn.handlePacket(&receivedPacket{ rcvTime: time.Now(), remoteAddr: &net.UDPAddr{}, buffer: getPacketBuffer(), data: buf.Bytes(), }) // Consistently(pack).ShouldNot(Receive()) - Eventually(sess.Context().Done()).Should(BeClosed()) + Eventually(conn.Context().Done()).Should(BeClosed()) }) It("closes when the sendQueue encounters an error", func() { - sess.handshakeConfirmed = true - conn := NewMockSendConn(mockCtrl) - conn.EXPECT().Write(gomock.Any()).Return(io.ErrClosedPipe).AnyTimes() - sess.sendQueue = newSendQueue(conn) + conn.handshakeConfirmed = true + sconn := NewMockSendConn(mockCtrl) + sconn.EXPECT().Write(gomock.Any()).Return(io.ErrClosedPipe).AnyTimes() + conn.sendQueue = newSendQueue(sconn) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() @@ -614,21 +613,21 @@ var _ = Describe("Session", func() { tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() - sess.sentPacketHandler = sph + conn.sentPacketHandler = sph p := getPacket(1) packer.EXPECT().PackPacket().Return(p, nil) packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() - runSession() - sess.queueControlFrame(&wire.PingFrame{}) - sess.scheduleSending() - Eventually(sess.Context().Done()).Should(BeClosed()) + runConn() + conn.queueControlFrame(&wire.PingFrame{}) + conn.scheduleSending() + Eventually(conn.Context().Done()).Should(BeClosed()) }) It("closes due to a stateless reset", func() { token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - runSession() + runConn() gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { var srErr *StatelessResetError @@ -638,9 +637,9 @@ var _ = Describe("Session", func() { tracer.EXPECT().Close(), ) streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() - sess.destroy(&StatelessResetError{Token: token}) + conn.destroy(&StatelessResetError{Token: token}) }) }) @@ -649,12 +648,12 @@ var _ = Describe("Session", func() { BeforeEach(func() { unpacker = NewMockUnpacker(mockCtrl) - sess.unpacker = unpacker + conn.unpacker = unpacker }) getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { buf := &bytes.Buffer{} - Expect(extHdr.Write(buf, sess.version)).To(Succeed()) + Expect(extHdr.Write(buf, conn.version)).To(Succeed()) return &receivedPacket{ data: append(buf.Bytes(), data...), buffer: getPacketBuffer(), @@ -668,18 +667,17 @@ var _ = Describe("Session", func() { Type: protocol.PacketTypeRetry, DestConnectionID: destConnID, SrcConnectionID: srcConnID, - Version: sess.version, + Version: conn.version, Token: []byte("foobar"), }}, make([]byte, 16) /* Retry integrity tag */) tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropUnexpectedPacket) - Expect(sess.handlePacketImpl(p)).To(BeFalse()) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) It("drops Version Negotiation packets", func() { - b, err := wire.ComposeVersionNegotiation(srcConnID, destConnID, sess.config.Versions) - Expect(err).ToNot(HaveOccurred()) + b := wire.ComposeVersionNegotiation(srcConnID, destConnID, conn.config.Versions) tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket) - Expect(sess.handlePacketImpl(&receivedPacket{ + Expect(conn.handlePacketImpl(&receivedPacket{ data: b, buffer: getPacketBuffer(), })).To(BeFalse()) @@ -690,13 +688,13 @@ var _ = Describe("Session", func() { Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - Version: sess.version, + Version: conn.version, }, PacketNumberLen: protocol.PacketNumberLen2, }, nil) p.data[0] ^= 0x40 // unset the QUIC bit tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) - Expect(sess.handlePacketImpl(p)).To(BeFalse()) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) It("drops packets for which the version is unsupported", func() { @@ -704,12 +702,12 @@ var _ = Describe("Session", func() { Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - Version: sess.version + 1, + Version: conn.version + 1, }, PacketNumberLen: protocol.PacketNumberLen2, }, nil) tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnsupportedVersion) - Expect(sess.handlePacketImpl(p)).To(BeFalse()) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) It("drops packets with an unsupported version", func() { @@ -719,19 +717,19 @@ var _ = Describe("Session", func() { protocol.SupportedVersions = origSupportedVersions }() - protocol.SupportedVersions = append(protocol.SupportedVersions, sess.version+1) + protocol.SupportedVersions = append(protocol.SupportedVersions, conn.version+1) p := getPacket(&wire.ExtendedHeader{ Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, DestConnectionID: destConnID, SrcConnectionID: srcConnID, - Version: sess.version + 1, + Version: conn.version + 1, }, PacketNumberLen: protocol.PacketNumberLen2, }, nil) tracer.EXPECT().DroppedPacket(logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedVersion) - Expect(sess.handlePacketImpl(p)).To(BeFalse()) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) It("informs the ReceivedPacketHandler about non-ack-eliciting packets", func() { @@ -754,11 +752,11 @@ var _ = Describe("Session", func() { rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.EncryptionInitial), rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECNCE, protocol.EncryptionInitial, rcvTime, false), ) - sess.receivedPacketHandler = rph + conn.receivedPacketHandler = rph packet.rcvTime = rcvTime tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ReceivedPacket(hdr, protocol.ByteCount(len(packet.data)), []logging.Frame{}) - Expect(sess.handlePacketImpl(packet)).To(BeTrue()) + Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) It("informs the ReceivedPacketHandler about ack-eliciting packets", func() { @@ -769,7 +767,7 @@ var _ = Describe("Session", func() { } rcvTime := time.Now().Add(-10 * time.Second) buf := &bytes.Buffer{} - Expect((&wire.PingFrame{}).Write(buf, sess.version)).To(Succeed()) + Expect((&wire.PingFrame{}).Write(buf, conn.version)).To(Succeed()) packet := getPacket(hdr, nil) packet.ecn = protocol.ECT1 unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{ @@ -783,11 +781,11 @@ var _ = Describe("Session", func() { rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT), rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECT1, protocol.Encryption1RTT, rcvTime, true), ) - sess.receivedPacketHandler = rph + conn.receivedPacketHandler = rph packet.rcvTime = rcvTime tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ReceivedPacket(hdr, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}}) - Expect(sess.handlePacketImpl(packet)).To(BeTrue()) + Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) It("drops duplicate packets", func() { @@ -805,9 +803,9 @@ var _ = Describe("Session", func() { }, nil) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT).Return(true) - sess.receivedPacketHandler = rph + conn.receivedPacketHandler = rph tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, protocol.ByteCount(len(packet.data)), logging.PacketDropDuplicate) - Expect(sess.handlePacketImpl(packet)).To(BeFalse()) + Expect(conn.handlePacketImpl(packet)).To(BeFalse()) }) It("drops a packet when unpacking fails", func() { @@ -818,7 +816,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() expectReplaceWithClosed() p := getPacket(&wire.ExtendedHeader{ @@ -826,25 +824,25 @@ var _ = Describe("Session", func() { IsLongHeader: true, Type: protocol.PacketTypeHandshake, DestConnectionID: srcConnID, - Version: sess.version, + Version: conn.version, Length: 2 + 6, }, PacketNumber: 0x1337, PacketNumberLen: protocol.PacketNumberLen2, }, []byte("foobar")) tracer.EXPECT().DroppedPacket(logging.PacketTypeHandshake, p.Size(), logging.PacketDropPayloadDecryptError) - sess.handlePacket(p) - Consistently(sess.Context().Done()).ShouldNot(BeClosed()) + conn.handlePacket(p) + Consistently(conn.Context().Done()).ShouldNot(BeClosed()) // make the go routine return tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) - sess.closeLocal(errors.New("close")) - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.closeLocal(errors.New("close")) + Eventually(conn.Context().Done()).Should(BeClosed()) }) It("processes multiple received packets before sending one", func() { - sess.sessionCreationTime = time.Now() + conn.creationTime = time.Now() var pn protocol.PacketNumber unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { pn++ @@ -861,7 +859,7 @@ var _ = Describe("Session", func() { packer.EXPECT().PackCoalescedPacket() // only expect a single call for i := 0; i < 3; i++ { - sess.handlePacket(getPacket(&wire.ExtendedHeader{ + conn.handlePacket(getPacket(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: srcConnID}, PacketNumber: 0x1337, PacketNumberLen: protocol.PacketNumberLen2, @@ -871,9 +869,9 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() - Consistently(sess.Context().Done()).ShouldNot(BeClosed()) + Consistently(conn.Context().Done()).ShouldNot(BeClosed()) // make the go routine return streamManager.EXPECT().CloseWithError(gomock.Any()) @@ -883,13 +881,13 @@ var _ = Describe("Session", func() { tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) - sess.closeLocal(errors.New("close")) - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.closeLocal(errors.New("close")) + Eventually(conn.Context().Done()).Should(BeClosed()) }) It("doesn't processes multiple received packets before sending one before handshake completion", func() { - sess.handshakeComplete = false - sess.sessionCreationTime = time.Now() + conn.handshakeComplete = false + conn.creationTime = time.Now() var pn protocol.PacketNumber unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { pn++ @@ -906,7 +904,7 @@ var _ = Describe("Session", func() { packer.EXPECT().PackCoalescedPacket().Times(3) // only expect a single call for i := 0; i < 3; i++ { - sess.handlePacket(getPacket(&wire.ExtendedHeader{ + conn.handlePacket(getPacket(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: srcConnID}, PacketNumber: 0x1337, PacketNumberLen: protocol.PacketNumberLen2, @@ -916,9 +914,9 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() - Consistently(sess.Context().Done()).ShouldNot(BeClosed()) + Consistently(conn.Context().Done()).ShouldNot(BeClosed()) // make the go routine return streamManager.EXPECT().CloseWithError(gomock.Any()) @@ -928,11 +926,11 @@ var _ = Describe("Session", func() { tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) - sess.closeLocal(errors.New("close")) - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.closeLocal(errors.New("close")) + Eventually(conn.Context().Done()).Should(BeClosed()) }) - It("closes the session when unpacking fails because the reserved bits were incorrect", func() { + It("closes the connection when unpacking fails because the reserved bits were incorrect", func() { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, wire.ErrInvalidReservedBits) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() @@ -941,7 +939,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - err := sess.run() + err := conn.run() Expect(err).To(HaveOccurred()) Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.ProtocolViolation)) @@ -955,8 +953,8 @@ var _ = Describe("Session", func() { }, nil) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - sess.handlePacket(packet) - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.handlePacket(packet) + Eventually(conn.Context().Done()).Should(BeClosed()) }) It("ignores packets when unpacking the header fails", func() { @@ -968,11 +966,11 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - runErr <- sess.run() + runErr <- conn.run() }() expectReplaceWithClosed() tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, gomock.Any(), logging.PacketDropHeaderParseError) - sess.handlePacket(getPacket(&wire.ExtendedHeader{ + conn.handlePacket(getPacket(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: srcConnID}, PacketNumberLen: protocol.PacketNumberLen1, }, nil)) @@ -982,11 +980,11 @@ var _ = Describe("Session", func() { tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) - sess.shutdown() - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) }) - It("closes the session when unpacking fails because of an error other than a decryption error", func() { + It("closes the connection when unpacking fails because of an error other than a decryption error", func() { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() @@ -995,7 +993,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - err := sess.run() + err := conn.run() Expect(err).To(HaveOccurred()) Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.ConnectionIDLimitError)) @@ -1009,8 +1007,8 @@ var _ = Describe("Session", func() { }, nil) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - sess.handlePacket(packet) - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.handlePacket(packet) + Eventually(conn.Context().Done()).Should(BeClosed()) }) It("rejects packets with empty payload", func() { @@ -1026,7 +1024,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - Expect(sess.run()).To(MatchError(&qerr.TransportError{ + Expect(conn.run()).To(MatchError(&qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "empty packet", })) @@ -1036,7 +1034,7 @@ var _ = Describe("Session", func() { mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - sess.handlePacket(getPacket(&wire.ExtendedHeader{ + conn.handlePacket(getPacket(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: srcConnID}, PacketNumberLen: protocol.PacketNumberLen1, }, nil)) @@ -1051,7 +1049,7 @@ var _ = Describe("Session", func() { DestConnectionID: destConnID, SrcConnectionID: srcConnID, Length: 1, - Version: sess.version, + Version: conn.version, }, PacketNumberLen: protocol.PacketNumberLen1, PacketNumber: 1, @@ -1063,7 +1061,7 @@ var _ = Describe("Session", func() { DestConnectionID: destConnID, SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, Length: 1, - Version: sess.version, + Version: conn.version, }, PacketNumberLen: protocol.PacketNumberLen1, PacketNumber: 2, @@ -1079,15 +1077,15 @@ var _ = Describe("Session", func() { p1 := getPacket(hdr1, nil) tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(p1.data)), gomock.Any()) - Expect(sess.handlePacketImpl(p1)).To(BeTrue()) + Expect(conn.handlePacketImpl(p1)).To(BeTrue()) // The next packet has to be ignored, since the source connection ID doesn't match. p2 := getPacket(hdr2, nil) tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, protocol.ByteCount(len(p2.data)), logging.PacketDropUnknownConnectionID) - Expect(sess.handlePacketImpl(p2)).To(BeFalse()) + Expect(conn.handlePacketImpl(p2)).To(BeFalse()) }) It("queues undecryptable packets", func() { - sess.handshakeComplete = false + conn.handshakeComplete = false hdr := &wire.ExtendedHeader{ Header: wire.Header{ IsLongHeader: true, @@ -1095,7 +1093,7 @@ var _ = Describe("Session", func() { DestConnectionID: destConnID, SrcConnectionID: srcConnID, Length: 1, - Version: sess.version, + Version: conn.version, }, PacketNumberLen: protocol.PacketNumberLen1, PacketNumber: 1, @@ -1103,8 +1101,8 @@ var _ = Describe("Session", func() { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable) packet := getPacket(hdr, nil) tracer.EXPECT().BufferedPacket(logging.PacketTypeHandshake) - Expect(sess.handlePacketImpl(packet)).To(BeFalse()) - Expect(sess.undecryptablePackets).To(Equal([]*receivedPacket{packet})) + Expect(conn.handlePacketImpl(packet)).To(BeFalse()) + Expect(conn.undecryptablePackets).To(Equal([]*receivedPacket{packet})) }) Context("updating the remote address", func() { @@ -1121,7 +1119,7 @@ var _ = Describe("Session", func() { packet.remoteAddr = &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any()) - Expect(sess.handlePacketImpl(packet)).To(BeTrue()) + Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) }) @@ -1141,7 +1139,7 @@ var _ = Describe("Session", func() { }, PacketNumberLen: protocol.PacketNumberLen3, } - hdrLen := hdr.GetLength(sess.version) + hdrLen := hdr.GetLength(conn.version) b := make([]byte, 1) rand.Read(b) packet := getPacket(hdr, bytes.Repeat(b, int(length)-3)) @@ -1159,7 +1157,7 @@ var _ = Describe("Session", func() { }, nil }) tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any()) - Expect(sess.handlePacketImpl(packet)).To(BeTrue()) + Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) It("handles coalesced packets", func() { @@ -1188,11 +1186,11 @@ var _ = Describe("Session", func() { tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any()), ) packet1.data = append(packet1.data, packet2.data...) - Expect(sess.handlePacketImpl(packet1)).To(BeTrue()) + Expect(conn.handlePacketImpl(packet1)).To(BeTrue()) }) It("works with undecryptable packets", func() { - sess.handshakeComplete = false + conn.handshakeComplete = false hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) gomock.InOrder( @@ -1211,10 +1209,10 @@ var _ = Describe("Session", func() { tracer.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), gomock.Any()), ) packet1.data = append(packet1.data, packet2.data...) - Expect(sess.handlePacketImpl(packet1)).To(BeTrue()) + Expect(conn.handlePacketImpl(packet1)).To(BeTrue()) - Expect(sess.undecryptablePackets).To(HaveLen(1)) - Expect(sess.undecryptablePackets[0].data).To(HaveLen(hdrLen1 + 456 - 3)) + Expect(conn.undecryptablePackets).To(HaveLen(1)) + Expect(conn.undecryptablePackets[0].data).To(HaveLen(hdrLen1 + 456 - 3)) }) It("ignores coalesced packet parts if the destination connection IDs don't match", func() { @@ -1236,23 +1234,23 @@ var _ = Describe("Session", func() { tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.ByteCount(len(packet2.data)), logging.PacketDropUnknownConnectionID), ) packet1.data = append(packet1.data, packet2.data...) - Expect(sess.handlePacketImpl(packet1)).To(BeTrue()) + Expect(conn.handlePacketImpl(packet1)).To(BeTrue()) }) }) }) Context("sending packets", func() { var ( - sessionDone chan struct{} - sender *MockSender + connDone chan struct{} + sender *MockSender ) BeforeEach(func() { sender = NewMockSender(mockCtrl) sender.EXPECT().Run() sender.EXPECT().WouldBlock().AnyTimes() - sess.sendQueue = sender - sessionDone = make(chan struct{}) + conn.sendQueue = sender + connDone = make(chan struct{}) }) AfterEach(func() { @@ -1264,30 +1262,30 @@ var _ = Describe("Session", func() { tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() sender.EXPECT().Close() - sess.shutdown() - Eventually(sess.Context().Done()).Should(BeClosed()) - Eventually(sessionDone).Should(BeClosed()) + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) + Eventually(connDone).Should(BeClosed()) }) - runSession := func() { + runConn := func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() - close(sessionDone) + conn.run() + close(connDone) }() } It("sends packets", func() { - sess.handshakeConfirmed = true + conn.handshakeConfirmed = true sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SentPacket(gomock.Any()) - sess.sentPacketHandler = sph - runSession() + conn.sentPacketHandler = sph + runConn() p := getPacket(1) packer.EXPECT().PackPacket().Return(p, nil) packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() @@ -1295,16 +1293,16 @@ var _ = Describe("Session", func() { sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) tracer.EXPECT().SentPacket(p.header, p.buffer.Len(), nil, []logging.Frame{}) - sess.scheduleSending() + conn.scheduleSending() Eventually(sent).Should(BeClosed()) }) It("doesn't send packets if there's nothing to send", func() { - sess.handshakeConfirmed = true - runSession() + conn.handshakeConfirmed = true + runConn() packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() - sess.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) - sess.scheduleSending() + conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) + conn.scheduleSending() time.Sleep(50 * time.Millisecond) // make sure there are no calls to mconn.Write() }) @@ -1315,35 +1313,35 @@ var _ = Describe("Session", func() { sph.EXPECT().SendMode().Return(ackhandler.SendAck) done := make(chan struct{}) packer.EXPECT().MaybePackAckPacket(false).Do(func(bool) { close(done) }) - sess.sentPacketHandler = sph - runSession() - sess.scheduleSending() + conn.sentPacketHandler = sph + runConn() + conn.scheduleSending() Eventually(done).Should(BeClosed()) }) It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { - sess.handshakeConfirmed = true + conn.handshakeConfirmed = true sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SentPacket(gomock.Any()) - sess.sentPacketHandler = sph + conn.sentPacketHandler = sph fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) fc.EXPECT().IsNewlyBlocked() p := getPacket(1) packer.EXPECT().PackPacket().Return(p, nil) packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() - sess.connFlowController = fc - runSession() + conn.connFlowController = fc + runConn() sent := make(chan struct{}) sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) tracer.EXPECT().SentPacket(p.header, p.length, nil, []logging.Frame{}) - sess.scheduleSending() + conn.scheduleSending() Eventually(sent).Should(BeClosed()) - frames, _ := sess.framer.AppendControlFrames(nil, 1000) + frames, _ := conn.framer.AppendControlFrames(nil, 1000) Expect(frames).To(Equal([]ackhandler.Frame{{Frame: &logging.DataBlockedFrame{MaximumData: 1337}}})) }) @@ -1352,9 +1350,9 @@ var _ = Describe("Session", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendNone).AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() - sess.sentPacketHandler = sph - runSession() - sess.scheduleSending() + conn.sentPacketHandler = sph + runConn() + conn.scheduleSending() time.Sleep(50 * time.Millisecond) }) @@ -1370,13 +1368,13 @@ var _ = Describe("Session", func() { switch encLevel { case protocol.EncryptionInitial: sendMode = ackhandler.SendPTOInitial - getFrame = sess.retransmissionQueue.GetInitialFrame + getFrame = conn.retransmissionQueue.GetInitialFrame case protocol.EncryptionHandshake: sendMode = ackhandler.SendPTOHandshake - getFrame = sess.retransmissionQueue.GetHandshakeFrame + getFrame = conn.retransmissionQueue.GetHandshakeFrame case protocol.Encryption1RTT: sendMode = ackhandler.SendPTOAppData - getFrame = sess.retransmissionQueue.GetAppDataFrame + getFrame = conn.retransmissionQueue.GetAppDataFrame } }) @@ -1392,12 +1390,12 @@ var _ = Describe("Session", func() { sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) }) - sess.sentPacketHandler = sph - runSession() + conn.sentPacketHandler = sph + runConn() sent := make(chan struct{}) sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) tracer.EXPECT().SentPacket(p.header, p.length, gomock.Any(), gomock.Any()) - sess.scheduleSending() + conn.scheduleSending() Eventually(sent).Should(BeClosed()) }) @@ -1413,12 +1411,12 @@ var _ = Describe("Session", func() { sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) }) - sess.sentPacketHandler = sph - runSession() + conn.sentPacketHandler = sph + runConn() sent := make(chan struct{}) sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) tracer.EXPECT().SentPacket(p.header, p.length, gomock.Any(), gomock.Any()) - sess.scheduleSending() + conn.scheduleSending() Eventually(sent).Should(BeClosed()) // We're using a mock packet packer in this test. // We therefore need to test separately that the PING was actually queued. @@ -1438,12 +1436,12 @@ var _ = Describe("Session", func() { tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() - sess.handshakeConfirmed = true - sess.handshakeComplete = true - sess.sentPacketHandler = sph + conn.handshakeConfirmed = true + conn.handshakeComplete = true + conn.sentPacketHandler = sph sender = NewMockSender(mockCtrl) sender.EXPECT().Run() - sess.sendQueue = sender + conn.sendQueue = sender streamManager.EXPECT().CloseWithError(gomock.Any()) }) @@ -1456,8 +1454,8 @@ var _ = Describe("Session", func() { tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() sender.EXPECT().Close() - sess.shutdown() - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) }) It("sends multiple packets one by one immediately", func() { @@ -1473,9 +1471,9 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() - sess.scheduleSending() + conn.scheduleSending() time.Sleep(50 * time.Millisecond) // make sure that only 2 packets are sent }) @@ -1490,9 +1488,9 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() - sess.scheduleSending() + conn.scheduleSending() time.Sleep(50 * time.Millisecond) // make sure that only 1 packet is sent }) @@ -1507,9 +1505,9 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() - sess.scheduleSending() + conn.scheduleSending() time.Sleep(50 * time.Millisecond) // make sure that only 1 packet is sent }) @@ -1526,9 +1524,9 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() - sess.scheduleSending() + conn.scheduleSending() time.Sleep(50 * time.Millisecond) // make sure that only 1 packet is sent }) @@ -1553,9 +1551,9 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() - sess.scheduleSending() + conn.scheduleSending() Eventually(written).Should(HaveLen(1)) Consistently(written, pacingDelay/2).Should(HaveLen(1)) Eventually(written, 2*pacingDelay).Should(HaveLen(2)) @@ -1576,9 +1574,9 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() - sess.scheduleSending() + conn.scheduleSending() Eventually(written).Should(HaveLen(3)) }) @@ -1589,9 +1587,9 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() - sess.scheduleSending() + conn.scheduleSending() time.Sleep(scaleDuration(50 * time.Millisecond)) written := make(chan struct{}) @@ -1611,14 +1609,14 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() written := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() sph.EXPECT().SentPacket(gomock.Any()).Do(func(*ackhandler.Packet) { sph.EXPECT().ReceivedBytes(gomock.Any()) - sess.handlePacket(&receivedPacket{buffer: getPacketBuffer()}) + conn.handlePacket(&receivedPacket{buffer: getPacketBuffer()}) }) sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() @@ -1626,7 +1624,7 @@ var _ = Describe("Session", func() { packer.EXPECT().PackPacket().Return(nil, nil) sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) }) - sess.scheduleSending() + conn.scheduleSending() time.Sleep(scaleDuration(50 * time.Millisecond)) Eventually(written).Should(BeClosed()) @@ -1644,11 +1642,11 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() available := make(chan struct{}, 1) sender.EXPECT().Available().Return(available) - sess.scheduleSending() + conn.scheduleSending() Eventually(written).Should(Receive()) time.Sleep(scaleDuration(50 * time.Millisecond)) @@ -1677,41 +1675,39 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() - sess.scheduleSending() // no packet will get sent + conn.scheduleSending() // no packet will get sent time.Sleep(50 * time.Millisecond) }) - if runtime.GOOS != "windows" { // Path MTU Discovery is disabled on Windows - It("sends a Path MTU probe packet", func() { - mtuDiscoverer := NewMockMtuDiscoverer(mockCtrl) - sess.mtuDiscoverer = mtuDiscoverer - sess.config.DisablePathMTUDiscovery = false - sph.EXPECT().SentPacket(gomock.Any()) - sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny) - sph.EXPECT().SendMode().Return(ackhandler.SendNone) - written := make(chan struct{}, 1) - sender.EXPECT().WouldBlock().AnyTimes() - sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) - gomock.InOrder( - mtuDiscoverer.EXPECT().NextProbeTime(), - mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true), - mtuDiscoverer.EXPECT().NextProbeTime(), - ) - ping := ackhandler.Frame{Frame: &wire.PingFrame{}} - mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234)) - packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234)).Return(getPacket(1), nil) - go func() { - defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() - }() - sess.scheduleSending() - Eventually(written).Should(Receive()) - }) - } + It("sends a Path MTU probe packet", func() { + mtuDiscoverer := NewMockMtuDiscoverer(mockCtrl) + conn.mtuDiscoverer = mtuDiscoverer + conn.config.DisablePathMTUDiscovery = false + sph.EXPECT().SentPacket(gomock.Any()) + sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny) + sph.EXPECT().SendMode().Return(ackhandler.SendNone) + written := make(chan struct{}, 1) + sender.EXPECT().WouldBlock().AnyTimes() + sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) + gomock.InOrder( + mtuDiscoverer.EXPECT().NextProbeTime(), + mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true), + mtuDiscoverer.EXPECT().NextProbeTime(), + ) + ping := ackhandler.Frame{Frame: &wire.PingFrame{}} + mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234)) + packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234)).Return(getPacket(1), nil) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + conn.run() + }() + conn.scheduleSending() + Eventually(written).Should(Receive()) + }) }) Context("scheduling sending", func() { @@ -1721,8 +1717,8 @@ var _ = Describe("Session", func() { sender = NewMockSender(mockCtrl) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Run() - sess.sendQueue = sender - sess.handshakeConfirmed = true + conn.sendQueue = sender + conn.handshakeConfirmed = true }) AfterEach(func() { @@ -1735,8 +1731,8 @@ var _ = Describe("Session", func() { sender.EXPECT().Close() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - sess.shutdown() - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) }) It("sends when scheduleSending is called", func() { @@ -1746,14 +1742,14 @@ var _ = Describe("Session", func() { sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SentPacket(gomock.Any()) - sess.sentPacketHandler = sph + conn.sentPacketHandler = sph packer.EXPECT().PackPacket().Return(getPacket(1), nil) packer.EXPECT().PackPacket().Return(nil, nil) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() // don't EXPECT any calls to mconn.Write() time.Sleep(50 * time.Millisecond) @@ -1761,7 +1757,7 @@ var _ = Describe("Session", func() { written := make(chan struct{}) sender.EXPECT().Send(gomock.Any()).Do(func(*packetBuffer) { close(written) }) tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - sess.scheduleSending() + conn.scheduleSending() Eventually(written).Should(BeClosed()) }) @@ -1775,12 +1771,12 @@ var _ = Describe("Session", func() { sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(1234))) }) - sess.sentPacketHandler = sph + conn.sentPacketHandler = sph rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond)) // make the run loop wait rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour)).MaxTimes(1) - sess.receivedPacketHandler = rph + conn.receivedPacketHandler = rph written := make(chan struct{}) sender.EXPECT().Send(gomock.Any()).Do(func(*packetBuffer) { close(written) }) @@ -1788,17 +1784,17 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() Eventually(written).Should(BeClosed()) }) }) It("sends coalesced packets before the handshake is confirmed", func() { - sess.handshakeComplete = false - sess.handshakeConfirmed = false + conn.handshakeComplete = false + conn.handshakeConfirmed = false sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sess.sentPacketHandler = sph + conn.sentPacketHandler = sph buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("foobar")...) packer.EXPECT().PackCoalescedPacket().Return(&coalescedPacket{ @@ -1858,10 +1854,10 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() - sess.scheduleSending() + conn.scheduleSending() Eventually(sent).Should(BeClosed()) // make sure the go routine returns @@ -1872,30 +1868,30 @@ var _ = Describe("Session", func() { mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - sess.shutdown() - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) }) It("cancels the HandshakeComplete context when the handshake completes", func() { packer.EXPECT().PackCoalescedPacket().AnyTimes() finishHandshake := make(chan struct{}) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sess.sentPacketHandler = sph + conn.sentPacketHandler = sph sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().TimeUntilSend().AnyTimes() sph.EXPECT().SendMode().AnyTimes() sph.EXPECT().SetHandshakeConfirmed() - sessionRunner.EXPECT().Retire(clientDestConnID) + connRunner.EXPECT().Retire(clientDestConnID) go func() { defer GinkgoRecover() <-finishHandshake cryptoSetup.EXPECT().RunHandshake() cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() - close(sess.handshakeCompleteChan) - sess.run() + close(conn.handshakeCompleteChan) + conn.run() }() - handshakeCtx := sess.HandshakeComplete() + handshakeCtx := conn.HandshakeComplete() Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) close(finishHandshake) Eventually(handshakeCtx.Done()).Should(BeClosed()) @@ -1907,31 +1903,31 @@ var _ = Describe("Session", func() { mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - sess.shutdown() - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) }) - It("sends a session ticket when the handshake completes", func() { + It("sends a connection ticket when the handshake completes", func() { const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2 packer.EXPECT().PackCoalescedPacket().AnyTimes() finishHandshake := make(chan struct{}) - sessionRunner.EXPECT().Retire(clientDestConnID) + connRunner.EXPECT().Retire(clientDestConnID) go func() { defer GinkgoRecover() <-finishHandshake cryptoSetup.EXPECT().RunHandshake() cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil) - close(sess.handshakeCompleteChan) - sess.run() + close(conn.handshakeCompleteChan) + conn.run() }() - handshakeCtx := sess.HandshakeComplete() + handshakeCtx := conn.HandshakeComplete() Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) close(finishHandshake) var frames []ackhandler.Frame Eventually(func() []ackhandler.Frame { - frames, _ = sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) + frames, _ = conn.framer.AppendControlFrames(nil, protocol.MaxByteCount) return frames }).ShouldNot(BeEmpty()) var count int @@ -1940,7 +1936,7 @@ var _ = Describe("Session", func() { if cf, ok := f.Frame.(*wire.CryptoFrame); ok { count++ s += len(cf.Data) - Expect(f.Length(sess.version)).To(BeNumerically("<=", protocol.MaxPostHandshakeCryptoFrameSize)) + Expect(f.Length(conn.version)).To(BeNumerically("<=", protocol.MaxPostHandshakeCryptoFrameSize)) } } Expect(size).To(BeEquivalentTo(s)) @@ -1952,8 +1948,8 @@ var _ = Describe("Session", func() { mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - sess.shutdown() - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) }) It("doesn't cancel the HandshakeComplete context when the handshake fails", func() { @@ -1967,14 +1963,14 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() - sess.run() + conn.run() }() - handshakeCtx := sess.HandshakeComplete() + handshakeCtx := conn.HandshakeComplete() Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) mconn.EXPECT().Write(gomock.Any()) - sess.closeLocal(errors.New("handshake error")) + conn.closeLocal(errors.New("handshake error")) Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) - Eventually(sess.Context().Done()).Should(BeClosed()) + Eventually(conn.Context().Done()).Should(BeClosed()) }) It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { @@ -1987,11 +1983,11 @@ var _ = Describe("Session", func() { sph.EXPECT().SentPacket(gomock.Any()) mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - sess.sentPacketHandler = sph + conn.sentPacketHandler = sph done := make(chan struct{}) - sessionRunner.EXPECT().Retire(clientDestConnID) + connRunner.EXPECT().Retire(clientDestConnID) packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { - frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) + frames, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(frames).ToNot(BeEmpty()) Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) defer close(done) @@ -2009,8 +2005,8 @@ var _ = Describe("Session", func() { cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() mconn.EXPECT().Write(gomock.Any()) - close(sess.handshakeCompleteChan) - sess.run() + close(conn.handshakeCompleteChan) + conn.run() }() Eventually(done).Should(BeClosed()) // make sure the go routine returns @@ -2020,8 +2016,8 @@ var _ = Describe("Session", func() { cryptoSetup.EXPECT().Close() tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - sess.shutdown() - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) }) It("doesn't return a run error when closing", func() { @@ -2029,7 +2025,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - Expect(sess.run()).To(Succeed()) + Expect(conn.run()).To(Succeed()) close(done) }() streamManager.EXPECT().CloseWithError(gomock.Any()) @@ -2039,17 +2035,17 @@ var _ = Describe("Session", func() { mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - sess.shutdown() + conn.shutdown() Eventually(done).Should(BeClosed()) }) - It("passes errors to the session runner", func() { + It("passes errors to the connection runner", func() { testErr := errors.New("handshake error") done := make(chan struct{}) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - err := sess.run() + err := conn.run() Expect(err).To(MatchError(&qerr.ApplicationError{ ErrorCode: 0x1337, ErrorMessage: testErr.Error(), @@ -2063,7 +2059,7 @@ var _ = Describe("Session", func() { mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - Expect(sess.CloseWithError(0x1337, testErr.Error())).To(Succeed()) + Expect(conn.CloseWithError(0x1337, testErr.Error())).To(Succeed()) Eventually(done).Should(BeClosed()) }) @@ -2081,12 +2077,12 @@ var _ = Describe("Session", func() { streamManager.EXPECT().UpdateLimits(params) packer.EXPECT().HandleTransportParameters(params) packer.EXPECT().PackCoalescedPacket().MaxTimes(3) - Expect(sess.earlySessionReady()).ToNot(BeClosed()) - sessionRunner.EXPECT().GetStatelessResetToken(gomock.Any()).Times(2) - sessionRunner.EXPECT().Add(gomock.Any(), sess).Times(2) + Expect(conn.earlyConnReady()).ToNot(BeClosed()) + connRunner.EXPECT().GetStatelessResetToken(gomock.Any()).Times(2) + connRunner.EXPECT().Add(gomock.Any(), conn).Times(2) tracer.EXPECT().ReceivedTransportParameters(params) - sess.handleTransportParameters(params) - Expect(sess.earlySessionReady()).To(BeClosed()) + conn.handleTransportParameters(params) + Expect(conn.earlyConnReady()).To(BeClosed()) }) }) @@ -2095,24 +2091,24 @@ var _ = Describe("Session", func() { streamManager.EXPECT().UpdateLimits(gomock.Any()) packer.EXPECT().HandleTransportParameters(gomock.Any()) tracer.EXPECT().ReceivedTransportParameters(gomock.Any()) - sess.handleTransportParameters(&wire.TransportParameters{ + conn.handleTransportParameters(&wire.TransportParameters{ MaxIdleTimeout: t, InitialSourceConnectionID: destConnID, }) } - runSession := func() { + runConn := func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() } BeforeEach(func() { - sess.config.MaxIdleTimeout = 30 * time.Second - sess.config.KeepAlive = true - sess.receivedPacketHandler.ReceivedPacket(0, protocol.ECNNon, protocol.EncryptionHandshake, time.Now(), true) + conn.config.MaxIdleTimeout = 30 * time.Second + conn.config.KeepAlive = true + conn.receivedPacketHandler.ReceivedPacket(0, protocol.ECNNon, protocol.EncryptionHandshake, time.Now(), true) }) AfterEach(func() { @@ -2124,51 +2120,51 @@ var _ = Describe("Session", func() { mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - sess.shutdown() - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) }) It("sends a PING as a keep-alive after half the idle timeout", func() { setRemoteIdleTimeout(5 * time.Second) - sess.lastPacketReceivedTime = time.Now().Add(-5 * time.Second / 2) + conn.lastPacketReceivedTime = time.Now().Add(-5 * time.Second / 2) sent := make(chan struct{}) packer.EXPECT().PackCoalescedPacket().Do(func() (*packedPacket, error) { close(sent) return nil, nil }) - runSession() + runConn() Eventually(sent).Should(BeClosed()) }) It("sends a PING after a maximum of protocol.MaxKeepAliveInterval", func() { - sess.config.MaxIdleTimeout = time.Hour + conn.config.MaxIdleTimeout = time.Hour setRemoteIdleTimeout(time.Hour) - sess.lastPacketReceivedTime = time.Now().Add(-protocol.MaxKeepAliveInterval).Add(-time.Millisecond) + conn.lastPacketReceivedTime = time.Now().Add(-protocol.MaxKeepAliveInterval).Add(-time.Millisecond) sent := make(chan struct{}) packer.EXPECT().PackCoalescedPacket().Do(func() (*packedPacket, error) { close(sent) return nil, nil }) - runSession() + runConn() Eventually(sent).Should(BeClosed()) }) It("doesn't send a PING packet if keep-alive is disabled", func() { setRemoteIdleTimeout(5 * time.Second) - sess.config.KeepAlive = false - sess.lastPacketReceivedTime = time.Now().Add(-time.Second * 5 / 2) - runSession() + conn.config.KeepAlive = false + conn.lastPacketReceivedTime = time.Now().Add(-time.Second * 5 / 2) + runConn() // don't EXPECT() any calls to mconn.Write() time.Sleep(50 * time.Millisecond) }) It("doesn't send a PING if the handshake isn't completed yet", func() { - sess.config.HandshakeIdleTimeout = time.Hour - sess.handshakeComplete = false + conn.config.HandshakeIdleTimeout = time.Hour + conn.handshakeComplete = false // Needs to be shorter than our idle timeout. // Otherwise we'll try to send a CONNECTION_CLOSE. - sess.lastPacketReceivedTime = time.Now().Add(-20 * time.Second) - runSession() + conn.lastPacketReceivedTime = time.Now().Add(-20 * time.Second) + runConn() // don't EXPECT() any calls to mconn.Write() time.Sleep(50 * time.Millisecond) }) @@ -2180,8 +2176,8 @@ var _ = Describe("Session", func() { }) It("times out due to no network activity", func() { - sessionRunner.EXPECT().Remove(gomock.Any()).Times(2) - sess.lastPacketReceivedTime = time.Now().Add(-time.Hour) + connRunner.EXPECT().Remove(gomock.Any()).Times(2) + conn.lastPacketReceivedTime = time.Now().Add(-time.Hour) done := make(chan struct{}) cryptoSetup.EXPECT().Close() gomock.InOrder( @@ -2193,7 +2189,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - err := sess.run() + err := conn.run() nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) Expect(nerr.Timeout()).To(BeTrue()) @@ -2204,9 +2200,9 @@ var _ = Describe("Session", func() { }) It("times out due to non-completed handshake", func() { - sess.handshakeComplete = false - sess.sessionCreationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second) - sessionRunner.EXPECT().Remove(gomock.Any()).Times(2) + conn.handshakeComplete = false + conn.creationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second) + connRunner.EXPECT().Remove(gomock.Any()).Times(2) cryptoSetup.EXPECT().Close() gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { @@ -2218,7 +2214,7 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - err := sess.run() + err := conn.run() nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) Expect(nerr.Timeout()).To(BeTrue()) @@ -2229,10 +2225,10 @@ var _ = Describe("Session", func() { }) It("does not use the idle timeout before the handshake complete", func() { - sess.handshakeComplete = false - sess.config.HandshakeIdleTimeout = 9999 * time.Second - sess.config.MaxIdleTimeout = 9999 * time.Second - sess.lastPacketReceivedTime = time.Now().Add(-time.Minute) + conn.handshakeComplete = false + conn.config.HandshakeIdleTimeout = 9999 * time.Second + conn.config.MaxIdleTimeout = 9999 * time.Second + conn.lastPacketReceivedTime = time.Now().Add(-time.Minute) packer.EXPECT().PackApplicationClose(gomock.Any()).DoAndReturn(func(e *qerr.ApplicationError) (*coalescedPacket, error) { Expect(e.ErrorCode).To(BeZero()) return &coalescedPacket{buffer: getPacketBuffer()}, nil @@ -2246,26 +2242,26 @@ var _ = Describe("Session", func() { }), tracer.EXPECT().Close(), ) - // the handshake timeout is irrelevant here, since it depends on the time the session was created, + // the handshake timeout is irrelevant here, since it depends on the time the connection was created, // and not on the last network activity go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() - Consistently(sess.Context().Done()).ShouldNot(BeClosed()) + Consistently(conn.Context().Done()).ShouldNot(BeClosed()) // make the go routine return expectReplaceWithClosed() cryptoSetup.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) - sess.shutdown() - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) }) - It("closes the session due to the idle timeout before handshake", func() { - sess.config.HandshakeIdleTimeout = 0 + It("closes the connection due to the idle timeout before handshake", func() { + conn.config.HandshakeIdleTimeout = 0 packer.EXPECT().PackCoalescedPacket().AnyTimes() - sessionRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() gomock.InOrder( tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { @@ -2274,12 +2270,12 @@ var _ = Describe("Session", func() { tracer.EXPECT().Close(), ) done := make(chan struct{}) - sess.handshakeComplete = false + conn.handshakeComplete = false go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) - err := sess.run() + err := conn.run() nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) Expect(nerr.Timeout()).To(BeTrue()) @@ -2289,11 +2285,11 @@ var _ = Describe("Session", func() { Eventually(done).Should(BeClosed()) }) - It("closes the session due to the idle timeout after handshake", func() { + It("closes the connection due to the idle timeout after handshake", func() { packer.EXPECT().PackCoalescedPacket().AnyTimes() gomock.InOrder( - sessionRunner.EXPECT().Retire(clientDestConnID), - sessionRunner.EXPECT().Remove(gomock.Any()), + connRunner.EXPECT().Retire(clientDestConnID), + connRunner.EXPECT().Remove(gomock.Any()), ) cryptoSetup.EXPECT().Close() gomock.InOrder( @@ -2302,15 +2298,15 @@ var _ = Describe("Session", func() { }), tracer.EXPECT().Close(), ) - sess.idleTimeout = 0 + conn.idleTimeout = 0 done := make(chan struct{}) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1) - close(sess.handshakeCompleteChan) - err := sess.run() + close(conn.handshakeCompleteChan) + err := conn.run() nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) Expect(nerr.Timeout()).To(BeTrue()) @@ -2321,15 +2317,15 @@ var _ = Describe("Session", func() { }) It("doesn't time out when it just sent a packet", func() { - sess.lastPacketReceivedTime = time.Now().Add(-time.Hour) - sess.firstAckElicitingPacketAfterIdleSentTime = time.Now().Add(-time.Second) - sess.idleTimeout = 30 * time.Second + conn.lastPacketReceivedTime = time.Now().Add(-time.Hour) + conn.firstAckElicitingPacketAfterIdleSentTime = time.Now().Add(-time.Second) + conn.idleTimeout = 30 * time.Second go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() - Consistently(sess.Context().Done()).ShouldNot(BeClosed()) + Consistently(conn.Context().Done()).ShouldNot(BeClosed()) // make the go routine return packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() @@ -2337,19 +2333,19 @@ var _ = Describe("Session", func() { mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - sess.shutdown() - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) }) }) - It("stores up to MaxSessionUnprocessedPackets packets", func() { + It("stores up to MaxConnUnprocessedPackets packets", func() { done := make(chan struct{}) tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, logging.ByteCount(6), logging.PacketDropDOSPrevention).Do(func(logging.PacketType, logging.ByteCount, logging.PacketDropReason) { close(done) }) // Nothing here should block - for i := protocol.PacketNumber(0); i < protocol.MaxSessionUnprocessedPackets+1; i++ { - sess.handlePacket(&receivedPacket{data: []byte("foobar")}) + for i := protocol.PacketNumber(0); i < protocol.MaxConnUnprocessedPackets+1; i++ { + conn.handlePacket(&receivedPacket{data: []byte("foobar")}) } Eventually(done).Should(BeClosed()) }) @@ -2358,7 +2354,7 @@ var _ = Describe("Session", func() { It("opens streams", func() { mstr := NewMockStreamI(mockCtrl) streamManager.EXPECT().OpenStream().Return(mstr, nil) - str, err := sess.OpenStream() + str, err := conn.OpenStream() Expect(err).ToNot(HaveOccurred()) Expect(str).To(Equal(mstr)) }) @@ -2366,7 +2362,7 @@ var _ = Describe("Session", func() { It("opens streams synchronously", func() { mstr := NewMockStreamI(mockCtrl) streamManager.EXPECT().OpenStreamSync(context.Background()).Return(mstr, nil) - str, err := sess.OpenStreamSync(context.Background()) + str, err := conn.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str).To(Equal(mstr)) }) @@ -2374,7 +2370,7 @@ var _ = Describe("Session", func() { It("opens unidirectional streams", func() { mstr := NewMockSendStreamI(mockCtrl) streamManager.EXPECT().OpenUniStream().Return(mstr, nil) - str, err := sess.OpenUniStream() + str, err := conn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) Expect(str).To(Equal(mstr)) }) @@ -2382,7 +2378,7 @@ var _ = Describe("Session", func() { It("opens unidirectional streams synchronously", func() { mstr := NewMockSendStreamI(mockCtrl) streamManager.EXPECT().OpenUniStreamSync(context.Background()).Return(mstr, nil) - str, err := sess.OpenUniStreamSync(context.Background()) + str, err := conn.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) Expect(str).To(Equal(mstr)) }) @@ -2392,7 +2388,7 @@ var _ = Describe("Session", func() { defer cancel() mstr := NewMockStreamI(mockCtrl) streamManager.EXPECT().AcceptStream(ctx).Return(mstr, nil) - str, err := sess.AcceptStream(ctx) + str, err := conn.AcceptStream(ctx) Expect(err).ToNot(HaveOccurred()) Expect(str).To(Equal(mstr)) }) @@ -2402,38 +2398,38 @@ var _ = Describe("Session", func() { defer cancel() mstr := NewMockReceiveStreamI(mockCtrl) streamManager.EXPECT().AcceptUniStream(ctx).Return(mstr, nil) - str, err := sess.AcceptUniStream(ctx) + str, err := conn.AcceptUniStream(ctx) Expect(err).ToNot(HaveOccurred()) Expect(str).To(Equal(mstr)) }) }) It("returns the local address", func() { - Expect(sess.LocalAddr()).To(Equal(localAddr)) + Expect(conn.LocalAddr()).To(Equal(localAddr)) }) It("returns the remote address", func() { - Expect(sess.RemoteAddr()).To(Equal(remoteAddr)) + Expect(conn.RemoteAddr()).To(Equal(remoteAddr)) }) }) -var _ = Describe("Client Session", func() { +var _ = Describe("Client Connection", func() { var ( - sess *session - sessionRunner *MockSessionRunner - packer *MockPacker - mconn *MockSendConn - cryptoSetup *mocks.MockCryptoSetup - tracer *mocklogging.MockConnectionTracer - tlsConf *tls.Config - quicConf *Config + conn *connection + connRunner *MockConnRunner + packer *MockPacker + mconn *MockSendConn + cryptoSetup *mocks.MockCryptoSetup + tracer *mocklogging.MockConnectionTracer + tlsConf *tls.Config + quicConf *Config ) srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket { buf := &bytes.Buffer{} - Expect(hdr.Write(buf, sess.version)).To(Succeed()) + Expect(hdr.Write(buf, conn.version)).To(Succeed()) return &receivedPacket{ data: append(buf.Bytes(), data...), buffer: getPacketBuffer(), @@ -2441,9 +2437,9 @@ var _ = Describe("Client Session", func() { } expectReplaceWithClosed := func() { - sessionRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { s.shutdown() - Eventually(areClosedSessionsRunning).Should(BeFalse()) + Eventually(areClosedConnsRunning).Should(BeFalse()) }) } @@ -2453,7 +2449,7 @@ var _ = Describe("Client Session", func() { }) JustBeforeEach(func() { - Eventually(areSessionsRunning).Should(BeFalse()) + Eventually(areConnsRunning).Should(BeFalse()) mconn = NewMockSendConn(mockCtrl) mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).AnyTimes() @@ -2461,15 +2457,15 @@ var _ = Describe("Client Session", func() { if tlsConf == nil { tlsConf = &tls.Config{} } - sessionRunner = NewMockSessionRunner(mockCtrl) + connRunner = NewMockConnRunner(mockCtrl) tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().SentTransportParameters(gomock.Any()) tracer.EXPECT().UpdatedKeyFromTLS(gomock.Any(), gomock.Any()).AnyTimes() tracer.EXPECT().UpdatedCongestionState(gomock.Any()) - sess = newClientSession( + conn = newClientConnection( mconn, - sessionRunner, + connRunner, destConnID, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, quicConf, @@ -2481,11 +2477,11 @@ var _ = Describe("Client Session", func() { 1234, utils.DefaultLogger, protocol.VersionTLS, - ).(*session) + ).(*connection) packer = NewMockPacker(mockCtrl) - sess.packer = packer + conn.packer = packer cryptoSetup = mocks.NewMockCryptoSetup(mockCtrl) - sess.cryptoStreamHandler = cryptoSetup + conn.cryptoStreamHandler = cryptoSetup }) It("changes the connection ID when receiving the first packet from the server", func() { @@ -2497,11 +2493,11 @@ var _ = Describe("Client Session", func() { data: []byte{0}, // one PADDING frame }, nil }) - sess.unpacker = unpacker + conn.unpacker = unpacker go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - sess.run() + conn.run() }() newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7} p := getPacket(&wire.ExtendedHeader{ @@ -2511,12 +2507,12 @@ var _ = Describe("Client Session", func() { SrcConnectionID: newConnID, DestConnectionID: srcConnID, Length: 2 + 6, - Version: sess.version, + Version: conn.version, }, PacketNumberLen: protocol.PacketNumberLen2, }, []byte("foobar")) tracer.EXPECT().ReceivedPacket(gomock.Any(), p.Size(), []logging.Frame{}) - Expect(sess.handlePacketImpl(p)).To(BeTrue()) + Expect(conn.handlePacketImpl(p)).To(BeTrue()) // make sure the go routine returns packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) expectReplaceWithClosed() @@ -2524,20 +2520,20 @@ var _ = Describe("Client Session", func() { mconn.EXPECT().Write(gomock.Any()) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() - sess.shutdown() - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) }) It("continues accepting Long Header packets after using a new connection ID", func() { unpacker := NewMockUnpacker(mockCtrl) - sess.unpacker = unpacker - sessionRunner.EXPECT().AddResetToken(gomock.Any(), gomock.Any()) - sess.connIDManager.SetHandshakeComplete() - sess.handleNewConnectionIDFrame(&wire.NewConnectionIDFrame{ + conn.unpacker = unpacker + connRunner.EXPECT().AddResetToken(gomock.Any(), gomock.Any()) + conn.connIDManager.SetHandshakeComplete() + conn.handleNewConnectionIDFrame(&wire.NewConnectionIDFrame{ SequenceNumber: 1, ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}, }) - Expect(sess.connIDManager.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5})) + Expect(conn.connIDManager.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5})) // now receive a packet with the original source connection ID unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte) (*unpackedPacket, error) { return &unpackedPacket{ @@ -2553,28 +2549,28 @@ var _ = Describe("Client Session", func() { SrcConnectionID: destConnID, } tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()) - Expect(sess.handleSinglePacket(&receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue()) + Expect(conn.handleSinglePacket(&receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue()) }) It("handles HANDSHAKE_DONE frames", func() { - sess.peerParams = &wire.TransportParameters{} + conn.peerParams = &wire.TransportParameters{} sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sess.sentPacketHandler = sph + conn.sentPacketHandler = sph sph.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().SetHandshakeConfirmed() - Expect(sess.handleHandshakeDoneFrame()).To(Succeed()) + Expect(conn.handleHandshakeDoneFrame()).To(Succeed()) }) It("interprets an ACK for 1-RTT packets as confirmation of the handshake", func() { - sess.peerParams = &wire.TransportParameters{} + conn.peerParams = &wire.TransportParameters{} sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sess.sentPacketHandler = sph + conn.sentPacketHandler = sph ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 3}}} sph.EXPECT().ReceivedAck(ack, protocol.Encryption1RTT, gomock.Any()).Return(true, nil) sph.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().SetLargest1RTTAcked(protocol.PacketNumber(3)) cryptoSetup.EXPECT().SetHandshakeConfirmed() - Expect(sess.handleAckFrame(ack, protocol.Encryption1RTT)).To(Succeed()) + Expect(conn.handleAckFrame(ack, protocol.Encryption1RTT)).To(Succeed()) }) Context("handling tokens", func() { @@ -2590,14 +2586,13 @@ var _ = Describe("Client Session", func() { It("handles NEW_TOKEN frames", func() { mockTokenStore.EXPECT().Put("server", &ClientToken{data: []byte("foobar")}) - Expect(sess.handleNewTokenFrame(&wire.NewTokenFrame{Token: []byte("foobar")})).To(Succeed()) + Expect(conn.handleNewTokenFrame(&wire.NewTokenFrame{Token: []byte("foobar")})).To(Succeed()) }) }) Context("handling Version Negotiation", func() { getVNP := func(versions ...protocol.VersionNumber) *receivedPacket { - b, err := wire.ComposeVersionNegotiation(srcConnID, destConnID, versions) - Expect(err).ToNot(HaveOccurred()) + b := wire.ComposeVersionNegotiation(srcConnID, destConnID, versions) return &receivedPacket{ data: b, buffer: getPacketBuffer(), @@ -2606,17 +2601,17 @@ var _ = Describe("Client Session", func() { It("closes and returns the right error", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sess.sentPacketHandler = sph + conn.sentPacketHandler = sph sph.EXPECT().ReceivedBytes(gomock.Any()) sph.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(128), protocol.PacketNumberLen4) - sess.config.Versions = []protocol.VersionNumber{1234, 4321} + conn.config.Versions = []protocol.VersionNumber{1234, 4321} errChan := make(chan error, 1) go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - errChan <- sess.run() + errChan <- conn.run() }() - sessionRunner.EXPECT().Remove(srcConnID) + connRunner.EXPECT().Remove(srcConnID) tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any()).Do(func(hdr *wire.Header, versions []logging.VersionNumber) { Expect(hdr.Version).To(BeZero()) Expect(versions).To(And( @@ -2625,7 +2620,7 @@ var _ = Describe("Client Session", func() { )) }) cryptoSetup.EXPECT().Close() - Expect(sess.handlePacketImpl(getVNP(4321, 1337))).To(BeFalse()) + Expect(conn.handlePacketImpl(getVNP(4321, 1337))).To(BeFalse()) var err error Eventually(errChan).Should(Receive(&err)) Expect(err).To(HaveOccurred()) @@ -2640,9 +2635,9 @@ var _ = Describe("Client Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - errChan <- sess.run() + errChan <- conn.run() }() - sessionRunner.EXPECT().Remove(srcConnID).MaxTimes(1) + connRunner.EXPECT().Remove(srcConnID).MaxTimes(1) gomock.InOrder( tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any()), tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { @@ -2653,7 +2648,7 @@ var _ = Describe("Client Session", func() { tracer.EXPECT().Close(), ) cryptoSetup.EXPECT().Close() - Expect(sess.handlePacketImpl(getVNP(12345678))).To(BeFalse()) + Expect(conn.handlePacketImpl(getVNP(12345678))).To(BeFalse()) var err error Eventually(errChan).Should(Receive(&err)) Expect(err).To(HaveOccurred()) @@ -2662,16 +2657,16 @@ var _ = Describe("Client Session", func() { }) It("ignores Version Negotiation packets that offer the current version", func() { - p := getVNP(sess.version) + p := getVNP(conn.version) tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedVersion) - Expect(sess.handlePacketImpl(p)).To(BeFalse()) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) It("ignores unparseable Version Negotiation packets", func() { - p := getVNP(sess.version) + p := getVNP(conn.version) p.data = p.data[:len(p.data)-2] tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) - Expect(sess.handlePacketImpl(p)).To(BeFalse()) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) }) @@ -2688,20 +2683,20 @@ var _ = Describe("Client Session", func() { SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, Token: []byte("foobar"), - Version: sess.version, + Version: conn.version, }, } }) getRetryTag := func(hdr *wire.ExtendedHeader) []byte { buf := &bytes.Buffer{} - hdr.Write(buf, sess.version) + hdr.Write(buf, conn.version) return handshake.GetRetryIntegrityTag(buf.Bytes(), origDestConnID, hdr.Version)[:] } It("handles Retry packets", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sess.sentPacketHandler = sph + conn.sentPacketHandler = sph sph.EXPECT().ResetForRetry() sph.EXPECT().ReceivedBytes(gomock.Any()) cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}) @@ -2711,21 +2706,21 @@ var _ = Describe("Client Session", func() { Expect(hdr.SrcConnectionID).To(Equal(retryHdr.SrcConnectionID)) Expect(hdr.Token).To(Equal(retryHdr.Token)) }) - Expect(sess.handlePacketImpl(getPacket(retryHdr, getRetryTag(retryHdr)))).To(BeTrue()) + Expect(conn.handlePacketImpl(getPacket(retryHdr, getRetryTag(retryHdr)))).To(BeTrue()) }) It("ignores Retry packets after receiving a regular packet", func() { - sess.receivedFirstPacket = true + conn.receivedFirstPacket = true p := getPacket(retryHdr, getRetryTag(retryHdr)) tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropUnexpectedPacket) - Expect(sess.handlePacketImpl(p)).To(BeFalse()) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) It("ignores Retry packets if the server didn't change the connection ID", func() { retryHdr.SrcConnectionID = destConnID p := getPacket(retryHdr, getRetryTag(retryHdr)) tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropUnexpectedPacket) - Expect(sess.handlePacketImpl(p)).To(BeFalse()) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) It("ignores Retry packets with the a wrong Integrity tag", func() { @@ -2733,7 +2728,7 @@ var _ = Describe("Client Session", func() { tag[0]++ p := getPacket(retryHdr, tag) tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, p.Size(), logging.PacketDropPayloadDecryptError) - Expect(sess.handlePacketImpl(p)).To(BeFalse()) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) }) @@ -2749,15 +2744,15 @@ var _ = Describe("Client Session", func() { go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) - errChan <- sess.run() + errChan <- conn.run() close(errChan) }() }) expectClose := func(applicationClose bool) { if !closed { - sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedLocalSession{})) + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) { + Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{})) s.shutdown() }) if applicationClose { @@ -2776,8 +2771,8 @@ var _ = Describe("Client Session", func() { } AfterEach(func() { - sess.shutdown() - Eventually(sess.Context().Done()).Should(BeClosed()) + conn.shutdown() + Eventually(conn.Context().Done()).Should(BeClosed()) Eventually(errChan).Should(BeClosed()) }) @@ -2795,20 +2790,20 @@ var _ = Describe("Client Session", func() { packer.EXPECT().HandleTransportParameters(gomock.Any()) packer.EXPECT().PackCoalescedPacket().MaxTimes(1) tracer.EXPECT().ReceivedTransportParameters(params) - sess.handleTransportParameters(params) - sess.handleHandshakeComplete() + conn.handleTransportParameters(params) + conn.handleHandshakeComplete() // make sure the connection ID is not retired - cf, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) + cf, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(cf).To(BeEmpty()) - sessionRunner.EXPECT().AddResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, sess) - Expect(sess.connIDManager.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + connRunner.EXPECT().AddResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, conn) + Expect(conn.connIDManager.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) // shut down - sessionRunner.EXPECT().RemoveResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}) + connRunner.EXPECT().RemoveResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}) expectClose(true) }) It("uses the minimum of the peers' idle timeouts", func() { - sess.config.MaxIdleTimeout = 19 * time.Second + conn.config.MaxIdleTimeout = 19 * time.Second params := &wire.TransportParameters{ OriginalDestinationConnectionID: destConnID, InitialSourceConnectionID: destConnID, @@ -2816,14 +2811,14 @@ var _ = Describe("Client Session", func() { } packer.EXPECT().HandleTransportParameters(gomock.Any()) tracer.EXPECT().ReceivedTransportParameters(params) - sess.handleTransportParameters(params) - sess.handleHandshakeComplete() - Expect(sess.idleTimeout).To(Equal(18 * time.Second)) + conn.handleTransportParameters(params) + conn.handleHandshakeComplete() + Expect(conn.idleTimeout).To(Equal(18 * time.Second)) expectClose(true) }) It("errors if the transport parameters contain a wrong initial_source_connection_id", func() { - sess.handshakeDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + conn.handshakeDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} params := &wire.TransportParameters{ OriginalDestinationConnectionID: destConnID, InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, @@ -2831,7 +2826,7 @@ var _ = Describe("Client Session", func() { } expectClose(false) tracer.EXPECT().ReceivedTransportParameters(params) - sess.handleTransportParameters(params) + conn.handleTransportParameters(params) Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ ErrorCode: qerr.TransportParameterError, ErrorMessage: "expected initial_source_connection_id to equal deadbeef, is decafbad", @@ -2839,7 +2834,7 @@ var _ = Describe("Client Session", func() { }) It("errors if the transport parameters don't contain the retry_source_connection_id, if a Retry was performed", func() { - sess.retrySrcConnID = &protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + conn.retrySrcConnID = &protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} params := &wire.TransportParameters{ OriginalDestinationConnectionID: destConnID, InitialSourceConnectionID: destConnID, @@ -2847,7 +2842,7 @@ var _ = Describe("Client Session", func() { } expectClose(false) tracer.EXPECT().ReceivedTransportParameters(params) - sess.handleTransportParameters(params) + conn.handleTransportParameters(params) Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ ErrorCode: qerr.TransportParameterError, ErrorMessage: "missing retry_source_connection_id", @@ -2855,7 +2850,7 @@ var _ = Describe("Client Session", func() { }) It("errors if the transport parameters contain the wrong retry_source_connection_id, if a Retry was performed", func() { - sess.retrySrcConnID = &protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + conn.retrySrcConnID = &protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} params := &wire.TransportParameters{ OriginalDestinationConnectionID: destConnID, InitialSourceConnectionID: destConnID, @@ -2864,7 +2859,7 @@ var _ = Describe("Client Session", func() { } expectClose(false) tracer.EXPECT().ReceivedTransportParameters(params) - sess.handleTransportParameters(params) + conn.handleTransportParameters(params) Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ ErrorCode: qerr.TransportParameterError, ErrorMessage: "expected retry_source_connection_id to equal deadbeef, is deadc0de", @@ -2880,7 +2875,7 @@ var _ = Describe("Client Session", func() { } expectClose(false) tracer.EXPECT().ReceivedTransportParameters(params) - sess.handleTransportParameters(params) + conn.handleTransportParameters(params) Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ ErrorCode: qerr.TransportParameterError, ErrorMessage: "received retry_source_connection_id, although no Retry was performed", @@ -2888,15 +2883,15 @@ var _ = Describe("Client Session", func() { }) It("errors if the transport parameters contain a wrong original_destination_connection_id", func() { - sess.origDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + conn.origDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} params := &wire.TransportParameters{ OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - InitialSourceConnectionID: sess.handshakeDestConnID, + InitialSourceConnectionID: conn.handshakeDestConnID, StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } expectClose(false) tracer.EXPECT().ReceivedTransportParameters(params) - sess.handleTransportParameters(params) + conn.handleTransportParameters(params) Eventually(errChan).Should(Receive(MatchError(&qerr.TransportError{ ErrorCode: qerr.TransportParameterError, ErrorMessage: "expected original_destination_connection_id to equal deadbeef, is decafbad", @@ -2909,7 +2904,7 @@ var _ = Describe("Client Session", func() { getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { buf := &bytes.Buffer{} - Expect(extHdr.Write(buf, sess.version)).To(Succeed()) + Expect(extHdr.Write(buf, conn.version)).To(Succeed()) return &receivedPacket{ data: append(buf.Bytes(), data...), buffer: getPacketBuffer(), @@ -2929,7 +2924,7 @@ var _ = Describe("Client Session", func() { It("ignores Initial packets with a different source connection ID", func() { // Modified from test "ignores packets with a different source connection ID" unpacker = NewMockUnpacker(mockCtrl) - sess.unpacker = unpacker + conn.unpacker = unpacker hdr1 := &wire.ExtendedHeader{ Header: wire.Header{ @@ -2938,7 +2933,7 @@ var _ = Describe("Client Session", func() { DestConnectionID: destConnID, SrcConnectionID: srcConnID, Length: 1, - Version: sess.version, + Version: conn.version, }, PacketNumberLen: protocol.PacketNumberLen1, PacketNumber: 1, @@ -2950,7 +2945,7 @@ var _ = Describe("Client Session", func() { DestConnectionID: destConnID, SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, Length: 1, - Version: sess.version, + Version: conn.version, }, PacketNumberLen: protocol.PacketNumberLen1, PacketNumber: 2, @@ -2964,10 +2959,10 @@ var _ = Describe("Client Session", func() { data: []byte{0}, // one PADDING frame }, nil) tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()) - Expect(sess.handlePacketImpl(getPacket(hdr1, nil))).To(BeTrue()) + Expect(conn.handlePacketImpl(getPacket(hdr1, nil))).To(BeTrue()) // The next packet has to be ignored, since the source connection ID doesn't match. tracer.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any()) - Expect(sess.handlePacketImpl(getPacket(hdr2, nil))).To(BeFalse()) + Expect(conn.handlePacketImpl(getPacket(hdr2, nil))).To(BeFalse()) }) It("ignores 0-RTT packets", func() { @@ -2977,22 +2972,22 @@ var _ = Describe("Client Session", func() { Type: protocol.PacketType0RTT, DestConnectionID: srcConnID, Length: 2 + 6, - Version: sess.version, + Version: conn.version, }, PacketNumber: 0x42, PacketNumberLen: protocol.PacketNumberLen2, }, []byte("foobar")) tracer.EXPECT().DroppedPacket(logging.PacketType0RTT, p.Size(), gomock.Any()) - Expect(sess.handlePacketImpl(p)).To(BeFalse()) + Expect(conn.handlePacketImpl(p)).To(BeFalse()) }) // Illustrates that an injected Initial with an ACK frame for an unsent packet causes // the connection to immediately break down It("fails on Initial-level ACK for unsent packet", func() { ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, sess.version, destConnID, []wire.Frame{ack}) + initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{ack}) tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()) - Expect(sess.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) + Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) }) // Illustrates that an injected Initial with a CONNECTION_CLOSE frame causes @@ -3002,16 +2997,16 @@ var _ = Describe("Client Session", func() { IsApplicationError: true, ReasonPhrase: "mitm attacker", } - initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, sess.version, destConnID, []wire.Frame{connCloseFrame}) + initialPacket := testutils.ComposeInitialPacket(destConnID, srcConnID, conn.version, destConnID, []wire.Frame{connCloseFrame}) tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()) - Expect(sess.handlePacketImpl(wrapPacket(initialPacket))).To(BeTrue()) + Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeTrue()) }) // Illustrates that attacker who injects a Retry packet and changes the connection ID // can cause subsequent real Initial packets to be ignored It("ignores Initial packets which use original source id, after accepting a Retry", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sess.sentPacketHandler = sph + conn.sentPacketHandler = sph sph.EXPECT().ReceivedBytes(gomock.Any()).Times(2) sph.EXPECT().ResetForRetry() newSrcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} @@ -3019,10 +3014,10 @@ var _ = Describe("Client Session", func() { packer.EXPECT().SetToken([]byte("foobar")) tracer.EXPECT().ReceivedRetry(gomock.Any()) - sess.handlePacketImpl(wrapPacket(testutils.ComposeRetryPacket(newSrcConnID, destConnID, destConnID, []byte("foobar"), sess.version))) - initialPacket := testutils.ComposeInitialPacket(sess.connIDManager.Get(), srcConnID, sess.version, sess.connIDManager.Get(), nil) + conn.handlePacketImpl(wrapPacket(testutils.ComposeRetryPacket(newSrcConnID, destConnID, destConnID, []byte("foobar"), conn.version))) + initialPacket := testutils.ComposeInitialPacket(conn.connIDManager.Get(), srcConnID, conn.version, conn.connIDManager.Get(), nil) tracer.EXPECT().DroppedPacket(gomock.Any(), gomock.Any(), gomock.Any()) - Expect(sess.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) + Expect(conn.handlePacketImpl(wrapPacket(initialPacket))).To(BeFalse()) }) }) }) diff --git a/example/echo/echo.go b/example/echo/echo.go index 7c1ae28032a..9ad707d2bf4 100644 --- a/example/echo/echo.go +++ b/example/echo/echo.go @@ -12,7 +12,7 @@ import ( "log" "math/big" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" ) const addr = "localhost:4242" @@ -36,11 +36,11 @@ func echoServer() error { if err != nil { return err } - sess, err := listener.Accept(context.Background()) + conn, err := listener.Accept(context.Background()) if err != nil { return err } - stream, err := sess.AcceptStream(context.Background()) + stream, err := conn.AcceptStream(context.Background()) if err != nil { panic(err) } @@ -54,12 +54,12 @@ func clientMain() error { InsecureSkipVerify: true, NextProtos: []string{"quic-echo-example"}, } - session, err := quic.DialAddr(addr, tlsConf, nil) + conn, err := quic.DialAddr(addr, tlsConf, nil) if err != nil { return err } - stream, err := session.OpenStreamSync(context.Background()) + stream, err := conn.OpenStreamSync(context.Background()) if err != nil { return err } diff --git a/fuzzing/handshake/fuzz.go b/fuzzing/handshake/fuzz.go index 8c0d0ce46a8..c761ad0ca41 100644 --- a/fuzzing/handshake/fuzz.go +++ b/fuzzing/handshake/fuzz.go @@ -300,7 +300,6 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. serverConf.ClientAuth = getClientAuth(runConfig[1] & 0b00000111) serverConf.CipherSuites = getSuites(runConfig[1] >> 6) serverConf.SessionTicketsDisabled = helper.NthBit(runConfig[1], 3) - clientConf.PreferServerCipherSuites = helper.NthBit(runConfig[1], 4) if helper.NthBit(runConfig[2], 0) { clientConf.RootCAs = x509.NewCertPool() } diff --git a/fuzzing/header/cmd/corpus.go b/fuzzing/header/cmd/corpus.go index 2422afcc82d..eeb880ce860 100644 --- a/fuzzing/header/cmd/corpus.go +++ b/fuzzing/header/cmd/corpus.go @@ -24,11 +24,7 @@ func getVNP(src, dest protocol.ConnectionID, numVersions int) []byte { for i := 0; i < numVersions; i++ { versions[i] = protocol.VersionNumber(rand.Uint32()) } - data, err := wire.ComposeVersionNegotiation(src, dest, versions) - if err != nil { - log.Fatal(err) - } - return data + return wire.ComposeVersionNegotiation(src, dest, versions) } func main() { diff --git a/fuzzing/header/fuzz.go b/fuzzing/header/fuzz.go index 46bdfa89523..be7564b9e48 100644 --- a/fuzzing/header/fuzz.go +++ b/fuzzing/header/fuzz.go @@ -91,8 +91,6 @@ func fuzzVNP(data []byte) int { if len(versions) == 0 { panic("no versions") } - if _, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, versions); err != nil { - panic(err) - } + wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, versions) return 1 } diff --git a/go.mod b/go.mod index e868284e17a..75b04de0b01 100644 --- a/go.mod +++ b/go.mod @@ -7,14 +7,14 @@ require ( github.com/francoispqt/gojay v1.2.13 github.com/golang/mock v1.6.0 github.com/marten-seemann/qpack v0.2.1 - github.com/marten-seemann/qtls-go1-16 v0.1.4 - github.com/marten-seemann/qtls-go1-17 v0.1.0 - github.com/marten-seemann/qtls-go1-18 v0.1.0-beta.1 + github.com/marten-seemann/qtls-go1-16 v0.1.5 + github.com/marten-seemann/qtls-go1-17 v0.1.1 + github.com/marten-seemann/qtls-go1-18 v0.1.1 github.com/onsi/ginkgo v1.16.4 github.com/onsi/gomega v1.13.0 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c - golang.org/x/sys v0.0.0-20210510120138-977fb7262007 + golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) diff --git a/go.sum b/go.sum index bf00cfdbe65..5ecb30e941b 100644 --- a/go.sum +++ b/go.sum @@ -80,13 +80,12 @@ github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/marten-seemann/qpack v0.2.1 h1:jvTsT/HpCn2UZJdP+UUB53FfUUgeOyG5K1ns0OJOGVs= github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc= -github.com/marten-seemann/qtls-go1-15 v0.1.4/go.mod h1:GyFwywLKkRt+6mfU99csTEY1joMZz5vmB1WNZH3P81I= -github.com/marten-seemann/qtls-go1-16 v0.1.4 h1:xbHbOGGhrenVtII6Co8akhLEdrawwB2iHl5yhJRpnco= -github.com/marten-seemann/qtls-go1-16 v0.1.4/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= -github.com/marten-seemann/qtls-go1-17 v0.1.0 h1:P9ggrs5xtwiqXv/FHNwntmuLMNq3KaSIG93AtAZ48xk= -github.com/marten-seemann/qtls-go1-17 v0.1.0/go.mod h1:fz4HIxByo+LlWcreM4CZOYNuz3taBQ8rN2X6FqvaWo8= -github.com/marten-seemann/qtls-go1-18 v0.1.0-beta.1 h1:EnzzN9fPUkUck/1CuY1FlzBaIYMoiBsdwTNmNGkwUUM= -github.com/marten-seemann/qtls-go1-18 v0.1.0-beta.1/go.mod h1:PUhIQk19LoFt2174H4+an8TYvWOGjb/hHwphBeaDHwI= +github.com/marten-seemann/qtls-go1-16 v0.1.5 h1:o9JrYPPco/Nukd/HpOHMHZoBDXQqoNtUCmny98/1uqQ= +github.com/marten-seemann/qtls-go1-16 v0.1.5/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= +github.com/marten-seemann/qtls-go1-17 v0.1.1 h1:DQjHPq+aOzUeh9/lixAGunn6rIOQyWChPSI4+hgW7jc= +github.com/marten-seemann/qtls-go1-17 v0.1.1/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s= +github.com/marten-seemann/qtls-go1-18 v0.1.1 h1:qp7p7XXUFL7fpBvSS1sWD+uSqPvzNQK43DH+/qEkj0Y= +github.com/marten-seemann/qtls-go1-18 v0.1.1/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -213,8 +212,9 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007 h1:gG67DSER+11cZvqIMb8S8bt0vZtiN6xWYARwirrOSfE= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/http3/body.go b/http3/body.go index 851eaa1f403..23d4cf556ef 100644 --- a/http3/body.go +++ b/http3/body.go @@ -1,12 +1,31 @@ package http3 import ( + "context" "fmt" "io" + "net" "github.com/lucas-clemente/quic-go" ) +type StreamCreator interface { + OpenStream() (quic.Stream, error) + OpenStreamSync(context.Context) (quic.Stream, error) + OpenUniStream() (quic.SendStream, error) + OpenUniStreamSync(context.Context) (quic.SendStream, error) + LocalAddr() net.Addr + RemoteAddr() net.Addr +} + +var _ StreamCreator = quic.Connection(nil) + +// A Hijacker allows hijacking of the stream creating part of a quic.Session from a http.Response.Body. +// It is used by WebTransport to create WebTransport streams after a session has been established. +type Hijacker interface { + StreamCreator() StreamCreator +} + // The body of a http.Request or http.Response. type body struct { str quic.Stream @@ -24,6 +43,13 @@ type body struct { var _ io.ReadCloser = &body{} +type hijackableBody struct { + body + conn quic.Connection // only needed to implement Hijacker +} + +var _ Hijacker = &hijackableBody{} + func newRequestBody(str quic.Stream, onFrameError func()) *body { return &body{ str: str, @@ -31,14 +57,21 @@ func newRequestBody(str quic.Stream, onFrameError func()) *body { } } -func newResponseBody(str quic.Stream, done chan<- struct{}, onFrameError func()) *body { - return &body{ - str: str, - onFrameError: onFrameError, - reqDone: done, +func newResponseBody(str quic.Stream, conn quic.Connection, done chan<- struct{}, onFrameError func()) *hijackableBody { + return &hijackableBody{ + body: body{ + str: str, + onFrameError: onFrameError, + reqDone: done, + }, + conn: conn, } } +func (r *hijackableBody) StreamCreator() StreamCreator { + return r.conn +} + func (r *body) Read(b []byte) (int, error) { n, err := r.readImpl(b) if err != nil { @@ -51,7 +84,7 @@ func (r *body) readImpl(b []byte) (int, error) { if r.bytesRemainingInFrame == 0 { parseLoop: for { - frame, err := parseNextFrame(r.str) + frame, err := parseNextFrame(r.str, nil) if err != nil { return 0, err } @@ -90,6 +123,10 @@ func (r *body) requestDone() { r.reqDoneClosed = true } +func (r *body) StreamID() quic.StreamID { + return r.str.StreamID() +} + func (r *body) Close() error { r.requestDone() // If the EOF was read, CancelRead() is a no-op. diff --git a/http3/body_test.go b/http3/body_test.go index d9d5c780099..f50004dc325 100644 --- a/http3/body_test.go +++ b/http3/body_test.go @@ -29,7 +29,7 @@ func (t bodyType) String() string { var _ = Describe("Body", func() { var ( - rb *body + rb io.ReadCloser str *mockquic.MockStream buf *bytes.Buffer reqDone chan struct{} @@ -68,7 +68,7 @@ var _ = Describe("Body", func() { rb = newRequestBody(str, errorCb) case bodyTypeResponse: reqDone = make(chan struct{}) - rb = newResponseBody(str, reqDone, errorCb) + rb = newResponseBody(str, nil, reqDone, errorCb) } }) diff --git a/http3/client.go b/http3/client.go index 861eaf0ab70..43d65b327e4 100644 --- a/http3/client.go +++ b/http3/client.go @@ -34,12 +34,17 @@ var defaultQuicConfig = &quic.Config{ Versions: []protocol.VersionNumber{protocol.VersionTLS}, } -var dialAddr = quic.DialAddrEarly +type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) + +var dialAddr = quic.DialAddrEarlyContext type roundTripperOpts struct { DisableCompression bool EnableDatagram bool MaxHeaderBytes int64 + AdditionalSettings map[uint64]uint64 + StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error) + UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool) } // client is a HTTP3 client doing requests @@ -49,7 +54,7 @@ type client struct { opts *roundTripperOpts dialOnce sync.Once - dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) + dialer dialFunc handshakeErr error requestWriter *requestWriter @@ -57,29 +62,25 @@ type client struct { decoder *qpack.Decoder hostname string - session quic.EarlySession + conn quic.EarlyConnection logger utils.Logger } -func newClient( - hostname string, - tlsConf *tls.Config, - opts *roundTripperOpts, - quicConfig *quic.Config, - dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error), -) (*client, error) { - if quicConfig == nil { - quicConfig = defaultQuicConfig.Clone() - } else if len(quicConfig.Versions) == 0 { - quicConfig = quicConfig.Clone() - quicConfig.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]} +func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (*client, error) { + if conf == nil { + conf = defaultQuicConfig.Clone() + } else if len(conf.Versions) == 0 { + conf = conf.Clone() + conf.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]} } - if len(quicConfig.Versions) != 1 { + if len(conf.Versions) != 1 { return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") } - quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams - quicConfig.EnableDatagrams = opts.EnableDatagram + if conf.MaxIncomingStreams == 0 { + conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams + } + conf.EnableDatagrams = opts.EnableDatagram logger := utils.DefaultLogger.WithPrefix("h3 client") if tlsConf == nil { @@ -88,26 +89,26 @@ func newClient( tlsConf = tlsConf.Clone() } // Replace existing ALPNs by H3 - tlsConf.NextProtos = []string{versionToALPN(quicConfig.Versions[0])} + tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} return &client{ hostname: authorityAddr("https", hostname), tlsConf: tlsConf, requestWriter: newRequestWriter(logger), decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), - config: quicConfig, + config: conf, opts: opts, dialer: dialer, logger: logger, }, nil } -func (c *client) dial() error { +func (c *client) dial(ctx context.Context) error { var err error if c.dialer != nil { - c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config) + c.conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config) } else { - c.session, err = dialAddr(c.hostname, c.tlsConf, c.config) + c.conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config) } if err != nil { return err @@ -115,39 +116,66 @@ func (c *client) dial() error { // send the SETTINGs frame, using 0-RTT data, if possible go func() { - if err := c.setupSession(); err != nil { - c.logger.Debugf("Setting up session failed: %s", err) - c.session.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "") + if err := c.setupConn(); err != nil { + c.logger.Debugf("Setting up connection failed: %s", err) + c.conn.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "") } }() + if c.opts.StreamHijacker != nil { + go c.handleBidirectionalStreams() + } go c.handleUnidirectionalStreams() return nil } -func (c *client) setupSession() error { +func (c *client) setupConn() error { // open the control stream - str, err := c.session.OpenUniStream() + str, err := c.conn.OpenUniStream() if err != nil { return err } buf := &bytes.Buffer{} quicvarint.Write(buf, streamTypeControlStream) // send the SETTINGS frame - (&settingsFrame{Datagram: c.opts.EnableDatagram}).Write(buf) + (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Write(buf) _, err = str.Write(buf.Bytes()) return err } +func (c *client) handleBidirectionalStreams() { + for { + str, err := c.conn.AcceptStream(context.Background()) + if err != nil { + c.logger.Debugf("accepting bidirectional stream failed: %s", err) + return + } + go func(str quic.Stream) { + for { + _, err := parseNextFrame(str, func(ft FrameType) (processed bool, err error) { + return c.opts.StreamHijacker(ft, c.conn, str) + }) + if err == errHijacked { + return + } + if err != nil { + c.logger.Debugf("error handling stream: %s", err) + } + c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream") + } + }(str) + } +} + func (c *client) handleUnidirectionalStreams() { for { - str, err := c.session.AcceptUniStream(context.Background()) + str, err := c.conn.AcceptUniStream(context.Background()) if err != nil { c.logger.Debugf("accepting unidirectional stream failed: %s", err) return } - go func() { + go func(str quic.ReceiveStream) { streamType, err := quicvarint.Read(quicvarint.NewReader(str)) if err != nil { c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) @@ -162,20 +190,23 @@ func (c *client) handleUnidirectionalStreams() { return case streamTypePushStream: // We never increased the Push ID, so we don't expect any push streams. - c.session.CloseWithError(quic.ApplicationErrorCode(errorIDError), "") + c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "") return default: + if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str) { + return + } str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) return } - f, err := parseNextFrame(str) + f, err := parseNextFrame(str, nil) if err != nil { - c.session.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") + c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") return } sf, ok := f.(*settingsFrame) if !ok { - c.session.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "") + c.conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "") return } if !sf.Datagram { @@ -184,18 +215,18 @@ func (c *client) handleUnidirectionalStreams() { // If datagram support was enabled on our side as well as on the server side, // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). - if c.opts.EnableDatagram && !c.session.ConnectionState().SupportsDatagrams { - c.session.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support") + if c.opts.EnableDatagram && !c.conn.ConnectionState().SupportsDatagrams { + c.conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support") } - }() + }(str) } } func (c *client) Close() error { - if c.session == nil { + if c.conn == nil { return nil } - return c.session.CloseWithError(quic.ApplicationErrorCode(errorNoError), "") + return c.conn.CloseWithError(quic.ApplicationErrorCode(errorNoError), "") } func (c *client) maxHeaderBytes() uint64 { @@ -212,7 +243,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { } c.dialOnce.Do(func() { - c.handshakeErr = c.dial() + c.handshakeErr = c.dial(req.Context()) }) if c.handshakeErr != nil { @@ -225,13 +256,13 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { } else { // wait for the handshake to complete select { - case <-c.session.HandshakeComplete().Done(): + case <-c.conn.HandshakeComplete().Done(): case <-req.Context().Done(): return nil, req.Context().Err() } } - str, err := c.session.OpenStreamSync(req.Context()) + str, err := c.conn.OpenStreamSync(req.Context()) if err != nil { return nil, err } @@ -260,7 +291,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { if rerr.err != nil { reason = rerr.err.Error() } - c.session.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) + c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) } } return rsp, rerr.err @@ -279,7 +310,7 @@ func (c *client) doRequest( return nil, newStreamError(errorInternalError, err) } - frame, err := parseNextFrame(str) + frame, err := parseNextFrame(str, nil) if err != nil { return nil, newStreamError(errorFrameError, err) } @@ -300,7 +331,7 @@ func (c *client) doRequest( return nil, newConnError(errorGeneralProtocolError, err) } - connState := qtls.ToTLSConnectionState(c.session.ConnectionState().TLS) + connState := qtls.ToTLSConnectionState(c.conn.ConnectionState().TLS) res := &http.Response{ Proto: "HTTP/3", ProtoMajor: 3, @@ -320,8 +351,8 @@ func (c *client) doRequest( res.Header.Add(hf.Name, hf.Value) } } - respBody := newResponseBody(str, reqDone, func() { - c.session.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") + respBody := newResponseBody(str, c.conn, reqDone, func() { + c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") }) // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. diff --git a/http3/client_test.go b/http3/client_test.go index 3f3b4a4b1f9..b13993f5697 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -12,13 +12,13 @@ import ( "net/http" "time" - "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go" mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic" - "github.com/lucas-clemente/quic-go/quicvarint" - "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/quicvarint" + + "github.com/golang/mock/gomock" "github.com/marten-seemann/qpack" . "github.com/onsi/ginkgo" @@ -65,7 +65,7 @@ var _ = Describe("Client", func() { client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) Expect(err).ToNot(HaveOccurred()) var dialAddrCalled bool - dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) { + dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { Expect(quicConf).To(Equal(defaultQuicConfig)) Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3})) Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.Version1})) @@ -80,7 +80,7 @@ var _ = Describe("Client", func() { client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil) Expect(err).ToNot(HaveOccurred()) var dialAddrCalled bool - dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { + dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { Expect(hostname).To(Equal("quic.clemente.io:443")) dialAddrCalled = true return nil, errors.New("test done") @@ -100,12 +100,8 @@ var _ = Describe("Client", func() { client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil) Expect(err).ToNot(HaveOccurred()) var dialAddrCalled bool - dialAddr = func( - hostname string, - tlsConfP *tls.Config, - quicConfP *quic.Config, - ) (quic.EarlySession, error) { - Expect(hostname).To(Equal("localhost:1337")) + dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { + Expect(host).To(Equal("localhost:1337")) Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3})) Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) @@ -122,9 +118,11 @@ var _ = Describe("Client", func() { testErr := errors.New("test done") tlsConf := &tls.Config{ServerName: "foo.bar"} quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second} + ctx, cancel := context.WithTimeout(context.Background(), time.Hour) + defer cancel() var dialerCalled bool - dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) { - Expect(network).To(Equal("udp")) + dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { + Expect(ctxP).To(Equal(ctx)) Expect(address).To(Equal("localhost:1337")) Expect(tlsConfP.ServerName).To(Equal("foo.bar")) Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) @@ -133,7 +131,7 @@ var _ = Describe("Client", func() { } client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer) Expect(err).ToNot(HaveOccurred()) - _, err = client.RoundTrip(req) + _, err = client.RoundTrip(req.WithContext(ctx)) Expect(err).To(MatchError(testErr)) Expect(dialerCalled).To(BeTrue()) }) @@ -142,7 +140,7 @@ var _ = Describe("Client", func() { testErr := errors.New("handshake error") client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil) Expect(err).ToNot(HaveOccurred()) - dialAddr = func(hostname string, _ *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) { + dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { Expect(quicConf.EnableDatagrams).To(BeTrue()) return nil, testErr } @@ -154,14 +152,14 @@ var _ = Describe("Client", func() { testErr := errors.New("handshake error") client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) Expect(err).ToNot(HaveOccurred()) - dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { return nil, testErr } _, err = client.RoundTrip(req) Expect(err).To(MatchError(testErr)) }) - It("closes correctly if session was not created", func() { + It("closes correctly if connection was not created", func() { client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(client.Close()).To(Succeed()) @@ -179,7 +177,7 @@ var _ = Describe("Client", func() { testErr := errors.New("handshake error") req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) - dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { return nil, testErr } _, err = client.RoundTrip(req) @@ -187,10 +185,93 @@ var _ = Describe("Client", func() { }) }) + Context("hijacking unidirectional streams", func() { + var ( + request *http.Request + conn *mockquic.MockEarlyConnection + settingsFrameWritten chan struct{} + ) + testDone := make(chan struct{}) + + BeforeEach(func() { + testDone = make(chan struct{}) + settingsFrameWritten = make(chan struct{}) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { + defer GinkgoRecover() + close(settingsFrameWritten) + }) + conn = mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { + return conn, nil + } + var err error + request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + testDone <- struct{}{} + Eventually(settingsFrameWritten).Should(BeClosed()) + }) + + It("hijacks an unidirectional stream of unknown stream type", func() { + streamTypeChan := make(chan StreamType, 1) + client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool { + streamTypeChan <- st + return true + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x54) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return unknownStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + _, err := client.RoundTrip(request) + Expect(err).To(MatchError("done")) + Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { + streamTypeChan := make(chan StreamType, 1) + client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool { + streamTypeChan <- st + return false + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x54) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError)) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return unknownStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + _, err := client.RoundTrip(request) + Expect(err).To(MatchError("done")) + Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + }) + Context("control stream handling", func() { var ( request *http.Request - sess *mockquic.MockEarlySession + conn *mockquic.MockEarlyConnection settingsFrameWritten chan struct{} ) testDone := make(chan struct{}) @@ -202,11 +283,13 @@ var _ = Describe("Client", func() { defer GinkgoRecover() close(settingsFrameWritten) }) - sess = mockquic.NewMockEarlySession(mockCtrl) - sess.EXPECT().OpenUniStream().Return(controlStr, nil) - sess.EXPECT().HandshakeComplete().Return(handshakeCtx) - sess.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) - dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { return sess, nil } + conn = mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { + return conn, nil + } var err error request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) @@ -223,16 +306,16 @@ var _ = Describe("Client", func() { (&settingsFrame{}).Write(buf) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) _, err := client.RoundTrip(request) Expect(err).To(MatchError("done")) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to sess.CloseWithError + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { @@ -248,10 +331,10 @@ var _ = Describe("Client", func() { str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return str, nil }) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) @@ -261,7 +344,7 @@ var _ = Describe("Client", func() { }) } - It("resets streams other than the control stream and the QPACK streams", func() { + It("resets streams Other than the control stream and the QPACK streams", func() { buf := &bytes.Buffer{} quicvarint.Write(buf, 1337) str := mockquic.NewMockStream(mockCtrl) @@ -271,10 +354,10 @@ var _ = Describe("Client", func() { close(done) }) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return str, nil }) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) @@ -289,15 +372,15 @@ var _ = Describe("Client", func() { (&dataFrame{}).Write(buf) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) done := make(chan struct{}) - sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { defer GinkgoRecover() Expect(code).To(BeEquivalentTo(errorMissingSettings)) close(done) @@ -315,15 +398,15 @@ var _ = Describe("Client", func() { buf.Write(b.Bytes()[:b.Len()-1]) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) done := make(chan struct{}) - sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { defer GinkgoRecover() Expect(code).To(BeEquivalentTo(errorFrameError)) close(done) @@ -338,15 +421,15 @@ var _ = Describe("Client", func() { quicvarint.Write(buf, streamTypePushStream) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) done := make(chan struct{}) - sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { defer GinkgoRecover() Expect(code).To(BeEquivalentTo(errorIDError)) close(done) @@ -363,16 +446,16 @@ var _ = Describe("Client", func() { (&settingsFrame{Datagram: true}).Write(buf) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) - sess.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) done := make(chan struct{}) - sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) { + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) { defer GinkgoRecover() Expect(code).To(BeEquivalentTo(errorSettingsError)) Expect(reason).To(Equal("missing QUIC Datagram support")) @@ -388,7 +471,7 @@ var _ = Describe("Client", func() { var ( request *http.Request str *mockquic.MockStream - sess *mockquic.MockEarlySession + conn *mockquic.MockEarlyConnection settingsFrameWritten chan struct{} ) testDone := make(chan struct{}) @@ -410,7 +493,7 @@ var _ = Describe("Client", func() { fields := make(map[string]string) decoder := qpack.NewDecoder(nil) - frame, err := parseNextFrame(str) + frame, err := parseNextFrame(str, nil) ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) headersFrame := frame.(*headersFrame) @@ -429,7 +512,7 @@ var _ = Describe("Client", func() { buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(rstr, utils.DefaultLogger) + rw := newResponseWriter(rstr, nil, utils.DefaultLogger) rw.WriteHeader(status) rw.Flush() return buf.Bytes() @@ -447,13 +530,15 @@ var _ = Describe("Client", func() { close(settingsFrameWritten) }) // SETTINGS frame str = mockquic.NewMockStream(mockCtrl) - sess = mockquic.NewMockEarlySession(mockCtrl) - sess.EXPECT().OpenUniStream().Return(controlStr, nil) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn = mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) - dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { return sess, nil } + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { + return conn, nil + } var err error request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) @@ -466,9 +551,9 @@ var _ = Describe("Client", func() { It("errors if it can't open a stream", func() { testErr := errors.New("stream open error") - sess.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) - sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) - sess.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) _, err := client.RoundTrip(request) Expect(err).To(MatchError(testErr)) }) @@ -477,7 +562,7 @@ var _ = Describe("Client", func() { testErr := errors.New("stream open error") request.Method = MethodGet0RTT // don't EXPECT any calls to HandshakeComplete() - sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() str.EXPECT().Close() @@ -493,9 +578,9 @@ var _ = Describe("Client", func() { It("returns a response", func() { rspBuf := bytes.NewBuffer(getResponse(418)) gomock.InOrder( - sess.EXPECT().HandshakeComplete().Return(handshakeCtx), - sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), - sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}), + conn.EXPECT().HandshakeComplete().Return(handshakeCtx), + conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), ) str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) str.EXPECT().Close() @@ -513,8 +598,8 @@ var _ = Describe("Client", func() { BeforeEach(func() { strBuf = &bytes.Buffer{} gomock.InOrder( - sess.EXPECT().HandshakeComplete().Return(handshakeCtx), - sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), + conn.EXPECT().HandshakeComplete().Return(handshakeCtx), + conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), ) body := &mockBody{} body.SetData([]byte("request body")) @@ -571,7 +656,7 @@ var _ = Describe("Client", func() { (&dataFrame{Length: 0x6}).Write(buf) buf.Write([]byte("foobar")) str.EXPECT().Close().Do(func() { close(done) }) - sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}) + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors // the response body is sent asynchronously, while already reading the response str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() @@ -584,7 +669,7 @@ var _ = Describe("Client", func() { It("closes the connection when the first frame is not a HEADERS frame", func() { buf := &bytes.Buffer{} (&dataFrame{Length: 0x42}).Write(buf) - sess.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()) closed := make(chan struct{}) str.EXPECT().Close().Do(func() { close(closed) }) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() @@ -610,7 +695,7 @@ var _ = Describe("Client", func() { It("cancels a request while waiting for the handshake to complete", func() { ctx, cancel := context.WithCancel(context.Background()) req := request.WithContext(ctx) - sess.EXPECT().HandshakeComplete().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(context.Background()) errChan := make(chan error) go func() { @@ -625,8 +710,8 @@ var _ = Describe("Client", func() { It("cancels a request while the request is still in flight", func() { ctx, cancel := context.WithCancel(context.Background()) req := request.WithContext(ctx) - sess.EXPECT().HandshakeComplete().Return(handshakeCtx) - sess.EXPECT().OpenStreamSync(ctx).Return(str, nil) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Close().MaxTimes(1) @@ -654,9 +739,9 @@ var _ = Describe("Client", func() { ctx, cancel := context.WithCancel(context.Background()) req := request.WithContext(ctx) - sess.EXPECT().HandshakeComplete().Return(handshakeCtx) - sess.EXPECT().OpenStreamSync(ctx).Return(str, nil) - sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) buf := &bytes.Buffer{} str.EXPECT().Close().MaxTimes(1) @@ -674,11 +759,11 @@ var _ = Describe("Client", func() { Context("gzip compression", func() { BeforeEach(func() { - sess.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) }) It("adds the gzip header to requests", func() { - sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) gomock.InOrder( @@ -695,7 +780,7 @@ var _ = Describe("Client", func() { It("doesn't add gzip if the header disable it", func() { client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil) Expect(err).ToNot(HaveOccurred()) - sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) buf := &bytes.Buffer{} str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) gomock.InOrder( @@ -710,12 +795,12 @@ var _ = Describe("Client", func() { }) It("decompresses the response", func() { - sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) - sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}) + conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(rstr, utils.DefaultLogger) + rw := newResponseWriter(rstr, nil, utils.DefaultLogger) rw.Header().Set("Content-Encoding", "gzip") gz := gzip.NewWriter(rw) gz.Write([]byte("gzipped response")) @@ -736,12 +821,12 @@ var _ = Describe("Client", func() { }) It("only decompresses the response if the response contains the right content-encoding header", func() { - sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) - sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}) + conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) buf := &bytes.Buffer{} rstr := mockquic.NewMockStream(mockCtrl) rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() - rw := newResponseWriter(rstr, utils.DefaultLogger) + rw := newResponseWriter(rstr, nil, utils.DefaultLogger) rw.Write([]byte("not gzipped")) rw.Flush() str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) diff --git a/http3/error_codes.go b/http3/error_codes.go index dfde76d4326..d87eef4ae07 100644 --- a/http3/error_codes.go +++ b/http3/error_codes.go @@ -26,6 +26,7 @@ const ( errorMessageError errorCode = 0x10e errorConnectError errorCode = 0x10f errorVersionFallback errorCode = 0x110 + errorDatagramError errorCode = 0x4a1268 ) func (e errorCode) String() string { @@ -64,6 +65,8 @@ func (e errorCode) String() string { return "H3_CONNECT_ERROR" case errorVersionFallback: return "H3_VERSION_FALLBACK" + case errorDatagramError: + return "H3_DATAGRAM_ERROR" default: return fmt.Sprintf("unknown error code: %#x", uint16(e)) } diff --git a/http3/frames.go b/http3/frames.go index 679f66c10b9..5fb4f082eba 100644 --- a/http3/frames.go +++ b/http3/frames.go @@ -2,6 +2,7 @@ package http3 import ( "bytes" + "errors" "fmt" "io" "io/ioutil" @@ -10,42 +11,55 @@ import ( "github.com/lucas-clemente/quic-go/quicvarint" ) +// FrameType is the frame type of a HTTP/3 frame +type FrameType uint64 + +type unknownFrameHandlerFunc func(FrameType) (processed bool, err error) + type frame interface{} -func parseNextFrame(r io.Reader) (frame, error) { +var errHijacked = errors.New("hijacked") + +func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (frame, error) { qr := quicvarint.NewReader(r) - t, err := quicvarint.Read(qr) - if err != nil { - return nil, err - } - l, err := quicvarint.Read(qr) - if err != nil { - return nil, err - } + for { + t, err := quicvarint.Read(qr) + if err != nil { + return nil, err + } + // Call the unknownFrameHandler for frames not defined in the HTTP/3 spec + if t > 0xd && unknownFrameHandler != nil { + hijacked, err := unknownFrameHandler(FrameType(t)) + if err != nil { + return nil, err + } + // If the unknownFrameHandler didn't process the frame, it is our responsibility to skip it. + if hijacked { + return nil, errHijacked + } + continue + } + l, err := quicvarint.Read(qr) + if err != nil { + return nil, err + } - switch t { - case 0x0: - return &dataFrame{Length: l}, nil - case 0x1: - return &headersFrame{Length: l}, nil - case 0x4: - return parseSettingsFrame(r, l) - case 0x3: // CANCEL_PUSH - fallthrough - case 0x5: // PUSH_PROMISE - fallthrough - case 0x7: // GOAWAY - fallthrough - case 0xd: // MAX_PUSH_ID - fallthrough - case 0xe: // DUPLICATE_PUSH - fallthrough - default: + switch t { + case 0x0: + return &dataFrame{Length: l}, nil + case 0x1: + return &headersFrame{Length: l}, nil + case 0x4: + return parseSettingsFrame(r, l) + case 0x3: // CANCEL_PUSH + case 0x5: // PUSH_PROMISE + case 0x7: // GOAWAY + case 0xd: // MAX_PUSH_ID + } // skip over unknown frames if _, err := io.CopyN(ioutil.Discard, qr, int64(l)); err != nil { return nil, err } - return parseNextFrame(qr) } } @@ -67,11 +81,11 @@ func (f *headersFrame) Write(b *bytes.Buffer) { quicvarint.Write(b, f.Length) } -const settingDatagram = 0x276 +const settingDatagram = 0xffd277 type settingsFrame struct { Datagram bool - other map[uint64]uint64 // all settings that we don't explicitly recognize + Other map[uint64]uint64 // all settings that we don't explicitly recognize } func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { @@ -109,13 +123,13 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { } frame.Datagram = val == 1 default: - if _, ok := frame.other[id]; ok { + if _, ok := frame.Other[id]; ok { return nil, fmt.Errorf("duplicate setting: %d", id) } - if frame.other == nil { - frame.other = make(map[uint64]uint64) + if frame.Other == nil { + frame.Other = make(map[uint64]uint64) } - frame.other[id] = val + frame.Other[id] = val } } return frame, nil @@ -124,7 +138,7 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) { func (f *settingsFrame) Write(b *bytes.Buffer) { quicvarint.Write(b, 0x4) var l protocol.ByteCount - for id, val := range f.other { + for id, val := range f.Other { l += quicvarint.Len(id) + quicvarint.Len(val) } if f.Datagram { @@ -135,7 +149,7 @@ func (f *settingsFrame) Write(b *bytes.Buffer) { quicvarint.Write(b, settingDatagram) quicvarint.Write(b, 1) } - for id, val := range f.other { + for id, val := range f.Other { quicvarint.Write(b, id) quicvarint.Write(b, val) } diff --git a/http3/frames_test.go b/http3/frames_test.go index 83f65370b92..40ca3c124a4 100644 --- a/http3/frames_test.go +++ b/http3/frames_test.go @@ -24,7 +24,7 @@ var _ = Describe("Frames", func() { data = append(data, make([]byte, 0x42)...) buf := bytes.NewBuffer(data) (&dataFrame{Length: 0x1234}).Write(buf) - frame, err := parseNextFrame(buf) + frame, err := parseNextFrame(buf, nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1234))) @@ -34,7 +34,7 @@ var _ = Describe("Frames", func() { It("parses", func() { data := appendVarInt(nil, 0) // type byte data = appendVarInt(data, 0x1337) - frame, err := parseNextFrame(bytes.NewReader(data)) + frame, err := parseNextFrame(bytes.NewReader(data), nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame.(*dataFrame).Length).To(Equal(uint64(0x1337))) @@ -43,7 +43,7 @@ var _ = Describe("Frames", func() { It("writes", func() { buf := &bytes.Buffer{} (&dataFrame{Length: 0xdeadbeef}).Write(buf) - frame, err := parseNextFrame(buf) + frame, err := parseNextFrame(buf, nil) Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) @@ -55,7 +55,7 @@ var _ = Describe("Frames", func() { It("parses", func() { data := appendVarInt(nil, 1) // type byte data = appendVarInt(data, 0x1337) - frame, err := parseNextFrame(bytes.NewReader(data)) + frame, err := parseNextFrame(bytes.NewReader(data), nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) Expect(frame.(*headersFrame).Length).To(Equal(uint64(0x1337))) @@ -64,7 +64,7 @@ var _ = Describe("Frames", func() { It("writes", func() { buf := &bytes.Buffer{} (&headersFrame{Length: 0xdeadbeef}).Write(buf) - frame, err := parseNextFrame(buf) + frame, err := parseNextFrame(buf, nil) Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) @@ -81,12 +81,12 @@ var _ = Describe("Frames", func() { data := appendVarInt(nil, 4) // type byte data = appendVarInt(data, uint64(len(settings))) data = append(data, settings...) - frame, err := parseNextFrame(bytes.NewReader(data)) + frame, err := parseNextFrame(bytes.NewReader(data), nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&settingsFrame{})) sf := frame.(*settingsFrame) - Expect(sf.other).To(HaveKeyWithValue(uint64(13), uint64(37))) - Expect(sf.other).To(HaveKeyWithValue(uint64(0xdead), uint64(0xbeef))) + Expect(sf.Other).To(HaveKeyWithValue(uint64(13), uint64(37))) + Expect(sf.Other).To(HaveKeyWithValue(uint64(0xdead), uint64(0xbeef))) }) It("rejects duplicate settings", func() { @@ -97,25 +97,25 @@ var _ = Describe("Frames", func() { data := appendVarInt(nil, 4) // type byte data = appendVarInt(data, uint64(len(settings))) data = append(data, settings...) - _, err := parseNextFrame(bytes.NewReader(data)) + _, err := parseNextFrame(bytes.NewReader(data), nil) Expect(err).To(MatchError("duplicate setting: 13")) }) It("writes", func() { - sf := &settingsFrame{other: map[uint64]uint64{ + sf := &settingsFrame{Other: map[uint64]uint64{ 1: 2, 99: 999, 13: 37, }} buf := &bytes.Buffer{} sf.Write(buf) - frame, err := parseNextFrame(buf) + frame, err := parseNextFrame(buf, nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(sf)) }) It("errors on EOF", func() { - sf := &settingsFrame{other: map[uint64]uint64{ + sf := &settingsFrame{Other: map[uint64]uint64{ 13: 37, 0xdeadbeef: 0xdecafbad, }} @@ -123,13 +123,13 @@ var _ = Describe("Frames", func() { sf.Write(buf) data := buf.Bytes() - _, err := parseNextFrame(bytes.NewReader(data)) + _, err := parseNextFrame(bytes.NewReader(data), nil) Expect(err).ToNot(HaveOccurred()) for i := range data { b := make([]byte, i) copy(b, data[:i]) - _, err := parseNextFrame(bytes.NewReader(b)) + _, err := parseNextFrame(bytes.NewReader(b), nil) Expect(err).To(MatchError(io.EOF)) } }) @@ -141,7 +141,7 @@ var _ = Describe("Frames", func() { data := appendVarInt(nil, 4) // type byte data = appendVarInt(data, uint64(len(settings))) data = append(data, settings...) - f, err := parseNextFrame(bytes.NewReader(data)) + f, err := parseNextFrame(bytes.NewReader(data), nil) Expect(err).ToNot(HaveOccurred()) Expect(f).To(BeAssignableToTypeOf(&settingsFrame{})) sf := f.(*settingsFrame) @@ -156,7 +156,7 @@ var _ = Describe("Frames", func() { data := appendVarInt(nil, 4) // type byte data = appendVarInt(data, uint64(len(settings))) data = append(data, settings...) - _, err := parseNextFrame(bytes.NewReader(data)) + _, err := parseNextFrame(bytes.NewReader(data), nil) Expect(err).To(MatchError(fmt.Sprintf("duplicate setting: %d", settingDatagram))) }) @@ -166,7 +166,7 @@ var _ = Describe("Frames", func() { data := appendVarInt(nil, 4) // type byte data = appendVarInt(data, uint64(len(settings))) data = append(data, settings...) - _, err := parseNextFrame(bytes.NewReader(data)) + _, err := parseNextFrame(bytes.NewReader(data), nil) Expect(err).To(MatchError("invalid value for H3_DATAGRAM: 1337")) }) @@ -174,10 +174,55 @@ var _ = Describe("Frames", func() { sf := &settingsFrame{Datagram: true} buf := &bytes.Buffer{} sf.Write(buf) - frame, err := parseNextFrame(buf) + frame, err := parseNextFrame(buf, nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(sf)) }) }) }) + + Context("hijacking", func() { + It("reads a frame without hijacking the stream", func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, 1337) + customFrameContents := []byte("foobar") + buf.Write(customFrameContents) + + var called bool + _, err := parseNextFrame(buf, func(ft FrameType) (hijacked bool, err error) { + Expect(ft).To(BeEquivalentTo(1337)) + called = true + b := make([]byte, 3) + _, err = io.ReadFull(buf, b) + Expect(err).ToNot(HaveOccurred()) + Expect(string(b)).To(Equal("foo")) + return true, nil + }) + Expect(err).To(MatchError(errHijacked)) + Expect(called).To(BeTrue()) + }) + + It("reads a frame without hijacking the stream", func() { + buf := &bytes.Buffer{} + quicvarint.Write(buf, 1337) + customFrameContents := []byte("custom frame") + buf.Write(customFrameContents) + (&dataFrame{Length: 6}).Write(buf) + buf.WriteString("foobar") + + var called bool + frame, err := parseNextFrame(buf, func(ft FrameType) (hijacked bool, err error) { + Expect(ft).To(BeEquivalentTo(1337)) + called = true + b := make([]byte, len(customFrameContents)) + _, err = io.ReadFull(buf, b) + Expect(err).ToNot(HaveOccurred()) + Expect(string(b)).To(Equal(string(customFrameContents))) + return false, nil + }) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(&dataFrame{Length: 6})) + Expect(called).To(BeTrue()) + }) + }) }) diff --git a/http3/request.go b/http3/request.go index b5fc5d5aca7..f15e8afc2c0 100644 --- a/http3/request.go +++ b/http3/request.go @@ -12,9 +12,9 @@ import ( ) func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) { - var path, authority, method, contentLengthStr string - httpHeaders := http.Header{} + var path, authority, method, protocol, scheme, contentLengthStr string + httpHeaders := http.Header{} for _, h := range headers { switch h.Name { case ":path": @@ -23,6 +23,10 @@ func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) { method = h.Value case ":authority": authority = h.Value + case ":protocol": + protocol = h.Value + case ":scheme": + scheme = h.Value case "content-length": contentLengthStr = h.Value default: @@ -38,8 +42,14 @@ func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) { } isConnect := method == http.MethodConnect - if isConnect { - if path != "" || authority == "" { + // Extended CONNECT, see https://datatracker.ietf.org/doc/html/rfc8441#section-4 + isExtendedConnected := isConnect && protocol != "" + if isExtendedConnected { + if scheme == "" || path == "" || authority == "" { + return nil, errors.New("extended CONNECT: :scheme, :path and :authority must not be empty") + } + } else if isConnect { + if path != "" || authority == "" { // normal CONNECT return nil, errors.New(":path must be empty and :authority must not be empty") } } else if len(path) == 0 || len(authority) == 0 || len(method) == 0 { @@ -51,9 +61,20 @@ func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) { var err error if isConnect { - u = &url.URL{Host: authority} + u = &url.URL{} + if isExtendedConnected { + u, err = url.ParseRequestURI(path) + if err != nil { + return nil, err + } + } else { + u.Path = path + } + u.Scheme = scheme + u.Host = authority requestURI = authority } else { + protocol = "HTTP/3" u, err = url.ParseRequestURI(path) if err != nil { return nil, err @@ -72,7 +93,7 @@ func requestFromHeaders(headers []qpack.HeaderField) (*http.Request, error) { return &http.Request{ Method: method, URL: u, - Proto: "HTTP/3", + Proto: protocol, ProtoMajor: 3, ProtoMinor: 0, Header: httpHeaders, diff --git a/http3/request_test.go b/http3/request_test.go index edaba2a59b4..ec3aa8b3eeb 100644 --- a/http3/request_test.go +++ b/http3/request_test.go @@ -64,7 +64,7 @@ var _ = Describe("Request", func() { })) }) - It("handles other headers", func() { + It("handles Other headers", func() { headers := []qpack.HeaderField{ {Name: ":path", Value: "/foo"}, {Name: ":authority", Value: "quic.clemente.io"}, @@ -81,17 +81,6 @@ var _ = Describe("Request", func() { })) }) - It("handles CONNECT method", func() { - headers := []qpack.HeaderField{ - {Name: ":authority", Value: "quic.clemente.io"}, - {Name: ":method", Value: http.MethodConnect}, - } - req, err := requestFromHeaders(headers) - Expect(err).NotTo(HaveOccurred()) - Expect(req.Method).To(Equal(http.MethodConnect)) - Expect(req.RequestURI).To(Equal("quic.clemente.io")) - }) - It("errors with missing path", func() { headers := []qpack.HeaderField{ {Name: ":authority", Value: "quic.clemente.io"}, @@ -119,22 +108,64 @@ var _ = Describe("Request", func() { Expect(err).To(MatchError(":path, :authority and :method must not be empty")) }) - It("errors with missing authority in CONNECT method", func() { - headers := []qpack.HeaderField{ - {Name: ":method", Value: http.MethodConnect}, - } - _, err := requestFromHeaders(headers) - Expect(err).To(MatchError(":path must be empty and :authority must not be empty")) + Context("regular HTTP CONNECT", func() { + It("handles CONNECT method", func() { + headers := []qpack.HeaderField{ + {Name: ":authority", Value: "quic.clemente.io"}, + {Name: ":method", Value: http.MethodConnect}, + } + req, err := requestFromHeaders(headers) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Method).To(Equal(http.MethodConnect)) + Expect(req.RequestURI).To(Equal("quic.clemente.io")) + }) + + It("errors with missing authority in CONNECT method", func() { + headers := []qpack.HeaderField{ + {Name: ":method", Value: http.MethodConnect}, + } + _, err := requestFromHeaders(headers) + Expect(err).To(MatchError(":path must be empty and :authority must not be empty")) + }) + + It("errors with extra path in CONNECT method", func() { + headers := []qpack.HeaderField{ + {Name: ":path", Value: "/foo"}, + {Name: ":authority", Value: "quic.clemente.io"}, + {Name: ":method", Value: http.MethodConnect}, + } + _, err := requestFromHeaders(headers) + Expect(err).To(MatchError(":path must be empty and :authority must not be empty")) + }) }) - It("errors with extra path in CONNECT method", func() { - headers := []qpack.HeaderField{ - {Name: ":path", Value: "/foo"}, - {Name: ":authority", Value: "quic.clemente.io"}, - {Name: ":method", Value: http.MethodConnect}, - } - _, err := requestFromHeaders(headers) - Expect(err).To(MatchError(":path must be empty and :authority must not be empty")) + Context("Extended CONNECT", func() { + It("handles Extended CONNECT method", func() { + headers := []qpack.HeaderField{ + {Name: ":protocol", Value: "webtransport"}, + {Name: ":scheme", Value: "ftp"}, + {Name: ":method", Value: http.MethodConnect}, + {Name: ":authority", Value: "quic.clemente.io"}, + {Name: ":path", Value: "/foo?val=1337"}, + } + req, err := requestFromHeaders(headers) + Expect(err).NotTo(HaveOccurred()) + Expect(req.Method).To(Equal(http.MethodConnect)) + Expect(req.Proto).To(Equal("webtransport")) + Expect(req.URL.String()).To(Equal("ftp://quic.clemente.io/foo?val=1337")) + Expect(req.URL.Query().Get("val")).To(Equal("1337")) + }) + + It("errors with missing scheme", func() { + headers := []qpack.HeaderField{ + {Name: ":protocol", Value: "webtransport"}, + {Name: ":method", Value: http.MethodConnect}, + {Name: ":authority", Value: "quic.clemente.io"}, + {Name: ":path", Value: "/foo"}, + } + _, err := requestFromHeaders(headers) + Expect(err).To(MatchError("extended CONNECT: :scheme, :path and :authority must not be empty")) + }) }) Context("extracting the hostname from a request", func() { diff --git a/http3/request_writer.go b/http3/request_writer.go index 8878c8f19f7..aebb640b17e 100644 --- a/http3/request_writer.go +++ b/http3/request_writer.go @@ -113,7 +113,9 @@ func (w *requestWriter) writeHeaders(wr io.Writer, req *http.Request, gzip bool) } // copied from net/transport.go - +// Modified to support Extended CONNECT: +// Contrary to what the godoc for the http.Request says, +// we do respect the Proto field if the method is CONNECT. func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) error { host := req.Host if host == "" { @@ -124,8 +126,11 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra return err } + // http.NewRequest sets this field to HTTP/1.1 + isExtendedConnect := req.Method == http.MethodConnect && req.Proto != "" && req.Proto != "HTTP/1.1" + var path string - if req.Method != "CONNECT" { + if req.Method != http.MethodConnect || isExtendedConnect { path = req.URL.RequestURI() if !validPseudoPath(path) { orig := path @@ -162,10 +167,13 @@ func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, tra // [RFC3986]). f(":authority", host) f(":method", req.Method) - if req.Method != "CONNECT" { + if req.Method != http.MethodConnect || isExtendedConnect { f(":path", path) f(":scheme", req.URL.Scheme) } + if isExtendedConnect { + f(":protocol", req.Proto) + } if trailers != "" { f("trailer", trailers) } diff --git a/http3/request_writer_test.go b/http3/request_writer_test.go index 83c77204bda..9a1e718e289 100644 --- a/http3/request_writer_test.go +++ b/http3/request_writer_test.go @@ -6,12 +6,12 @@ import ( "net/http" "strconv" - "github.com/marten-seemann/qpack" - - "github.com/golang/mock/gomock" mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/golang/mock/gomock" + "github.com/marten-seemann/qpack" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -30,7 +30,7 @@ var _ = Describe("Request Writer", func() { ) decode := func(str io.Reader) map[string]string { - frame, err := parseNextFrame(str) + frame, err := parseNextFrame(str, nil) ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) headersFrame := frame.(*headersFrame) @@ -58,7 +58,7 @@ var _ = Describe("Request Writer", func() { It("writes a GET request", func() { str.EXPECT().Close() - req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil) + req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/index.html?foo=bar", nil) Expect(err).ToNot(HaveOccurred()) Expect(rw.WriteRequest(str, req, false)).To(Succeed()) headerFields := decode(strBuf) @@ -73,7 +73,7 @@ var _ = Describe("Request Writer", func() { closed := make(chan struct{}) str.EXPECT().Close().Do(func() { close(closed) }) postData := bytes.NewReader([]byte("foobar")) - req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", postData) + req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", postData) Expect(err).ToNot(HaveOccurred()) Expect(rw.WriteRequest(str, req, false)).To(Succeed()) @@ -85,7 +85,7 @@ var _ = Describe("Request Writer", func() { Expect(err).ToNot(HaveOccurred()) Expect(contentLength).To(BeNumerically(">", 0)) - frame, err := parseNextFrame(strBuf) + frame, err := parseNextFrame(strBuf, nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6)) @@ -94,7 +94,7 @@ var _ = Describe("Request Writer", func() { It("writes a POST request, if the Body returns an EOF immediately", func() { closed := make(chan struct{}) str.EXPECT().Close().Do(func() { close(closed) }) - req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", &foobarReader{}) + req, err := http.NewRequest(http.MethodPost, "https://quic.clemente.io/upload.html", &foobarReader{}) Expect(err).ToNot(HaveOccurred()) Expect(rw.WriteRequest(str, req, false)).To(Succeed()) @@ -102,7 +102,7 @@ var _ = Describe("Request Writer", func() { headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue(":method", "POST")) - frame, err := parseNextFrame(strBuf) + frame, err := parseNextFrame(strBuf, nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) Expect(frame.(*dataFrame).Length).To(BeEquivalentTo(6)) @@ -110,7 +110,7 @@ var _ = Describe("Request Writer", func() { It("sends cookies", func() { str.EXPECT().Close() - req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil) + req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil) Expect(err).ToNot(HaveOccurred()) cookie1 := &http.Cookie{ Name: "Cookie #1", @@ -129,10 +129,37 @@ var _ = Describe("Request Writer", func() { It("adds the header for gzip support", func() { str.EXPECT().Close() - req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil) + req, err := http.NewRequest(http.MethodGet, "https://quic.clemente.io/", nil) Expect(err).ToNot(HaveOccurred()) Expect(rw.WriteRequest(str, req, true)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip")) }) + + It("writes a CONNECT request", func() { + str.EXPECT().Close() + req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/", nil) + Expect(err).ToNot(HaveOccurred()) + Expect(rw.WriteRequest(str, req, false)).To(Succeed()) + headerFields := decode(strBuf) + Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT")) + Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) + Expect(headerFields).ToNot(HaveKey(":path")) + Expect(headerFields).ToNot(HaveKey(":scheme")) + Expect(headerFields).ToNot(HaveKey(":protocol")) + }) + + It("writes an Extended CONNECT request", func() { + str.EXPECT().Close() + req, err := http.NewRequest(http.MethodConnect, "https://quic.clemente.io/foobar", nil) + Expect(err).ToNot(HaveOccurred()) + req.Proto = "webtransport" + Expect(rw.WriteRequest(str, req, false)).To(Succeed()) + headerFields := decode(strBuf) + Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) + Expect(headerFields).To(HaveKeyWithValue(":method", "CONNECT")) + Expect(headerFields).To(HaveKeyWithValue(":path", "/foobar")) + Expect(headerFields).To(HaveKeyWithValue(":scheme", "https")) + Expect(headerFields).To(HaveKeyWithValue(":protocol", "webtransport")) + }) }) diff --git a/http3/response_writer.go b/http3/response_writer.go index 0a42518cfd1..9f232e0f2bb 100644 --- a/http3/response_writer.go +++ b/http3/response_writer.go @@ -23,6 +23,7 @@ type DataStreamer interface { } type responseWriter struct { + conn quic.Connection stream quic.Stream // needed for DataStream() bufferedStream *bufio.Writer @@ -38,12 +39,14 @@ var ( _ http.ResponseWriter = &responseWriter{} _ http.Flusher = &responseWriter{} _ DataStreamer = &responseWriter{} + _ Hijacker = &responseWriter{} ) -func newResponseWriter(stream quic.Stream, logger utils.Logger) *responseWriter { +func newResponseWriter(stream quic.Stream, conn quic.Connection, logger utils.Logger) *responseWriter { return &responseWriter{ header: http.Header{}, stream: stream, + conn: conn, bufferedStream: bufio.NewWriter(stream), logger: logger, } @@ -119,6 +122,14 @@ func (w *responseWriter) DataStream() quic.Stream { return w.stream } +func (w *responseWriter) StreamID() quic.StreamID { + return w.stream.StreamID() +} + +func (w *responseWriter) StreamCreator() StreamCreator { + return w.conn +} + // copied from http2/http2.go // bodyAllowedForStatus reports whether a given response status code // permits a body. See RFC 2616, section 4.4. diff --git a/http3/response_writer_test.go b/http3/response_writer_test.go index fb2ff186d67..2da3ef014a0 100644 --- a/http3/response_writer_test.go +++ b/http3/response_writer_test.go @@ -25,7 +25,7 @@ var _ = Describe("Response Writer", func() { strBuf = &bytes.Buffer{} str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() - rw = newResponseWriter(str, utils.DefaultLogger) + rw = newResponseWriter(str, nil, utils.DefaultLogger) }) decodeHeader := func(str io.Reader) map[string][]string { @@ -33,7 +33,7 @@ var _ = Describe("Response Writer", func() { fields := make(map[string][]string) decoder := qpack.NewDecoder(nil) - frame, err := parseNextFrame(str) + frame, err := parseNextFrame(str, nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) headersFrame := frame.(*headersFrame) @@ -49,7 +49,7 @@ var _ = Describe("Response Writer", func() { } getData := func(str io.Reader) []byte { - frame, err := parseNextFrame(str) + frame, err := parseNextFrame(str, nil) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(BeAssignableToTypeOf(&dataFrame{})) df := frame.(*dataFrame) diff --git a/http3/roundtrip.go b/http3/roundtrip.go index 8e6f943e93b..6ba251cbb43 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -1,6 +1,7 @@ package http3 import ( + "context" "crypto/tls" "errors" "fmt" @@ -9,7 +10,7 @@ import ( "strings" "sync" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" "golang.org/x/net/http/httpguts" ) @@ -46,10 +47,24 @@ type RoundTripper struct { // See https://www.ietf.org/archive/id/draft-schinazi-masque-h3-datagram-02.html. EnableDatagrams bool + // Additional HTTP/3 settings. + // It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft. + AdditionalSettings map[uint64]uint64 + + // When set, this callback is called for the first unknown frame parsed on a bidirectional stream. + // It is called right after parsing the frame type. + // Callers can either process the frame and return control of the stream back to HTTP/3 + // (by returning hijacked false). + // Alternatively, callers can take over the QUIC stream (by returning hijacked true). + StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error) + + // When set, this callback is called for unknown unidirectional stream of unknown stream type. + UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool) + // Dial specifies an optional dial function for creating QUIC // connections for requests. - // If Dial is nil, quic.DialAddrEarly will be used. - Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) + // If Dial is nil, quic.DialAddrEarlyContext will be used. + Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) // MaxResponseHeaderBytes specifies a limit on how many response bytes are // allowed in the server's response header. @@ -64,9 +79,6 @@ type RoundTripOpt struct { // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. // If set true and no cached connection is available, RoundTrip will return ErrNoCachedConn. OnlyCachedConn bool - // SkipSchemeCheck controls whether we check if the scheme is https. - // This allows the use of different schemes, e.g. masque://target.example.com:443/. - SkipSchemeCheck bool } var _ roundTripCloser = &RoundTripper{} @@ -100,7 +112,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. } } } - } else if !opt.SkipSchemeCheck { + } else { closeRequestBody(req) return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme) } @@ -144,6 +156,8 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr EnableDatagram: r.EnableDatagrams, DisableCompression: r.DisableCompression, MaxHeaderBytes: r.MaxResponseHeaderBytes, + StreamHijacker: r.StreamHijacker, + UniStreamHijacker: r.UniStreamHijacker, }, r.QuicConfig, r.Dial, diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index 184889f1c35..a17cf4db087 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -61,7 +61,7 @@ var _ = Describe("RoundTripper", func() { var ( rt *RoundTripper req1 *http.Request - session *mockquic.MockEarlySession + conn *mockquic.MockEarlyConnection handshakeCtx context.Context // an already canceled context ) @@ -80,12 +80,12 @@ var _ = Describe("RoundTripper", func() { origDialAddr := dialAddr BeforeEach(func() { - session = mockquic.NewMockEarlySession(mockCtrl) + conn = mockquic.NewMockEarlyConnection(mockCtrl) origDialAddr = dialAddr - dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) { + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { // return an error when trying to open a stream // we don't want to test all the dial logic here, just that dialing happens at all - return session, nil + return conn, nil } }) @@ -98,14 +98,14 @@ var _ = Describe("RoundTripper", func() { testErr := errors.New("test err") req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) - session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) - session.EXPECT().HandshakeComplete().Return(handshakeCtx) - session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) - session.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-closed return nil, errors.New("test done") }).MaxTimes(1) - session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) }) + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) }) _, err = rt.RoundTrip(req) Expect(err).To(MatchError(testErr)) Expect(rt.clients).To(HaveLen(1)) @@ -115,7 +115,7 @@ var _ = Describe("RoundTripper", func() { It("uses the quic.Config, if provided", func() { config := &quic.Config{HandshakeIdleTimeout: time.Millisecond} var receivedConfig *quic.Config - dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) { + dialAddr = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) { receivedConfig = config return nil, errors.New("handshake error") } @@ -127,7 +127,7 @@ var _ = Describe("RoundTripper", func() { It("uses the custom dialer, if provided", func() { var dialed bool - dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) { + dialer := func(_ context.Context, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { dialed = true return nil, errors.New("handshake error") } @@ -140,14 +140,14 @@ var _ = Describe("RoundTripper", func() { It("reuses existing clients", func() { closed := make(chan struct{}) testErr := errors.New("test err") - session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) - session.EXPECT().HandshakeComplete().Return(handshakeCtx).Times(2) - session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2) - session.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx).Times(2) + conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-closed return nil, errors.New("test done") }).MaxTimes(1) - session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) }) + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) { close(closed) }) req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) Expect(err).ToNot(HaveOccurred()) _, err = rt.RoundTrip(req) @@ -179,15 +179,6 @@ var _ = Describe("RoundTripper", func() { Expect(req.Body.(*mockBody).closed).To(BeTrue()) }) - It("allow non-https schemes if SkipSchemeCheck is set", func() { - req, err := http.NewRequest("GET", "masque://www.example.org/", nil) - Expect(err).ToNot(HaveOccurred()) - _, err = rt.RoundTrip(req) - Expect(err).To(MatchError("http3: unsupported protocol scheme: masque")) - _, err = rt.RoundTripOpt(req, RoundTripOpt{SkipSchemeCheck: true, OnlyCachedConn: true}) - Expect(err).To(MatchError("http3: no cached connection was available")) - }) - It("rejects requests without a URL", func() { req1.URL = nil req1.Body = &mockBody{} diff --git a/http3/server.go b/http3/server.go index 2ae3fef5af7..e9eafda2fa3 100644 --- a/http3/server.go +++ b/http3/server.go @@ -12,7 +12,6 @@ import ( "runtime" "strings" "sync" - "sync/atomic" "time" "github.com/lucas-clemente/quic-go" @@ -34,6 +33,9 @@ const ( nextProtoH3 = "h3" ) +// StreamType is the stream type of a unidirectional stream. +type StreamType uint64 + const ( streamTypeControlStream = 0 streamTypePushStream = 1 @@ -51,6 +53,44 @@ func versionToALPN(v protocol.VersionNumber) string { return "" } +// ConfigureTLSConfig creates a new tls.Config which can be used +// to create a quic.Listener meant for serving http3. The created +// tls.Config adds the functionality of detecting the used QUIC version +// in order to set the correct ALPN value for the http3 connection. +func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config { + // The tls.Config used to setup the quic.Listener needs to have the GetConfigForClient callback set. + // That way, we can get the QUIC version and set the correct ALPN value. + return &tls.Config{ + GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + // determine the ALPN from the QUIC version used + proto := nextProtoH3Draft29 + if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok { + if qconn.GetQUICVersion() == protocol.Version1 { + proto = nextProtoH3 + } + } + config := tlsConf + if tlsConf.GetConfigForClient != nil { + getConfigForClient := tlsConf.GetConfigForClient + var err error + conf, err := getConfigForClient(ch) + if err != nil { + return nil, err + } + if conf != nil { + config = conf + } + } + if config == nil { + return nil, nil + } + config = config.Clone() + config.NextProtos = []string{proto} + return config, nil + }, + } +} + // contextKey is a value for use with context.WithValue. It's used as // a pointer so it fits in an interface{} without allocation. type contextKey struct { @@ -79,6 +119,11 @@ func newConnError(code errorCode, err error) requestError { return requestError{err: err, connErr: code} } +// listenerInfo contains info about specific listener added with addListener +type listenerInfo struct { + port int // 0 means that no info about port is available +} + // Server is a HTTP/3 server. type Server struct { *http.Server @@ -89,21 +134,37 @@ type Server struct { // Enable support for HTTP/3 datagrams. // If set to true, QuicConfig.EnableDatagram will be set. - // See https://www.ietf.org/archive/id/draft-schinazi-masque-h3-datagram-02.html. + // See https://datatracker.ietf.org/doc/html/draft-ietf-masque-h3-datagram-07. EnableDatagrams bool // The port to use in Alt-Svc response headers. // If needed Port can be manually set when the Server is created. // This is useful when a Layer 4 firewall is redirecting UDP traffic and clients must use // a port different from the port the Server is listening on. - Port uint32 + Port int + + // Additional HTTP/3 settings. + // It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft. + AdditionalSettings map[uint64]uint64 + + // When set, this callback is called for the first unknown frame parsed on a bidirectional stream. + // It is called right after parsing the frame type. + // Callers can either process the frame and return control of the stream back to HTTP/3 + // (by returning hijacked false). + // Alternatively, callers can take over the QUIC stream (by returning hijacked true). + StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error) + + // When set, this callback is called for unknown unidirectional stream of unknown stream type. + UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool) + + mutex sync.RWMutex + listeners map[*quic.EarlyListener]listenerInfo + + closed bool - mutex sync.Mutex - listeners map[*quic.EarlyListener]struct{} - closed utils.AtomicBool + altSvcHeader string - loggerOnce sync.Once - logger utils.Logger + logger utils.Logger } // ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections. @@ -111,7 +172,7 @@ func (s *Server) ListenAndServe() error { if s.Server == nil { return errors.New("use of http3.Server without http.Server") } - return s.serveImpl(s.TLSConfig, nil) + return s.serveConn(s.TLSConfig, nil) } // ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections. @@ -127,61 +188,39 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { config := &tls.Config{ Certificates: certs, } - return s.serveImpl(config, nil) + return s.serveConn(config, nil) } // Serve an existing UDP connection. // It is possible to reuse the same connection for outgoing connections. // Closing the server does not close the packet conn. func (s *Server) Serve(conn net.PacketConn) error { - return s.serveImpl(s.TLSConfig, conn) + return s.serveConn(s.TLSConfig, conn) } -func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error { - if s.closed.Get() { - return http.ErrServerClosed - } +// ServeListener serves an existing QUIC listener. +// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config +// and use it to construct a http3-friendly QUIC listener. +// Closing the server does close the listener. +func (s *Server) ServeListener(ln quic.EarlyListener) error { if s.Server == nil { return errors.New("use of http3.Server without http.Server") } - s.loggerOnce.Do(func() { - s.logger = utils.DefaultLogger.WithPrefix("server") - }) - // The tls.Config we pass to Listen needs to have the GetConfigForClient callback set. - // That way, we can get the QUIC version and set the correct ALPN value. - baseConf := &tls.Config{ - GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { - // determine the ALPN from the QUIC version used - proto := nextProtoH3Draft29 - if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok { - if qconn.GetQUICVersion() == protocol.Version1 { - proto = nextProtoH3 - } - } - config := tlsConf - if tlsConf.GetConfigForClient != nil { - getConfigForClient := tlsConf.GetConfigForClient - var err error - conf, err := getConfigForClient(ch) - if err != nil { - return nil, err - } - if conf != nil { - config = conf - } - } - if config == nil { - return nil, nil - } - config = config.Clone() - config.NextProtos = []string{proto} - return config, nil - }, + if err := s.addListener(&ln); err != nil { + return err } + err := s.serveListener(ln) + s.removeListener(&ln) + return err +} - var ln quic.EarlyListener - var err error +func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error { + if s.Server == nil { + return errors.New("use of http3.Server without http.Server") + } + + baseConf := ConfigureTLSConfig(tlsConf) quicConf := s.QuicConfig if quicConf == nil { quicConf = &quic.Config{} @@ -191,6 +230,9 @@ func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error { if s.EnableDatagrams { quicConf.EnableDatagrams = true } + + var ln quic.EarlyListener + var err error if conn == nil { ln, err = quicListenAddr(s.Addr, baseConf, quicConf) } else { @@ -199,64 +241,153 @@ func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error { if err != nil { return err } - s.addListener(&ln) - defer s.removeListener(&ln) + if err := s.addListener(&ln); err != nil { + return err + } + err = s.serveListener(ln) + s.removeListener(&ln) + return err +} +func (s *Server) serveListener(ln quic.EarlyListener) error { for { - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) if err != nil { return err } - go s.handleConn(sess) + go s.handleConn(conn) } } +func extractPort(addr string) (int, error) { + _, portStr, err := net.SplitHostPort(addr) + if err != nil { + return 0, err + } + + portInt, err := net.LookupPort("tcp", portStr) + if err != nil { + return 0, err + } + return portInt, nil +} + +func (s *Server) generateAltSvcHeader() { + if len(s.listeners) == 0 { + // Don't announce any ports since no one is listening for connections + s.altSvcHeader = "" + return + } + + // This code assumes that we will use protocol.SupportedVersions if no quic.Config is passed. + supportedVersions := protocol.SupportedVersions + if s.QuicConfig != nil && len(s.QuicConfig.Versions) > 0 { + supportedVersions = s.QuicConfig.Versions + } + var versionStrings []string + for _, version := range supportedVersions { + if v := versionToALPN(version); len(v) > 0 { + versionStrings = append(versionStrings, v) + } + } + + var altSvc []string + addPort := func(port int) { + for _, v := range versionStrings { + altSvc = append(altSvc, fmt.Sprintf(`%s=":%d"; ma=2592000`, v, port)) + } + } + + if s.Port != 0 { + // if Port is specified, we must use it instead of the + // listener addresses since there's a reason it's specified. + addPort(s.Port) + } else { + // if we have some listeners assigned, try to find ports + // which we can announce, otherwise nothing should be announced + validPortsFound := false + for _, info := range s.listeners { + if info.port != 0 { + addPort(info.port) + validPortsFound = true + } + } + if !validPortsFound { + if port, err := extractPort(s.Addr); err == nil { + addPort(port) + } + } + } + + s.altSvcHeader = strings.Join(altSvc, ",") +} + // We store a pointer to interface in the map set. This is safe because we only // call trackListener via Serve and can track+defer untrack the same pointer to // local variable there. We never need to compare a Listener from another caller. -func (s *Server) addListener(l *quic.EarlyListener) { +func (s *Server) addListener(l *quic.EarlyListener) error { s.mutex.Lock() + defer s.mutex.Unlock() + + if s.closed { + return http.ErrServerClosed + } + if s.logger == nil { + s.logger = utils.DefaultLogger.WithPrefix("server") + } if s.listeners == nil { - s.listeners = make(map[*quic.EarlyListener]struct{}) + s.listeners = make(map[*quic.EarlyListener]listenerInfo) } - s.listeners[l] = struct{}{} - s.mutex.Unlock() + + if port, err := extractPort((*l).Addr().String()); err == nil { + s.listeners[l] = listenerInfo{port} + } else { + s.logger.Errorf( + "Unable to extract port from listener %+v, will not be announced using SetQuicHeaders: %s", err) + s.listeners[l] = listenerInfo{} + } + s.generateAltSvcHeader() + return nil } func (s *Server) removeListener(l *quic.EarlyListener) { s.mutex.Lock() delete(s.listeners, l) + s.generateAltSvcHeader() s.mutex.Unlock() } -func (s *Server) handleConn(sess quic.EarlySession) { +func (s *Server) handleConn(conn quic.EarlyConnection) { decoder := qpack.NewDecoder(nil) // send a SETTINGS frame - str, err := sess.OpenUniStream() + str, err := conn.OpenUniStream() if err != nil { s.logger.Debugf("Opening the control stream failed.") return } buf := &bytes.Buffer{} quicvarint.Write(buf, streamTypeControlStream) // stream type - (&settingsFrame{Datagram: s.EnableDatagrams}).Write(buf) + (&settingsFrame{Datagram: s.EnableDatagrams, Other: s.AdditionalSettings}).Write(buf) str.Write(buf.Bytes()) - go s.handleUnidirectionalStreams(sess) + go s.handleUnidirectionalStreams(conn) // Process all requests immediately. // It's the client's responsibility to decide which requests are eligible for 0-RTT. for { - str, err := sess.AcceptStream(context.Background()) + str, err := conn.AcceptStream(context.Background()) if err != nil { s.logger.Debugf("Accepting stream failed: %s", err) return } go func() { - rerr := s.handleRequest(sess, str, decoder, func() { - sess.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") + rerr := s.handleRequest(conn, str, decoder, func() { + conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "") }) + if rerr.err == errHijacked { + return + } if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 { s.logger.Debugf("Handling request failed: %s", err) if rerr.streamErr != 0 { @@ -267,7 +398,7 @@ func (s *Server) handleConn(sess quic.EarlySession) { if rerr.err != nil { reason = rerr.err.Error() } - sess.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) + conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) } return } @@ -276,9 +407,9 @@ func (s *Server) handleConn(sess quic.EarlySession) { } } -func (s *Server) handleUnidirectionalStreams(sess quic.EarlySession) { +func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) { for { - str, err := sess.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) if err != nil { s.logger.Debugf("accepting unidirectional stream failed: %s", err) return @@ -298,20 +429,23 @@ func (s *Server) handleUnidirectionalStreams(sess quic.EarlySession) { // TODO: check that only one stream of each type is opened. return case streamTypePushStream: // only the server can push - sess.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "") + conn.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "") return default: + if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str) { + return + } str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) return } - f, err := parseNextFrame(str) + f, err := parseNextFrame(str, nil) if err != nil { - sess.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") + conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") return } sf, ok := f.(*settingsFrame) if !ok { - sess.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "") + conn.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "") return } if !sf.Datagram { @@ -320,8 +454,8 @@ func (s *Server) handleUnidirectionalStreams(sess quic.EarlySession) { // If datagram support was enabled on our side as well as on the client side, // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). - if s.EnableDatagrams && !sess.ConnectionState().SupportsDatagrams { - sess.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support") + if s.EnableDatagrams && !conn.ConnectionState().SupportsDatagrams { + conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support") } }(str) } @@ -334,9 +468,18 @@ func (s *Server) maxHeaderBytes() uint64 { return uint64(s.Server.MaxHeaderBytes) } -func (s *Server) handleRequest(sess quic.Session, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError { - frame, err := parseNextFrame(str) +func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError { + var ufh unknownFrameHandlerFunc + if s.StreamHijacker != nil { + ufh = func(ft FrameType) (processed bool, err error) { + return s.StreamHijacker(ft, conn, str) + } + } + frame, err := parseNextFrame(str, ufh) if err != nil { + if err == errHijacked { + return requestError{err: errHijacked} + } return newStreamError(errorRequestIncomplete, err) } hf, ok := frame.(*headersFrame) @@ -361,7 +504,7 @@ func (s *Server) handleRequest(sess quic.Session, str quic.Stream, decoder *qpac return newStreamError(errorGeneralProtocolError, err) } - req.RemoteAddr = sess.RemoteAddr().String() + req.RemoteAddr = conn.RemoteAddr().String() req.Body = newRequestBody(str, onFrameError) if s.logger.Debug() { @@ -372,9 +515,9 @@ func (s *Server) handleRequest(sess quic.Session, str quic.Stream, decoder *qpac ctx := str.Context() ctx = context.WithValue(ctx, ServerContextKey, s) - ctx = context.WithValue(ctx, http.LocalAddrContextKey, sess.LocalAddr()) + ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr()) req = req.WithContext(ctx) - r := newResponseWriter(str, s.logger) + r := newResponseWriter(str, conn, s.logger) defer func() { if !r.usedDataStream() { r.Flush() @@ -415,11 +558,11 @@ func (s *Server) handleRequest(sess quic.Session, str quic.Stream, decoder *qpac // Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients. // Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. func (s *Server) Close() error { - s.closed.Set(true) - s.mutex.Lock() defer s.mutex.Unlock() + s.closed = true + var err error for ln := range s.listeners { if cerr := (*ln).Close(); cerr != nil && err == nil { @@ -436,39 +579,28 @@ func (s *Server) CloseGracefully(timeout time.Duration) error { return nil } -// SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC. -// The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443): -// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30" +// ErrNoAltSvcPort is the error returned by SetQuicHeaders when no port was found +// for Alt-Svc to announce. This can happen if listening on a PacketConn without a port +// (UNIX socket, for example) and no port is specified in Server.Port or Server.Addr. +var ErrNoAltSvcPort = errors.New("no port can be announced, specify it explicitly using Server.Port or Server.Addr") + +// SetQuicHeaders can be used to set the proper headers that announce that this server supports HTTP/3. +// The values set by default advertise all of the ports the server is listening on, but can be +// changed to a specific port by setting Server.Port before launching the serverr. +// If no listener's Addr().String() returns an address with a valid port, Server.Addr will be used +// to extract the port, if specified. +// For example, a server launched using ListenAndServe on an address with port 443 would set: +// Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 func (s *Server) SetQuicHeaders(hdr http.Header) error { - port := atomic.LoadUint32(&s.Port) + s.mutex.RLock() + defer s.mutex.RUnlock() - if port == 0 { - // Extract port from s.Server.Addr - _, portStr, err := net.SplitHostPort(s.Server.Addr) - if err != nil { - return err - } - portInt, err := net.LookupPort("tcp", portStr) - if err != nil { - return err - } - port = uint32(portInt) - atomic.StoreUint32(&s.Port, port) - } - - // This code assumes that we will use protocol.SupportedVersions if no quic.Config is passed. - supportedVersions := protocol.SupportedVersions - if s.QuicConfig != nil && len(s.QuicConfig.Versions) > 0 { - supportedVersions = s.QuicConfig.Versions - } - altSvc := make([]string, 0, len(supportedVersions)) - for _, version := range supportedVersions { - v := versionToALPN(version) - if len(v) > 0 { - altSvc = append(altSvc, fmt.Sprintf(`%s=":%d"; ma=2592000`, v, port)) - } + if s.altSvcHeader == "" { + return ErrNoAltSvcPort } - hdr.Add("Alt-Svc", strings.Join(altSvc, ",")) + // use the map directly to avoid constant canonicalization + // since the key is already canonicalized + hdr["Alt-Svc"] = append(hdr["Alt-Svc"], s.altSvcHeader) return nil } diff --git a/http3/server_test.go b/http3/server_test.go index 02e9c4166be..4b99775a60c 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -9,6 +9,8 @@ import ( "io" "net" "net/http" + "runtime" + "sync/atomic" "time" "github.com/lucas-clemente/quic-go" @@ -23,6 +25,7 @@ import ( . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + gmtypes "github.com/onsi/gomega/types" ) type mockConn struct { @@ -38,6 +41,49 @@ func (c *mockConn) GetQUICVersion() protocol.VersionNumber { return c.version } +type mockAddr struct { + addr string +} + +func (ma *mockAddr) Network() string { + return "udp" +} + +func (ma *mockAddr) String() string { + return ma.addr +} + +type mockAddrListener struct { + *mockquic.MockEarlyListener + addr *mockAddr +} + +func (m *mockAddrListener) Addr() net.Addr { + _ = m.MockEarlyListener.Addr() + return m.addr +} + +func newMockAddrListener(addr string) *mockAddrListener { + return &mockAddrListener{ + MockEarlyListener: mockquic.NewMockEarlyListener(mockCtrl), + addr: &mockAddr{ + addr: addr, + }, + } +} + +type noPortListener struct { + *mockAddrListener +} + +func (m *noPortListener) Addr() net.Addr { + _ = m.mockAddrListener.Addr() + return &net.UnixAddr{ + Net: "unix", + Name: "/tmp/quic.sock", + } +} + var _ = Describe("Server", func() { var ( s *Server @@ -62,7 +108,7 @@ var _ = Describe("Server", func() { var ( qpackDecoder *qpack.Decoder str *mockquic.MockStream - sess *mockquic.MockEarlySession + conn *mockquic.MockEarlyConnection exampleGetRequest *http.Request examplePostRequest *http.Request ) @@ -72,7 +118,7 @@ var _ = Describe("Server", func() { fields := make(map[string][]string) decoder := qpack.NewDecoder(nil) - frame, err := parseNextFrame(str) + frame, err := parseNextFrame(str, nil) ExpectWithOffset(1, err).ToNot(HaveOccurred()) ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) headersFrame := frame.(*headersFrame) @@ -119,10 +165,10 @@ var _ = Describe("Server", func() { qpackDecoder = qpack.NewDecoder(nil) str = mockquic.NewMockStream(mockCtrl) - sess = mockquic.NewMockEarlySession(mockCtrl) + conn = mockquic.NewMockEarlyConnection(mockCtrl) addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - sess.EXPECT().RemoteAddr().Return(addr).AnyTimes() - sess.EXPECT().LocalAddr().AnyTimes() + conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() + conn.EXPECT().LocalAddr().AnyTimes() }) It("calls the HTTP handler function", func() { @@ -138,7 +184,7 @@ var _ = Describe("Server", func() { }).AnyTimes() str.EXPECT().CancelRead(gomock.Any()) - Expect(s.handleRequest(sess, str, qpackDecoder, nil)).To(Equal(requestError{})) + Expect(s.handleRequest(conn, str, qpackDecoder, nil)).To(Equal(requestError{})) var req *http.Request Eventually(requestChan).Should(Receive(&req)) Expect(req.Host).To(Equal("www.example.com")) @@ -155,7 +201,7 @@ var _ = Describe("Server", func() { str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() str.EXPECT().CancelRead(gomock.Any()) - serr := s.handleRequest(sess, str, qpackDecoder, nil) + serr := s.handleRequest(conn, str, qpackDecoder, nil) Expect(serr.err).ToNot(HaveOccurred()) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) @@ -172,7 +218,7 @@ var _ = Describe("Server", func() { str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() str.EXPECT().CancelRead(gomock.Any()) - serr := s.handleRequest(sess, str, qpackDecoder, nil) + serr := s.handleRequest(conn, str, qpackDecoder, nil) Expect(serr.err).ToNot(HaveOccurred()) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"})) @@ -189,22 +235,88 @@ var _ = Describe("Server", func() { str.EXPECT().Write([]byte("foobar")) // don't EXPECT CancelRead() - serr := s.handleRequest(sess, str, qpackDecoder, nil) + serr := s.handleRequest(conn, str, qpackDecoder, nil) Expect(serr.err).ToNot(HaveOccurred()) }) + Context("hijacking unidirectional streams", func() { + var conn *mockquic.MockEarlyConnection + testDone := make(chan struct{}) + + BeforeEach(func() { + testDone = make(chan struct{}) + conn = mockquic.NewMockEarlyConnection(mockCtrl) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Write(gomock.Any()) + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() + conn.EXPECT().LocalAddr().AnyTimes() + }) + + AfterEach(func() { testDone <- struct{}{} }) + + It("hijacks an unidirectional stream of unknown stream type", func() { + streamTypeChan := make(chan StreamType, 1) + s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool { + streamTypeChan <- st + return true + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x54) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return unknownStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { + streamTypeChan := make(chan StreamType, 1) + s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool { + streamTypeChan <- st + return false + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x54) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError)) + + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return unknownStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + }) + Context("control stream handling", func() { - var sess *mockquic.MockEarlySession + var conn *mockquic.MockEarlyConnection testDone := make(chan struct{}) BeforeEach(func() { - sess = mockquic.NewMockEarlySession(mockCtrl) + conn = mockquic.NewMockEarlyConnection(mockCtrl) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Write(gomock.Any()) - sess.EXPECT().OpenUniStream().Return(controlStr, nil) - sess.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) - sess.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() - sess.EXPECT().LocalAddr().AnyTimes() + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() + conn.EXPECT().LocalAddr().AnyTimes() }) AfterEach(func() { testDone <- struct{}{} }) @@ -215,15 +327,15 @@ var _ = Describe("Server", func() { (&settingsFrame{}).Write(buf) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) - s.handleConn(sess) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to sess.CloseWithError + s.handleConn(conn) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError }) for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { @@ -239,19 +351,19 @@ var _ = Describe("Server", func() { str := mockquic.NewMockStream(mockCtrl) str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return str, nil }) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) - s.handleConn(sess) + s.handleConn(conn) time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead }) } - It("reset streams other than the control stream and the QPACK streams", func() { + It("reset streams Other than the control stream and the QPACK streams", func() { buf := &bytes.Buffer{} quicvarint.Write(buf, 1337) str := mockquic.NewMockStream(mockCtrl) @@ -261,14 +373,14 @@ var _ = Describe("Server", func() { close(done) }) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return str, nil }) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) - s.handleConn(sess) + s.handleConn(conn) Eventually(done).Should(BeClosed()) }) @@ -278,20 +390,20 @@ var _ = Describe("Server", func() { (&dataFrame{}).Write(buf) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) done := make(chan struct{}) - sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { defer GinkgoRecover() Expect(code).To(BeEquivalentTo(errorMissingSettings)) close(done) }) - s.handleConn(sess) + s.handleConn(conn) Eventually(done).Should(BeClosed()) }) @@ -303,20 +415,20 @@ var _ = Describe("Server", func() { buf.Write(b.Bytes()[:b.Len()-1]) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) done := make(chan struct{}) - sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { defer GinkgoRecover() Expect(code).To(BeEquivalentTo(errorFrameError)) close(done) }) - s.handleConn(sess) + s.handleConn(conn) Eventually(done).Should(BeClosed()) }) @@ -326,20 +438,20 @@ var _ = Describe("Server", func() { (&dataFrame{}).Write(buf) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) done := make(chan struct{}) - sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { defer GinkgoRecover() Expect(code).To(BeEquivalentTo(errorStreamCreationError)) close(done) }) - s.handleConn(sess) + s.handleConn(conn) Eventually(done).Should(BeClosed()) }) @@ -350,45 +462,45 @@ var _ = Describe("Server", func() { (&settingsFrame{Datagram: true}).Write(buf) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { return controlStr, nil }) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) - sess.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) done := make(chan struct{}) - sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) { + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) { defer GinkgoRecover() Expect(code).To(BeEquivalentTo(errorSettingsError)) Expect(reason).To(Equal("missing QUIC Datagram support")) close(done) }) - s.handleConn(sess) + s.handleConn(conn) Eventually(done).Should(BeClosed()) }) }) Context("stream- and connection-level errors", func() { - var sess *mockquic.MockEarlySession + var conn *mockquic.MockEarlyConnection testDone := make(chan struct{}) BeforeEach(func() { testDone = make(chan struct{}) addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - sess = mockquic.NewMockEarlySession(mockCtrl) + conn = mockquic.NewMockEarlyConnection(mockCtrl) controlStr := mockquic.NewMockStream(mockCtrl) controlStr.EXPECT().Write(gomock.Any()) - sess.EXPECT().OpenUniStream().Return(controlStr, nil) - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) - sess.EXPECT().AcceptStream(gomock.Any()).Return(str, nil) - sess.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) - sess.EXPECT().RemoteAddr().Return(addr).AnyTimes() - sess.EXPECT().LocalAddr().AnyTimes() + conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil) + conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() + conn.EXPECT().LocalAddr().AnyTimes() }) AfterEach(func() { testDone <- struct{}{} }) @@ -411,7 +523,7 @@ var _ = Describe("Server", func() { str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError)) str.EXPECT().Close().Do(func() { close(done) }) - s.handleConn(sess) + s.handleConn(conn) Eventually(done).Should(BeClosed()) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) @@ -433,7 +545,7 @@ var _ = Describe("Server", func() { str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() str.EXPECT().CancelWrite(quic.StreamErrorCode(errorFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) - s.handleConn(sess) + s.handleConn(conn) Eventually(done).Should(BeClosed()) }) @@ -448,7 +560,7 @@ var _ = Describe("Server", func() { str.EXPECT().Read(gomock.Any()).Return(0, testErr) str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) }) - s.handleConn(sess) + s.handleConn(conn) Consistently(handlerCalled).ShouldNot(BeClosed()) }) @@ -466,11 +578,11 @@ var _ = Describe("Server", func() { }).AnyTimes() done := make(chan struct{}) - sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { + conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { Expect(code).To(Equal(quic.ApplicationErrorCode(errorFrameUnexpected))) close(done) }) - s.handleConn(sess) + s.handleConn(conn) Eventually(done).Should(BeClosed()) }) @@ -493,7 +605,7 @@ var _ = Describe("Server", func() { done := make(chan struct{}) str.EXPECT().CancelWrite(quic.StreamErrorCode(errorFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) - s.handleConn(sess) + s.handleConn(conn) Eventually(done).Should(BeClosed()) }) }) @@ -515,7 +627,7 @@ var _ = Describe("Server", func() { }).AnyTimes() str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError)) - serr := s.handleRequest(sess, str, qpackDecoder, nil) + serr := s.handleRequest(conn, str, qpackDecoder, nil) Expect(serr.err).ToNot(HaveOccurred()) Eventually(handlerCalled).Should(BeClosed()) }) @@ -538,7 +650,7 @@ var _ = Describe("Server", func() { }).AnyTimes() str.EXPECT().CancelRead(quic.StreamErrorCode(errorNoError)) - serr := s.handleRequest(sess, str, qpackDecoder, nil) + serr := s.handleRequest(conn, str, qpackDecoder, nil) Expect(serr.err).ToNot(HaveOccurred()) Eventually(handlerCalled).Should(BeClosed()) }) @@ -549,55 +661,100 @@ var _ = Describe("Server", func() { s.QuicConfig = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionDraft29}} }) + var ln1 quic.EarlyListener + var ln2 quic.EarlyListener expected := http.Header{ "Alt-Svc": {`h3-29=":443"; ma=2592000`}, } - It("sets proper headers with numeric port", func() { - s.Server.Addr = ":443" + addListener := func(addr string, ln *quic.EarlyListener) { + mln := newMockAddrListener(addr) + mln.EXPECT().Addr() + *ln = mln + s.addListener(ln) + } + + removeListener := func(ln *quic.EarlyListener) { + s.removeListener(ln) + } + + checkSetHeaders := func(expected gmtypes.GomegaMatcher) { hdr := http.Header{} Expect(s.SetQuicHeaders(hdr)).To(Succeed()) - Expect(hdr).To(Equal(expected)) + Expect(hdr).To(expected) + } + + checkSetHeaderError := func() { + hdr := http.Header{} + Expect(s.SetQuicHeaders(hdr)).To(Equal(ErrNoAltSvcPort)) + } + + It("sets proper headers with numeric port", func() { + addListener(":443", &ln1) + checkSetHeaders(Equal(expected)) + removeListener(&ln1) + checkSetHeaderError() }) It("sets proper headers with full addr", func() { - s.Server.Addr = "127.0.0.1:443" - hdr := http.Header{} - Expect(s.SetQuicHeaders(hdr)).To(Succeed()) - Expect(hdr).To(Equal(expected)) + addListener("127.0.0.1:443", &ln1) + checkSetHeaders(Equal(expected)) + removeListener(&ln1) + checkSetHeaderError() }) It("sets proper headers with string port", func() { - s.Server.Addr = ":https" - hdr := http.Header{} - Expect(s.SetQuicHeaders(hdr)).To(Succeed()) - Expect(hdr).To(Equal(expected)) + addListener(":https", &ln1) + checkSetHeaders(Equal(expected)) + removeListener(&ln1) + checkSetHeaderError() }) It("works multiple times", func() { - s.Server.Addr = ":https" - hdr := http.Header{} - Expect(s.SetQuicHeaders(hdr)).To(Succeed()) - Expect(hdr).To(Equal(expected)) - hdr = http.Header{} - Expect(s.SetQuicHeaders(hdr)).To(Succeed()) - Expect(hdr).To(Equal(expected)) + addListener(":https", &ln1) + checkSetHeaders(Equal(expected)) + checkSetHeaders(Equal(expected)) + removeListener(&ln1) + checkSetHeaderError() }) It("works if the quic.Config sets QUIC versions", func() { - s.Server.Addr = ":443" s.QuicConfig.Versions = []quic.VersionNumber{quic.Version1, quic.VersionDraft29} - hdr := http.Header{} - Expect(s.SetQuicHeaders(hdr)).To(Succeed()) - Expect(hdr).To(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000,h3-29=":443"; ma=2592000`}})) + addListener(":443", &ln1) + checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000,h3-29=":443"; ma=2592000`}})) + removeListener(&ln1) + checkSetHeaderError() }) It("uses s.Port if set to a non-zero value", func() { - s.Server.Addr = ":443" s.Port = 8443 - hdr := http.Header{} - Expect(s.SetQuicHeaders(hdr)).To(Succeed()) - Expect(hdr).To(Equal(http.Header{"Alt-Svc": {`h3-29=":8443"; ma=2592000`}})) + addListener(":443", &ln1) + checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3-29=":8443"; ma=2592000`}})) + removeListener(&ln1) + checkSetHeaderError() + }) + + It("uses s.Addr if listeners don't have ports available", func() { + s.Addr = ":443" + mln := &noPortListener{newMockAddrListener("")} + mln.EXPECT().Addr() + ln1 = mln + s.addListener(&ln1) + checkSetHeaders(Equal(expected)) + s.removeListener(&ln1) + checkSetHeaderError() + }) + + It("properly announces multiple listeners", func() { + addListener(":443", &ln1) + addListener(":8443", &ln2) + checkSetHeaders(Or( + Equal(http.Header{"Alt-Svc": {`h3-29=":443"; ma=2592000,h3-29=":8443"; ma=2592000`}}), + Equal(http.Header{"Alt-Svc": {`h3-29=":8443"; ma=2592000,h3-29=":443"; ma=2592000`}}), + )) + removeListener(&ln1) + removeListener(&ln2) + checkSetHeaderError() }) }) @@ -619,6 +776,51 @@ var _ = Describe("Server", func() { Expect(serv.ListenAndServe()).To(MatchError(http.ErrServerClosed)) }) + It("handles concurrent Serve and Close", func() { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + c, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + s.Serve(c) + }() + runtime.Gosched() + s.Close() + Eventually(done).Should(BeClosed()) + }) + + Context("ConfigureTLSConfig", func() { + var tlsConf *tls.Config + var ch *tls.ClientHelloInfo + + BeforeEach(func() { + tlsConf = &tls.Config{} + ch = &tls.ClientHelloInfo{} + }) + + It("advertises draft by default", func() { + tlsConf = ConfigureTLSConfig(tlsConf) + Expect(tlsConf.GetConfigForClient).NotTo(BeNil()) + + config, err := tlsConf.GetConfigForClient(ch) + Expect(err).NotTo(HaveOccurred()) + Expect(config.NextProtos).To(Equal([]string{nextProtoH3Draft29})) + }) + + It("advertises h3 for quic version 1", func() { + tlsConf = ConfigureTLSConfig(tlsConf) + Expect(tlsConf.GetConfigForClient).NotTo(BeNil()) + + ch.Conn = newMockConn(protocol.Version1) + config, err := tlsConf.GetConfigForClient(ch) + Expect(err).NotTo(HaveOccurred()) + Expect(config.NextProtos).To(Equal([]string{nextProtoH3})) + }) + }) + Context("Serve", func() { origQuicListen := quicListen @@ -627,7 +829,7 @@ var _ = Describe("Server", func() { }) It("serves a packet conn", func() { - ln := mockquic.NewMockEarlyListener(mockCtrl) + ln := newMockAddrListener(":443") conn := &net.UDPConn{} quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { Expect(c).To(Equal(conn)) @@ -638,10 +840,11 @@ var _ = Describe("Server", func() { s.TLSConfig = &tls.Config{} stopAccept := make(chan struct{}) - ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) { + ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { <-stopAccept return nil, errors.New("closed") }) + ln.EXPECT().Addr() // generate alt-svc headers done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -656,8 +859,8 @@ var _ = Describe("Server", func() { }) It("serves two packet conns", func() { - ln1 := mockquic.NewMockEarlyListener(mockCtrl) - ln2 := mockquic.NewMockEarlyListener(mockCtrl) + ln1 := newMockAddrListener(":443") + ln2 := newMockAddrListener(":8443") lns := make(chan quic.EarlyListener, 2) lns <- ln1 lns <- ln2 @@ -671,15 +874,17 @@ var _ = Describe("Server", func() { s.TLSConfig = &tls.Config{} stopAccept1 := make(chan struct{}) - ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) { + ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { <-stopAccept1 return nil, errors.New("closed") }) + ln1.EXPECT().Addr() // generate alt-svc headers stopAccept2 := make(chan struct{}) - ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) { + ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { <-stopAccept2 return nil, errors.New("closed") }) + ln2.EXPECT().Addr() done1 := make(chan struct{}) go func() { @@ -704,6 +909,96 @@ var _ = Describe("Server", func() { }) }) + Context("ServeListener", func() { + origQuicListen := quicListen + + AfterEach(func() { + quicListen = origQuicListen + }) + + It("serves a listener", func() { + var called int32 + ln := newMockAddrListener(":443") + quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + atomic.StoreInt32(&called, 1) + return ln, nil + } + + s := &Server{Server: &http.Server{}} + s.TLSConfig = &tls.Config{} + + stopAccept := make(chan struct{}) + ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { + <-stopAccept + return nil, errors.New("closed") + }) + ln.EXPECT().Addr() // generate alt-svc headers + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + s.ServeListener(ln) + }() + + Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) + Consistently(done).ShouldNot(BeClosed()) + ln.EXPECT().Close().Do(func() { close(stopAccept) }) + Expect(s.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + + It("serves two listeners", func() { + var called int32 + ln1 := newMockAddrListener(":443") + ln2 := newMockAddrListener(":8443") + lns := make(chan quic.EarlyListener, 2) + lns <- ln1 + lns <- ln2 + quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + atomic.StoreInt32(&called, 1) + return <-lns, nil + } + + s := &Server{Server: &http.Server{}} + s.TLSConfig = &tls.Config{} + + stopAccept1 := make(chan struct{}) + ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { + <-stopAccept1 + return nil, errors.New("closed") + }) + ln1.EXPECT().Addr() // generate alt-svc headers + stopAccept2 := make(chan struct{}) + ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { + <-stopAccept2 + return nil, errors.New("closed") + }) + ln2.EXPECT().Addr() + + done1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done1) + s.ServeListener(ln1) + }() + done2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done2) + s.ServeListener(ln2) + }() + + Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) + Consistently(done1).ShouldNot(BeClosed()) + Expect(done2).ToNot(BeClosed()) + ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) + ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) + Expect(s.Close()).To(Succeed()) + Eventually(done1).Should(BeClosed()) + Eventually(done2).Should(BeClosed()) + }) + }) + Context("ListenAndServe", func() { BeforeEach(func() { s.Server.Addr = "localhost:0" diff --git a/integrationtests/gomodvendor/go.mod b/integrationtests/gomodvendor/go.mod index 4fc950f9d65..72e681cdf7e 100644 --- a/integrationtests/gomodvendor/go.mod +++ b/integrationtests/gomodvendor/go.mod @@ -1,6 +1,6 @@ module test -go 1.15 +go 1.16 // The version doesn't matter here, as we're replacing it with the currently checked out code anyway. require github.com/lucas-clemente/quic-go v0.21.0 diff --git a/integrationtests/gomodvendor/go.sum b/integrationtests/gomodvendor/go.sum index 3732c644cac..6b7ff48aabc 100644 --- a/integrationtests/gomodvendor/go.sum +++ b/integrationtests/gomodvendor/go.sum @@ -16,22 +16,28 @@ github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitf github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= -github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= @@ -40,11 +46,15 @@ github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrU github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= @@ -63,18 +73,16 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/lucas-clemente/quic-go v0.21.0 h1:ZdC8UBxUSBdPlEv1+4y4SqIBy54VA8bRxN7DmkQ0URs= -github.com/lucas-clemente/quic-go v0.21.0/go.mod h1:BWkfkkOSJD1AxFNBqdjBZi6FznZ96bhdcvZiA+LDrY8= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/marten-seemann/qpack v0.2.1 h1:jvTsT/HpCn2UZJdP+UUB53FfUUgeOyG5K1ns0OJOGVs= github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc= -github.com/marten-seemann/qtls-go1-15 v0.1.4 h1:RehYMOyRW8hPVEja1KBVsFVNSm35Jj9Mvs5yNoZZ28A= -github.com/marten-seemann/qtls-go1-15 v0.1.4/go.mod h1:GyFwywLKkRt+6mfU99csTEY1joMZz5vmB1WNZH3P81I= -github.com/marten-seemann/qtls-go1-16 v0.1.3 h1:XEZ1xGorVy9u+lJq+WXNE+hiqRYLNvJGYmwfwKQN2gU= -github.com/marten-seemann/qtls-go1-16 v0.1.3/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= -github.com/marten-seemann/qtls-go1-17 v0.1.0-alpha.1 h1:LRFa3YRSlOAf9y56Szfhlh60CQrIMBSK/rneZD1gtuk= -github.com/marten-seemann/qtls-go1-17 v0.1.0-alpha.1/go.mod h1:lQDiKZDfPagLmg1zMtEgoBMSTAORq6M08lBogD5FtBY= +github.com/marten-seemann/qtls-go1-16 v0.1.5 h1:o9JrYPPco/Nukd/HpOHMHZoBDXQqoNtUCmny98/1uqQ= +github.com/marten-seemann/qtls-go1-16 v0.1.5/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= +github.com/marten-seemann/qtls-go1-17 v0.1.1 h1:DQjHPq+aOzUeh9/lixAGunn6rIOQyWChPSI4+hgW7jc= +github.com/marten-seemann/qtls-go1-17 v0.1.1/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s= +github.com/marten-seemann/qtls-go1-18 v0.1.1 h1:qp7p7XXUFL7fpBvSS1sWD+uSqPvzNQK43DH+/qEkj0Y= +github.com/marten-seemann/qtls-go1-18 v0.1.1/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -82,13 +90,21 @@ github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3Rllmb github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= +github.com/onsi/ginkgo v1.16.4 h1:29JGrr5oVBm5ulCWet69zQkzWipVXIol6ygQUe/EzNc= +github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.13.0 h1:7lLHu94wT9Ij0o6EWWclhu0aOh32VxhkwEJvzuWPeak= +github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= @@ -120,10 +136,15 @@ github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYED github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= @@ -139,6 +160,8 @@ golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTk golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -151,8 +174,11 @@ golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200707034311-ab3426394381 h1:VXak5I6aEWmAXeQjA+QSZzlgNrpq9mjcfDemuexIKsU= golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -163,7 +189,8 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -176,12 +203,21 @@ golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201231184435-2d18734c6014 h1:joucsQqXmyBVxViHCPFjG3hx8JzIFSaym3l3MM/Jsdg= -golang.org/x/sys v0.0.0-20201231184435-2d18734c6014/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -191,9 +227,14 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.1 h1:wGiQel/hW0NnEkJUk8lbzkX2gFJU6PFxf1v5OlCfuOs= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= @@ -216,15 +257,21 @@ google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/integrationtests/self/cancelation_test.go b/integrationtests/self/cancelation_test.go index 651bb06e9c3..56edbbfc233 100644 --- a/integrationtests/self/cancelation_test.go +++ b/integrationtests/self/cancelation_test.go @@ -22,7 +22,7 @@ var _ = Describe("Stream Cancelations", func() { Context("canceling the read side", func() { var server quic.Listener - // The server accepts a single session, and then opens numStreams unidirectional streams. + // The server accepts a single connection, and then opens numStreams unidirectional streams. // On each of these streams, it (tries to) write PRData. // When done, it sends the number of canceled streams on the channel. runServer := func(data []byte) <-chan int32 { @@ -36,13 +36,13 @@ var _ = Describe("Stream Cancelations", func() { defer GinkgoRecover() var wg sync.WaitGroup wg.Add(numStreams) - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := 0; i < numStreams; i++ { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.OpenUniStreamSync(context.Background()) + str, err := conn.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) if _, err := str.Write(data); err != nil { Expect(err).To(MatchError(&quic.StreamError{ @@ -71,7 +71,7 @@ var _ = Describe("Stream Cancelations", func() { It("downloads when the client immediately cancels most streams", func() { serverCanceledCounterChan := runServer(PRData) - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), @@ -85,7 +85,7 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) // cancel around 2/3 of the streams if rand.Int31()%3 != 0 { @@ -102,7 +102,7 @@ var _ = Describe("Stream Cancelations", func() { var serverCanceledCounter int32 Eventually(serverCanceledCounterChan).Should(Receive(&serverCanceledCounter)) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) clientCanceledCounter := atomic.LoadInt32(&canceledCounter) // The server will only count a stream as being reset if learns about the cancelation before it finished writing all data. @@ -115,7 +115,7 @@ var _ = Describe("Stream Cancelations", func() { It("downloads when the client cancels streams after reading from them for a bit", func() { serverCanceledCounterChan := runServer(PRData) - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), @@ -129,7 +129,7 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) // only read some data from about 1/3 of the streams if rand.Int31()%3 != 0 { @@ -150,7 +150,7 @@ var _ = Describe("Stream Cancelations", func() { var serverCanceledCounter int32 Eventually(serverCanceledCounterChan).Should(Receive(&serverCanceledCounter)) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) clientCanceledCounter := atomic.LoadInt32(&canceledCounter) // The server will only count a stream as being reset if learns about the cancelation before it finished writing all data. @@ -165,7 +165,7 @@ var _ = Describe("Stream Cancelations", func() { // see https://github.com/lucas-clemente/quic-go/issues/3239. serverCanceledCounterChan := runServer(make([]byte, 100)) // make sure the FIN is sent with the STREAM frame - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), @@ -179,7 +179,7 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) @@ -198,7 +198,7 @@ var _ = Describe("Stream Cancelations", func() { }() } wg.Wait() - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) numCanceled := atomic.LoadInt32(&counter) fmt.Fprintf(GinkgoWriter, "canceled %d out of %d streams", numCanceled, numStreams) Expect(numCanceled).ToNot(BeZero()) @@ -208,7 +208,7 @@ var _ = Describe("Stream Cancelations", func() { Context("canceling the write side", func() { runClient := func(server quic.Listener) int32 /* number of canceled streams */ { - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), @@ -222,7 +222,7 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) if err != nil { @@ -242,7 +242,7 @@ var _ = Describe("Stream Cancelations", func() { fmt.Fprintf(GinkgoWriter, "Canceled writing on %d of %d streams\n", streamCount, numStreams) Expect(streamCount).To(BeNumerically(">", numStreams/10)) Expect(numStreams - streamCount).To(BeNumerically(">", numStreams/10)) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(server.Close()).To(Succeed()) return streamCount } @@ -254,12 +254,12 @@ var _ = Describe("Stream Cancelations", func() { var canceledCounter int32 go func() { defer GinkgoRecover() - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := 0; i < numStreams; i++ { go func() { defer GinkgoRecover() - str, err := sess.OpenUniStreamSync(context.Background()) + str, err := conn.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) // cancel about 2/3 of the streams if rand.Int31()%3 != 0 { @@ -285,12 +285,12 @@ var _ = Describe("Stream Cancelations", func() { var canceledCounter int32 go func() { defer GinkgoRecover() - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := 0; i < numStreams; i++ { go func() { defer GinkgoRecover() - str, err := sess.OpenUniStreamSync(context.Background()) + str, err := conn.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) // only write some data from about 1/3 of the streams, then cancel if rand.Int31()%3 != 0 { @@ -323,13 +323,13 @@ var _ = Describe("Stream Cancelations", func() { defer GinkgoRecover() var wg sync.WaitGroup wg.Add(numStreams) - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := 0; i < numStreams; i++ { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.OpenUniStreamSync(context.Background()) + str, err := conn.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) // cancel about half of the streams if rand.Int31()%2 == 0 { @@ -353,7 +353,7 @@ var _ = Describe("Stream Cancelations", func() { close(done) }() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), @@ -367,7 +367,7 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) // cancel around half of the streams if rand.Int31()%2 == 0 { @@ -392,7 +392,7 @@ var _ = Describe("Stream Cancelations", func() { Expect(count).To(BeNumerically(">", numStreams/15)) fmt.Fprintf(GinkgoWriter, "Successfully read from %d of %d streams.\n", count, numStreams) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) Eventually(done).Should(BeClosed()) Expect(server.Close()).To(Succeed()) }) @@ -405,7 +405,7 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() defer close(done) - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) var wg sync.WaitGroup wg.Add(numStreams) @@ -413,7 +413,7 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.OpenUniStreamSync(context.Background()) + str, err := conn.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) // cancel about half of the streams length := len(PRData) @@ -438,7 +438,7 @@ var _ = Describe("Stream Cancelations", func() { wg.Wait() }() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}), @@ -453,7 +453,7 @@ var _ = Describe("Stream Cancelations", func() { defer GinkgoRecover() defer wg.Done() - str, err := sess.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) r := io.Reader(str) @@ -488,7 +488,7 @@ var _ = Describe("Stream Cancelations", func() { Expect(count).To(BeNumerically(">", numStreams/15)) fmt.Fprintf(GinkgoWriter, "Successfully read from %d of %d streams.\n", count, numStreams) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(server.Close()).To(Succeed()) }) }) @@ -500,14 +500,14 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) ticker := time.NewTicker(5 * time.Millisecond) for i := 0; i < numStreams; i++ { <-ticker.C go func() { defer GinkgoRecover() - str, err := sess.OpenUniStreamSync(context.Background()) + str, err := conn.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) _, err = str.Write(PRData) Expect(err).ToNot(HaveOccurred()) @@ -516,7 +516,7 @@ var _ = Describe("Stream Cancelations", func() { } }() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 3}), @@ -539,7 +539,7 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() - str, err := sess.AcceptUniStream(ctx) + str, err := conn.AcceptUniStream(ctx) if err != nil { if err.Error() == "context canceled" { atomic.AddInt32(&counter, 1) @@ -557,7 +557,7 @@ var _ = Describe("Stream Cancelations", func() { count := atomic.LoadInt32(&counter) fmt.Fprintf(GinkgoWriter, "Canceled AcceptStream %d times\n", count) Expect(count).To(BeNumerically(">", numStreams/2)) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(server.Close()).To(Succeed()) }) @@ -574,14 +574,14 @@ var _ = Describe("Stream Cancelations", func() { go func() { defer GinkgoRecover() defer close(msg) - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) var numOpened int for numOpened < numStreams { ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(20*time.Millisecond)) defer cancel() - str, err := sess.OpenUniStreamSync(ctx) + str, err := conn.OpenUniStreamSync(ctx) if err != nil { Expect(err).To(MatchError(context.DeadlineExceeded)) atomic.AddInt32(&numCanceled, 1) @@ -601,7 +601,7 @@ var _ = Describe("Stream Cancelations", func() { } }() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIncomingUniStreams: maxIncomingStreams}), @@ -612,7 +612,7 @@ var _ = Describe("Stream Cancelations", func() { wg.Add(numStreams) for i := 0; i < numStreams; i++ { <-msg - str, err := sess.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) go func(str quic.ReceiveStream) { defer GinkgoRecover() @@ -627,7 +627,7 @@ var _ = Describe("Stream Cancelations", func() { count := atomic.LoadInt32(&numCanceled) fmt.Fprintf(GinkgoWriter, "Canceled OpenStreamSync %d times\n", count) Expect(count).To(BeNumerically(">=", numStreams-maxIncomingStreams)) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(server.Close()).To(Succeed()) }) }) @@ -686,7 +686,7 @@ var _ = Describe("Stream Cancelations", func() { for { str, err := conn.AcceptStream(context.Background()) if err != nil { - // Make sure the session is closed regularly. + // Make sure the connection is closed regularly. Expect(err).To(BeAssignableToTypeOf(&quic.ApplicationError{})) return } @@ -694,7 +694,7 @@ var _ = Describe("Stream Cancelations", func() { } }() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{}), @@ -702,21 +702,21 @@ var _ = Describe("Stream Cancelations", func() { Expect(err).ToNot(HaveOccurred()) for i := 0; i < maxIncomingStreams; i++ { - str, err := sess.OpenStreamSync(context.Background()) + str, err := conn.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) handleStream(str) } // We don't expect to accept any stream here. - // We're just making sure the session stays open and there's no error. + // We're just making sure the connection stays open and there's no error. ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() - _, err = sess.AcceptStream(ctx) + _, err = conn.AcceptStream(ctx) Expect(err).To(MatchError(context.DeadlineExceeded)) wg.Wait() - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) Eventually(serverRunning).Should(BeClosed()) }) }) diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index c93a82c1406..6da758b038f 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -7,7 +7,7 @@ import ( "math/rand" "net" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" @@ -26,13 +26,13 @@ var _ = Describe("Connection ID lengths tests", func() { go func() { defer GinkgoRecover() for { - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) if err != nil { return } go func() { defer GinkgoRecover() - str, err := sess.OpenStream() + str, err := conn.OpenStream() Expect(err).ToNot(HaveOccurred()) defer str.Close() _, err = str.Write(PRData) diff --git a/integrationtests/self/datagram_test.go b/integrationtests/self/datagram_test.go index 371f4c65451..12564f6ec0d 100644 --- a/integrationtests/self/datagram_test.go +++ b/integrationtests/self/datagram_test.go @@ -47,9 +47,9 @@ var _ = Describe("Datagram test", func() { Expect(err).ToNot(HaveOccurred()) go func() { defer GinkgoRecover() - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(sess.ConnectionState().SupportsDatagrams).To(BeTrue()) + Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) var wg sync.WaitGroup wg.Add(num) @@ -59,7 +59,7 @@ var _ = Describe("Datagram test", func() { defer wg.Done() b := make([]byte, 8) binary.BigEndian.PutUint64(b, uint64(i)) - Expect(sess.SendMessage(b)).To(Succeed()) + Expect(conn.SendMessage(b)).To(Succeed()) }(i) } wg.Wait() @@ -102,7 +102,7 @@ var _ = Describe("Datagram test", func() { startServerAndProxy() raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) Expect(err).ToNot(HaveOccurred()) - sess, err := quic.Dial( + conn, err := quic.Dial( clientConn, raddr, fmt.Sprintf("localhost:%d", proxy.LocalPort()), @@ -113,14 +113,14 @@ var _ = Describe("Datagram test", func() { }), ) Expect(err).ToNot(HaveOccurred()) - Expect(sess.ConnectionState().SupportsDatagrams).To(BeTrue()) + Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) var counter int for { - // Close the session if no message is received for 100 ms. + // Close the connection if no message is received for 100 ms. timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() { - sess.CloseWithError(0, "") + conn.CloseWithError(0, "") }) - if _, err := sess.ReceiveMessage(); err != nil { + if _, err := conn.ReceiveMessage(); err != nil { break } timer.Stop() diff --git a/integrationtests/self/deadline_test.go b/integrationtests/self/deadline_test.go index 963db7aa1f6..188350cae67 100644 --- a/integrationtests/self/deadline_test.go +++ b/integrationtests/self/deadline_test.go @@ -20,22 +20,22 @@ var _ = Describe("Stream deadline tests", func() { strChan := make(chan quic.SendStream) go func() { defer GinkgoRecover() - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptStream(context.Background()) + str, err := conn.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) _, err = str.Read([]byte{0}) Expect(err).ToNot(HaveOccurred()) strChan <- str }() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) - clientStr, err := sess.OpenStream() + clientStr, err := conn.OpenStream() Expect(err).ToNot(HaveOccurred()) _, err = clientStr.Write([]byte{0}) // need to write one byte so the server learns about the stream Expect(err).ToNot(HaveOccurred()) @@ -49,7 +49,7 @@ var _ = Describe("Stream deadline tests", func() { server, serverStr, clientStr := setup() defer server.Close() - const timeout = 20 * time.Millisecond + const timeout = time.Millisecond done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -75,7 +75,7 @@ var _ = Describe("Stream deadline tests", func() { bytesRead += n } Expect(data).To(Equal(PRDataLong)) - // make sure the test actually worked an Read actually ran into the deadline a few times + // make sure the test actually worked and Read actually ran into the deadline a few times Expect(timeoutCounter).To(BeNumerically(">=", 10)) Eventually(done).Should(BeClosed()) }) @@ -84,7 +84,7 @@ var _ = Describe("Stream deadline tests", func() { server, serverStr, clientStr := setup() defer server.Close() - const timeout = 20 * time.Millisecond + const timeout = time.Millisecond go func() { defer GinkgoRecover() _, err := serverStr.Write(PRDataLong) @@ -134,7 +134,7 @@ var _ = Describe("Stream deadline tests", func() { server, serverStr, clientStr := setup() defer server.Close() - const timeout = 20 * time.Millisecond + const timeout = time.Millisecond done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -167,7 +167,7 @@ var _ = Describe("Stream deadline tests", func() { server, serverStr, clientStr := setup() defer server.Close() - const timeout = 20 * time.Millisecond + const timeout = time.Millisecond readDone := make(chan struct{}) go func() { defer GinkgoRecover() diff --git a/integrationtests/self/drop_test.go b/integrationtests/self/drop_test.go index 964106df766..d7416d9186c 100644 --- a/integrationtests/self/drop_test.go +++ b/integrationtests/self/drop_test.go @@ -8,7 +8,7 @@ import ( "sync/atomic" "time" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -87,9 +87,9 @@ var _ = Describe("Drop Tests", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - str, err := sess.OpenStream() + str, err := conn.OpenStream() Expect(err).ToNot(HaveOccurred()) for i := uint8(1); i <= numMessages; i++ { n, err := str.Write([]byte{i}) @@ -98,17 +98,17 @@ var _ = Describe("Drop Tests", func() { time.Sleep(messageInterval) } <-done - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) }() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), ) Expect(err).ToNot(HaveOccurred()) - defer sess.CloseWithError(0, "") - str, err := sess.AcceptStream(context.Background()) + defer conn.CloseWithError(0, "") + str, err := conn.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) for i := uint8(1); i <= numMessages; i++ { b := []byte{0} diff --git a/integrationtests/self/early_data_test.go b/integrationtests/self/early_data_test.go index eb2d6ad8651..efa31f36d62 100644 --- a/integrationtests/self/early_data_test.go +++ b/integrationtests/self/early_data_test.go @@ -7,7 +7,7 @@ import ( "net" "time" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -32,16 +32,16 @@ var _ = Describe("early data", func() { go func() { defer GinkgoRecover() defer close(done) - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - str, err := sess.OpenUniStream() + str, err := conn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write([]byte("early data")) Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) // make sure the Write finished before the handshake completed - Expect(sess.HandshakeComplete().Done()).ToNot(BeClosed()) - Eventually(sess.Context().Done()).Should(BeClosed()) + Expect(conn.HandshakeComplete().Done()).ToNot(BeClosed()) + Eventually(conn.Context().Done()).Should(BeClosed()) }() serverPort := ln.Addr().(*net.UDPAddr).Port proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ @@ -53,18 +53,18 @@ var _ = Describe("early data", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal([]byte("early data"))) - sess.CloseWithError(0, "") + conn.CloseWithError(0, "") Eventually(done).Should(BeClosed()) }) }) diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index f766660a371..b788533dfb0 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -10,7 +10,7 @@ import ( "sync/atomic" "time" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -71,20 +71,20 @@ var _ = Describe("Handshake drop tests", func() { clientSpeaksFirst := &applicationProtocol{ name: "client speaks first", run: func(version protocol.VersionNumber) { - serverSessionChan := make(chan quic.Session) + serverConnChan := make(chan quic.Connection) go func() { defer GinkgoRecover() - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - defer sess.CloseWithError(0, "") - str, err := sess.AcceptStream(context.Background()) + defer conn.CloseWithError(0, "") + str, err := conn.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout)) Expect(err).ToNot(HaveOccurred()) Expect(b).To(Equal(data)) - serverSessionChan <- sess + serverConnChan <- conn }() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{ @@ -94,35 +94,35 @@ var _ = Describe("Handshake drop tests", func() { }), ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.OpenStream() + str, err := conn.OpenStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write(data) Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) - var serverSession quic.Session - Eventually(serverSessionChan, timeout).Should(Receive(&serverSession)) - sess.CloseWithError(0, "") - serverSession.CloseWithError(0, "") + var serverConn quic.Connection + Eventually(serverConnChan, timeout).Should(Receive(&serverConn)) + conn.CloseWithError(0, "") + serverConn.CloseWithError(0, "") }, } serverSpeaksFirst := &applicationProtocol{ name: "server speaks first", run: func(version protocol.VersionNumber) { - serverSessionChan := make(chan quic.Session) + serverConnChan := make(chan quic.Connection) go func() { defer GinkgoRecover() - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - str, err := sess.OpenStream() + str, err := conn.OpenStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write(data) Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) - serverSessionChan <- sess + serverConnChan <- conn }() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{ @@ -132,30 +132,30 @@ var _ = Describe("Handshake drop tests", func() { }), ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptStream(context.Background()) + str, err := conn.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) b, err := io.ReadAll(gbytes.TimeoutReader(str, timeout)) Expect(err).ToNot(HaveOccurred()) Expect(b).To(Equal(data)) - var serverSession quic.Session - Eventually(serverSessionChan, timeout).Should(Receive(&serverSession)) - sess.CloseWithError(0, "") - serverSession.CloseWithError(0, "") + var serverConn quic.Connection + Eventually(serverConnChan, timeout).Should(Receive(&serverConn)) + conn.CloseWithError(0, "") + serverConn.CloseWithError(0, "") }, } nobodySpeaks := &applicationProtocol{ name: "nobody speaks", run: func(version protocol.VersionNumber) { - serverSessionChan := make(chan quic.Session) + serverConnChan := make(chan quic.Connection) go func() { defer GinkgoRecover() - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - serverSessionChan <- sess + serverConnChan <- conn }() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{ @@ -165,11 +165,11 @@ var _ = Describe("Handshake drop tests", func() { }), ) Expect(err).ToNot(HaveOccurred()) - var serverSession quic.Session - Eventually(serverSessionChan, timeout).Should(Receive(&serverSession)) - // both server and client accepted a session. Close now. - sess.CloseWithError(0, "") - serverSession.CloseWithError(0, "") + var serverConn quic.Connection + Eventually(serverConnChan, timeout).Should(Receive(&serverConn)) + // both server and client accepted a connection. Close now. + conn.CloseWithError(0, "") + serverConn.CloseWithError(0, "") }, } diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 75af2ea9ca6..8d3bea4de62 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -131,14 +131,14 @@ var _ = Describe("Handshake tests", func() { runServer(getTLSConfig()) defer server.Close() clientTracer := &versionNegotiationTracer{} - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return clientTracer })}), ) Expect(err).ToNot(HaveOccurred()) - Expect(sess.(versioner).GetVersion()).To(Equal(expectedVersion)) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.(versioner).GetVersion()).To(Equal(expectedVersion)) + Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(clientTracer.chosen).To(Equal(expectedVersion)) Expect(clientTracer.receivedVersionNegotiation).To(BeFalse()) Expect(clientTracer.clientVersions).To(Equal(protocol.SupportedVersions)) @@ -159,7 +159,7 @@ var _ = Describe("Handshake tests", func() { defer server.Close() clientVersions := []protocol.VersionNumber{7, 8, 9, protocol.SupportedVersions[0], 10} clientTracer := &versionNegotiationTracer{} - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{ @@ -168,8 +168,8 @@ var _ = Describe("Handshake tests", func() { }), ) Expect(err).ToNot(HaveOccurred()) - Expect(sess.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0])) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.(versioner).GetVersion()).To(Equal(protocol.SupportedVersions[0])) + Expect(conn.CloseWithError(0, "")).To(Succeed()) Expect(clientTracer.chosen).To(Equal(expectedVersion)) Expect(clientTracer.receivedVersionNegotiation).To(BeTrue()) Expect(clientTracer.clientVersions).To(Equal(clientVersions)) @@ -199,28 +199,28 @@ var _ = Describe("Handshake tests", func() { go func() { defer GinkgoRecover() - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - str, err := sess.OpenStream() + str, err := conn.OpenStream() Expect(err).ToNot(HaveOccurred()) defer str.Close() _, err = str.Write(PRData) Expect(err).ToNot(HaveOccurred()) }() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), nil, ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptStream(context.Background()) + str, err := conn.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(PRData)) - Expect(sess.ConnectionState().TLS.CipherSuite).To(Equal(suiteID)) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.ConnectionState().TLS.CipherSuite).To(Equal(suiteID)) + Expect(conn.CloseWithError(0, "")).To(Succeed()) }) } }) @@ -280,19 +280,19 @@ var _ = Describe("Handshake tests", func() { tlsConf.ClientAuth = tls.RequireAndVerifyClientCert runServer(tlsConf) - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), clientConfig, ) // Usually, the error will occur after the client already finished the handshake. // However, there's a race condition here. The server's CONNECTION_CLOSE might be - // received before the session is returned, so we might already get the error while dialing. + // received before the connection is returned, so we might already get the error while dialing. if err == nil { errChan := make(chan error) go func() { defer GinkgoRecover() - _, err := sess.AcceptStream(context.Background()) + _, err := conn.AcceptStream(context.Background()) errChan <- err }() Eventually(errChan).Should(Receive(&err)) @@ -329,7 +329,7 @@ var _ = Describe("Handshake tests", func() { pconn net.PacketConn ) - dial := func() (quic.Session, error) { + dial := func() (quic.Connection, error) { remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) raddr, err := net.ResolveUDPAddr("udp", remoteAddr) Expect(err).ToNot(HaveOccurred()) @@ -368,11 +368,11 @@ var _ = Describe("Handshake tests", func() { It("rejects new connection attempts if connections don't get accepted", func() { for i := 0; i < protocol.MaxAcceptQueueSize; i++ { - sess, err := dial() + conn, err := dial() Expect(err).ToNot(HaveOccurred()) - defer sess.CloseWithError(0, "") + defer conn.CloseWithError(0, "") } - time.Sleep(25 * time.Millisecond) // wait a bit for the sessions to be queued + time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued _, err := dial() Expect(err).To(HaveOccurred()) @@ -380,14 +380,14 @@ var _ = Describe("Handshake tests", func() { Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) - // now accept one session, freeing one spot in the queue + // now accept one connection, freeing one spot in the queue _, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) // dial again, and expect that this dial succeeds - sess, err := dial() + conn, err := dial() Expect(err).ToNot(HaveOccurred()) - defer sess.CloseWithError(0, "") - time.Sleep(25 * time.Millisecond) // wait a bit for the session to be queued + defer conn.CloseWithError(0, "") + time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued _, err = dial() Expect(err).To(HaveOccurred()) @@ -396,15 +396,15 @@ var _ = Describe("Handshake tests", func() { }) It("removes closed connections from the accept queue", func() { - firstSess, err := dial() + firstConn, err := dial() Expect(err).ToNot(HaveOccurred()) for i := 1; i < protocol.MaxAcceptQueueSize; i++ { - sess, err := dial() + conn, err := dial() Expect(err).ToNot(HaveOccurred()) - defer sess.CloseWithError(0, "") + defer conn.CloseWithError(0, "") } - time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the sessions to be queued + time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued _, err = dial() Expect(err).To(HaveOccurred()) @@ -412,16 +412,16 @@ var _ = Describe("Handshake tests", func() { Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) - // Now close the one of the session that are waiting to be accepted. + // Now close the one of the connection that are waiting to be accepted. // This should free one spot in the queue. - Expect(firstSess.CloseWithError(0, "")) - Eventually(firstSess.Context().Done()).Should(BeClosed()) + Expect(firstConn.CloseWithError(0, "")) + Eventually(firstConn.Context().Done()).Should(BeClosed()) time.Sleep(scaleDuration(20 * time.Millisecond)) // dial again, and expect that this dial succeeds _, err = dial() Expect(err).ToNot(HaveOccurred()) - time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the session to be queued + time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued _, err = dial() Expect(err).To(HaveOccurred()) @@ -438,21 +438,21 @@ var _ = Describe("Handshake tests", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - cs := sess.ConnectionState() + cs := conn.ConnectionState() Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn)) close(done) }() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), nil, ) Expect(err).ToNot(HaveOccurred()) - defer sess.CloseWithError(0, "") - cs := sess.ConnectionState() + defer conn.CloseWithError(0, "") + cs := conn.ConnectionState() Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn)) Eventually(done).Should(BeClosed()) Expect(ln.Close()).To(Succeed()) @@ -489,7 +489,7 @@ var _ = Describe("Handshake tests", func() { server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) - // dial the first session and receive the token + // dial the first connection and receive the token go func() { defer GinkgoRecover() _, err := server.Accept(context.Background()) @@ -500,7 +500,7 @@ var _ = Describe("Handshake tests", func() { puts := make(chan string, 100) tokenStore := newTokenStore(gets, puts) quicConf := getQuicConfig(&quic.Config{TokenStore: tokenStore}) - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), quicConf, @@ -509,10 +509,10 @@ var _ = Describe("Handshake tests", func() { Expect(gets).To(Receive()) Eventually(puts).Should(Receive()) Expect(tokenChan).ToNot(Receive()) - // received a token. Close this session. - Expect(sess.CloseWithError(0, "")).To(Succeed()) + // received a token. Close this connection. + Expect(conn.CloseWithError(0, "")).To(Succeed()) - // dial the second session and verify that the token was used + // dial the second connection and verify that the token was used done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -520,13 +520,13 @@ var _ = Describe("Handshake tests", func() { _, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) }() - sess, err = quic.DialAddr( + conn, err = quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), quicConf, ) Expect(err).ToNot(HaveOccurred()) - defer sess.CloseWithError(0, "") + defer conn.CloseWithError(0, "") Expect(gets).To(Receive()) Expect(tokenChan).To(Receive()) diff --git a/integrationtests/self/hotswap_test.go b/integrationtests/self/hotswap_test.go new file mode 100644 index 00000000000..b76bc772ee5 --- /dev/null +++ b/integrationtests/self/hotswap_test.go @@ -0,0 +1,190 @@ +package self_test + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "strconv" + "sync/atomic" + "time" + + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/http3" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/testdata" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" +) + +type listenerWrapper struct { + quic.EarlyListener + listenerClosed bool + count int32 +} + +func (ln *listenerWrapper) Close() error { + ln.listenerClosed = true + return ln.EarlyListener.Close() +} + +func (ln *listenerWrapper) Faker() *fakeClosingListener { + atomic.AddInt32(&ln.count, 1) + ctx, cancel := context.WithCancel(context.Background()) + return &fakeClosingListener{ln, 0, ctx, cancel} +} + +type fakeClosingListener struct { + *listenerWrapper + closed int32 + ctx context.Context + cancel context.CancelFunc +} + +func (ln *fakeClosingListener) Accept(ctx context.Context) (quic.EarlyConnection, error) { + Expect(ctx).To(Equal(context.Background())) + return ln.listenerWrapper.Accept(ln.ctx) +} + +func (ln *fakeClosingListener) Close() error { + if atomic.CompareAndSwapInt32(&ln.closed, 0, 1) { + ln.cancel() + if atomic.AddInt32(&ln.listenerWrapper.count, -1) == 0 { + ln.listenerWrapper.Close() + } + } + return nil +} + +var _ = Describe("HTTP3 Server hotswap test", func() { + var ( + mux1 *http.ServeMux + mux2 *http.ServeMux + client *http.Client + server1 *http3.Server + server2 *http3.Server + ln *listenerWrapper + port string + ) + + versions := protocol.SupportedVersions + + BeforeEach(func() { + mux1 = http.NewServeMux() + mux1.HandleFunc("/hello1", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + io.WriteString(w, "Hello, World 1!\n") // don't check the error here. Stream may be reset. + }) + + mux2 = http.NewServeMux() + mux2.HandleFunc("/hello2", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + io.WriteString(w, "Hello, World 2!\n") // don't check the error here. Stream may be reset. + }) + + server1 = &http3.Server{ + Server: &http.Server{ + Handler: mux1, + TLSConfig: testdata.GetTLSConfig(), + }, + QuicConfig: getQuicConfig(&quic.Config{Versions: versions}), + } + + server2 = &http3.Server{ + Server: &http.Server{ + Handler: mux2, + TLSConfig: testdata.GetTLSConfig(), + }, + QuicConfig: getQuicConfig(&quic.Config{Versions: versions}), + } + + tlsConf := http3.ConfigureTLSConfig(testdata.GetTLSConfig()) + quicln, err := quic.ListenAddrEarly("0.0.0.0:0", tlsConf, getQuicConfig(&quic.Config{Versions: versions})) + ln = &listenerWrapper{EarlyListener: quicln} + Expect(err).NotTo(HaveOccurred()) + port = strconv.Itoa(ln.Addr().(*net.UDPAddr).Port) + }) + + AfterEach(func() { + Expect(ln.Close()).NotTo(HaveOccurred()) + }) + + for _, v := range versions { + version := v + + Context(fmt.Sprintf("with QUIC version %s", version), func() { + BeforeEach(func() { + client = &http.Client{ + Transport: &http3.RoundTripper{ + TLSClientConfig: &tls.Config{ + RootCAs: testdata.GetRootCA(), + }, + DisableCompression: true, + QuicConfig: getQuicConfig(&quic.Config{ + Versions: []protocol.VersionNumber{version}, + MaxIdleTimeout: 10 * time.Second, + }), + }, + } + }) + + It("hotswap works", func() { + // open first server and make single request to it + fake1 := ln.Faker() + stoppedServing1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + server1.ServeListener(fake1) + close(stoppedServing1) + }() + + resp, err := client.Get("https://localhost:" + port + "/hello1") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal("Hello, World 1!\n")) + + // open second server with same underlying listener, + // make sure it opened and both servers are currently running + fake2 := ln.Faker() + stoppedServing2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + server2.ServeListener(fake2) + close(stoppedServing2) + }() + + Consistently(stoppedServing1).ShouldNot(BeClosed()) + Consistently(stoppedServing2).ShouldNot(BeClosed()) + + // now close first server, no errors should occur here + // and only the fake listener should be closed + Expect(server1.Close()).NotTo(HaveOccurred()) + Eventually(stoppedServing1).Should(BeClosed()) + Expect(fake1.closed).To(Equal(int32(1))) + Expect(fake2.closed).To(Equal(int32(0))) + Expect(ln.listenerClosed).ToNot(BeTrue()) + Expect(client.Transport.(*http3.RoundTripper).Close()).NotTo(HaveOccurred()) + + // verify that new connections are being initiated from the second server now + resp, err = client.Get("https://localhost:" + port + "/hello2") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + body, err = io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal("Hello, World 2!\n")) + + // close the other server - both the fake and the actual listeners must close now + Expect(server2.Close()).NotTo(HaveOccurred()) + Eventually(stoppedServing2).Should(BeClosed()) + Expect(fake2.closed).To(Equal(int32(1))) + Expect(ln.listenerClosed).To(BeTrue()) + }) + }) + } +}) diff --git a/integrationtests/self/key_update_test.go b/integrationtests/self/key_update_test.go index d9c83e594e1..4012f018184 100644 --- a/integrationtests/self/key_update_test.go +++ b/integrationtests/self/key_update_test.go @@ -6,7 +6,7 @@ import ( "io" "net" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/logging" @@ -66,9 +66,9 @@ var _ = Describe("Key Update tests", func() { go func() { defer GinkgoRecover() - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - str, err := sess.OpenUniStream() + str, err := conn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) defer str.Close() _, err = str.Write(PRDataLong) @@ -82,18 +82,18 @@ var _ = Describe("Key Update tests", func() { handshake.KeyUpdateInterval = 1 // update keys as frequently as possible runServer() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{Tracer: newTracer(func() logging.ConnectionTracer { return &keyUpdateConnTracer{} })}), ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(PRDataLong)) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) keyPhasesSent, keyPhasesReceived := countKeyPhases() fmt.Fprintf(GinkgoWriter, "Used %d key phases on outgoing and %d key phases on incoming packets.\n", keyPhasesSent, keyPhasesReceived) diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index b0078c0f54d..34bb14c6a3f 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -29,25 +29,25 @@ var _ = Describe("MITM test", func() { const connIDLen = 6 // explicitly set the connection ID length, so the proxy can parse it var ( - proxy *quicproxy.QuicProxy - serverConn, clientConn *net.UDPConn - serverSess quic.Session - serverConfig *quic.Config + proxy *quicproxy.QuicProxy + serverUDPConn, clientUDPConn *net.UDPConn + serverConn quic.Connection + serverConfig *quic.Config ) startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback) { addr, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) - serverConn, err = net.ListenUDP("udp", addr) + serverUDPConn, err = net.ListenUDP("udp", addr) Expect(err).ToNot(HaveOccurred()) - ln, err := quic.Listen(serverConn, getTLSConfig(), serverConfig) + ln, err := quic.Listen(serverUDPConn, getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) go func() { defer GinkgoRecover() var err error - serverSess, err = ln.Accept(context.Background()) + serverConn, err = ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - str, err := serverSess.OpenUniStream() + str, err := serverConn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write(PRData) Expect(err).ToNot(HaveOccurred()) @@ -69,17 +69,17 @@ var _ = Describe("MITM test", func() { }) addr, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) - clientConn, err = net.ListenUDP("udp", addr) + clientUDPConn, err = net.ListenUDP("udp", addr) Expect(err).ToNot(HaveOccurred()) }) Context("unsuccessful attacks", func() { AfterEach(func() { - Eventually(serverSess.Context().Done()).Should(BeClosed()) + Eventually(serverConn.Context().Done()).Should(BeClosed()) // Test shutdown is tricky due to the proxy. Just wait for a bit. time.Sleep(50 * time.Millisecond) - Expect(clientConn.Close()).To(Succeed()) - Expect(serverConn.Close()).To(Succeed()) + Expect(clientUDPConn.Close()).To(Succeed()) + Expect(serverUDPConn.Close()).To(Succeed()) Expect(proxy.Close()).To(Succeed()) }) @@ -123,8 +123,8 @@ var _ = Describe("MITM test", func() { startServerAndProxy(delayCb, nil) raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) Expect(err).ToNot(HaveOccurred()) - sess, err := quic.Dial( - clientConn, + conn, err := quic.Dial( + clientUDPConn, raddr, fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), @@ -134,19 +134,19 @@ var _ = Describe("MITM test", func() { }), ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(PRData)) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) } It("downloads a message when the packets are injected towards the server", func() { delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { if dir == quicproxy.DirectionIncoming { defer GinkgoRecover() - go sendRandomPacketsOfSameType(clientConn, serverConn.LocalAddr(), raw) + go sendRandomPacketsOfSameType(clientUDPConn, serverUDPConn.LocalAddr(), raw) } return rtt / 2 } @@ -157,7 +157,7 @@ var _ = Describe("MITM test", func() { delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { if dir == quicproxy.DirectionOutgoing { defer GinkgoRecover() - go sendRandomPacketsOfSameType(serverConn, clientConn.LocalAddr(), raw) + go sendRandomPacketsOfSameType(serverUDPConn, clientUDPConn.LocalAddr(), raw) } return rtt / 2 } @@ -169,8 +169,8 @@ var _ = Describe("MITM test", func() { startServerAndProxy(nil, dropCb) raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) Expect(err).ToNot(HaveOccurred()) - sess, err := quic.Dial( - clientConn, + conn, err := quic.Dial( + clientUDPConn, raddr, fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), @@ -180,12 +180,12 @@ var _ = Describe("MITM test", func() { }), ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(PRData)) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) } Context("duplicating packets", func() { @@ -193,7 +193,7 @@ var _ = Describe("MITM test", func() { dropCb := func(dir quicproxy.Direction, raw []byte) bool { defer GinkgoRecover() if dir == quicproxy.DirectionIncoming { - _, err := clientConn.WriteTo(raw, serverConn.LocalAddr()) + _, err := clientUDPConn.WriteTo(raw, serverUDPConn.LocalAddr()) Expect(err).ToNot(HaveOccurred()) } return false @@ -205,7 +205,7 @@ var _ = Describe("MITM test", func() { dropCb := func(dir quicproxy.Direction, raw []byte) bool { defer GinkgoRecover() if dir == quicproxy.DirectionOutgoing { - _, err := serverConn.WriteTo(raw, clientConn.LocalAddr()) + _, err := serverUDPConn.WriteTo(raw, clientUDPConn.LocalAddr()) Expect(err).ToNot(HaveOccurred()) } return false @@ -230,8 +230,8 @@ var _ = Describe("MITM test", func() { fmt.Fprintf(GinkgoWriter, "Corrupted %d of %d packets.", num, atomic.LoadInt32(&numPackets)) Expect(num).To(BeNumerically(">=", 1)) // If the packet containing the CONNECTION_CLOSE is corrupted, - // we have to wait for the session to time out. - Eventually(serverSess.Context().Done(), 3*idleTimeout).Should(BeClosed()) + // we have to wait for the connection to time out. + Eventually(serverConn.Context().Done(), 3*idleTimeout).Should(BeClosed()) }) It("downloads a message when packet are corrupted towards the server", func() { @@ -243,7 +243,7 @@ var _ = Describe("MITM test", func() { if mrand.Intn(interval) == 0 { pos := mrand.Intn(len(raw)) raw[pos] = byte(mrand.Intn(256)) - _, err := clientConn.WriteTo(raw, serverConn.LocalAddr()) + _, err := clientUDPConn.WriteTo(raw, serverUDPConn.LocalAddr()) Expect(err).ToNot(HaveOccurred()) atomic.AddInt32(&numCorrupted, 1) return true @@ -263,7 +263,7 @@ var _ = Describe("MITM test", func() { if mrand.Intn(interval) == 0 { pos := mrand.Intn(len(raw)) raw[pos] = byte(mrand.Intn(256)) - _, err := serverConn.WriteTo(raw, clientConn.LocalAddr()) + _, err := serverUDPConn.WriteTo(raw, clientUDPConn.LocalAddr()) Expect(err).ToNot(HaveOccurred()) atomic.AddInt32(&numCorrupted, 1) return true @@ -292,12 +292,12 @@ var _ = Describe("MITM test", func() { }) // sendForgedVersionNegotiationPacket sends a fake VN packet with no supported versions - // from serverConn to client's remoteAddr + // from serverUDPConn to client's remoteAddr // expects hdr from an Initial packet intercepted from client sendForgedVersionNegotationPacket := func(conn net.PacketConn, remoteAddr net.Addr, hdr *wire.Header) { // Create fake version negotiation packet with no supported versions versions := []protocol.VersionNumber{} - packet, _ := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, versions) + packet := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, versions) // Send the packet _, err := conn.WriteTo(packet, remoteAddr) @@ -305,7 +305,7 @@ var _ = Describe("MITM test", func() { } // sendForgedRetryPacket sends a fake Retry packet with a modified srcConnID - // from serverConn to client's remoteAddr + // from serverUDPConn to client's remoteAddr // expects hdr from an Initial packet intercepted from client sendForgedRetryPacket := func(conn net.PacketConn, remoteAddr net.Addr, hdr *wire.Header) { var x byte = 0x12 @@ -339,7 +339,7 @@ var _ = Describe("MITM test", func() { raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) Expect(err).ToNot(HaveOccurred()) _, err = quic.Dial( - clientConn, + clientUDPConn, raddr, fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), @@ -365,7 +365,7 @@ var _ = Describe("MITM test", func() { return 0 } - sendForgedVersionNegotationPacket(serverConn, clientConn.LocalAddr(), hdr) + sendForgedVersionNegotationPacket(serverUDPConn, clientUDPConn.LocalAddr(), hdr) } return rtt / 2 } @@ -392,7 +392,7 @@ var _ = Describe("MITM test", func() { } initialPacketIntercepted = true - sendForgedRetryPacket(serverConn, clientConn.LocalAddr(), hdr) + sendForgedRetryPacket(serverUDPConn, clientUDPConn.LocalAddr(), hdr) } return rtt / 2 } @@ -416,7 +416,7 @@ var _ = Describe("MITM test", func() { return 0 } - sendForgedInitialPacket(serverConn, clientConn.LocalAddr(), hdr) + sendForgedInitialPacket(serverUDPConn, clientUDPConn.LocalAddr(), hdr) } return rtt } @@ -427,7 +427,7 @@ var _ = Describe("MITM test", func() { // client connection closes immediately on receiving ack for unsent packet It("fails when a forged initial packet with ack for unsent packet is sent to client", func() { - clientAddr := clientConn.LocalAddr() + clientAddr := clientUDPConn.LocalAddr() delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { if dir == quicproxy.DirectionIncoming { hdr, _, _, err := wire.ParsePacket(raw, connIDLen) @@ -435,7 +435,7 @@ var _ = Describe("MITM test", func() { if hdr.Type != protocol.PacketTypeInitial { return 0 } - sendForgedInitialPacketWithAck(serverConn, clientAddr, hdr) + sendForgedInitialPacketWithAck(serverUDPConn, clientAddr, hdr) } return rtt } diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index 492e91fa53e..1e89ef074bc 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -8,7 +8,7 @@ import ( "runtime" "time" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" @@ -24,13 +24,13 @@ var _ = Describe("Multiplexing", func() { go func() { defer GinkgoRecover() for { - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) if err != nil { return } go func() { defer GinkgoRecover() - str, err := sess.OpenStream() + str, err := conn.OpenStream() Expect(err).ToNot(HaveOccurred()) defer str.Close() _, err = str.Write(PRData) @@ -40,17 +40,17 @@ var _ = Describe("Multiplexing", func() { }() } - dial := func(conn net.PacketConn, addr net.Addr) { - sess, err := quic.Dial( - conn, + dial := func(pconn net.PacketConn, addr net.Addr) { + conn, err := quic.Dial( + pconn, addr, fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), ) Expect(err).ToNot(HaveOccurred()) - defer sess.CloseWithError(0, "") - str, err := sess.AcceptStream(context.Background()) + defer conn.CloseWithError(0, "") + str, err := conn.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/packetization_test.go b/integrationtests/self/packetization_test.go index 48297a0a638..d7d1c65908c 100644 --- a/integrationtests/self/packetization_test.go +++ b/integrationtests/self/packetization_test.go @@ -45,7 +45,7 @@ var _ = Describe("Packetization", func() { defer proxy.Close() clientTracer := newPacketTracer() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{ @@ -57,9 +57,9 @@ var _ = Describe("Packetization", func() { go func() { defer GinkgoRecover() - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptStream(context.Background()) + str, err := conn.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) b := make([]byte, 1) // Echo every byte received from the client. @@ -72,7 +72,7 @@ var _ = Describe("Packetization", func() { } }() - str, err := sess.OpenStreamSync(context.Background()) + str, err := conn.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) b := make([]byte, 1) // Send numMsg 1-byte messages. @@ -83,7 +83,7 @@ var _ = Describe("Packetization", func() { Expect(err).ToNot(HaveOccurred()) Expect(b[0]).To(Equal(uint8(i))) } - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) countBundledPackets := func(packets []packet) (numBundled int) { for _, p := range packets { diff --git a/integrationtests/self/resumption_test.go b/integrationtests/self/resumption_test.go index fbccd4a181f..10c554e6397 100644 --- a/integrationtests/self/resumption_test.go +++ b/integrationtests/self/resumption_test.go @@ -7,7 +7,7 @@ import ( "net" "sync" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -57,7 +57,7 @@ var _ = Describe("TLS session resumption", func() { cache := newClientSessionCache(gets, puts) tlsConf := getTLSClientConfig() tlsConf.ClientSessionCache = cache - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, nil, @@ -65,24 +65,24 @@ var _ = Describe("TLS session resumption", func() { Expect(err).ToNot(HaveOccurred()) var sessionKey string Eventually(puts).Should(Receive(&sessionKey)) - Expect(sess.ConnectionState().TLS.DidResume).To(BeFalse()) + Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) - serverSess, err := server.Accept(context.Background()) + serverConn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(serverSess.ConnectionState().TLS.DidResume).To(BeFalse()) + Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) - sess, err = quic.DialAddr( + conn, err = quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, nil, ) Expect(err).ToNot(HaveOccurred()) Expect(gets).To(Receive(Equal(sessionKey))) - Expect(sess.ConnectionState().TLS.DidResume).To(BeTrue()) + Expect(conn.ConnectionState().TLS.DidResume).To(BeTrue()) - serverSess, err = server.Accept(context.Background()) + serverConn, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(serverSess.ConnectionState().TLS.DidResume).To(BeTrue()) + Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue()) }) It("doesn't use session resumption, if the config disables it", func() { @@ -97,29 +97,29 @@ var _ = Describe("TLS session resumption", func() { cache := newClientSessionCache(gets, puts) tlsConf := getTLSClientConfig() tlsConf.ClientSessionCache = cache - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, nil, ) Expect(err).ToNot(HaveOccurred()) Consistently(puts).ShouldNot(Receive()) - Expect(sess.ConnectionState().TLS.DidResume).To(BeFalse()) + Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) - serverSess, err := server.Accept(context.Background()) + serverConn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(serverSess.ConnectionState().TLS.DidResume).To(BeFalse()) + Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) - sess, err = quic.DialAddr( + conn, err = quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), tlsConf, nil, ) Expect(err).ToNot(HaveOccurred()) - Expect(sess.ConnectionState().TLS.DidResume).To(BeFalse()) + Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) - serverSess, err = server.Accept(context.Background()) + serverConn, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(serverSess.ConnectionState().TLS.DidResume).To(BeFalse()) + Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) }) }) diff --git a/integrationtests/self/rtt_test.go b/integrationtests/self/rtt_test.go index c2f86eb459e..c6e3324eeb3 100644 --- a/integrationtests/self/rtt_test.go +++ b/integrationtests/self/rtt_test.go @@ -7,7 +7,7 @@ import ( "net" "time" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -28,9 +28,9 @@ var _ = Describe("non-zero RTT", func() { Expect(err).ToNot(HaveOccurred()) go func() { defer GinkgoRecover() - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - str, err := sess.OpenStream() + str, err := conn.OpenStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write(PRData) Expect(err).ToNot(HaveOccurred()) @@ -40,18 +40,18 @@ var _ = Describe("non-zero RTT", func() { } downloadFile := func(port int) { - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", port), getTLSClientConfig(), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptStream(context.Background()) + str, err := conn.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(PRData)) - sess.CloseWithError(0, "") + conn.CloseWithError(0, "") } Context(fmt.Sprintf("with QUIC version %s", version), func() { @@ -76,18 +76,18 @@ var _ = Describe("non-zero RTT", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptStream(context.Background()) + str, err := conn.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(PRData)) - sess.CloseWithError(0, "") + conn.CloseWithError(0, "") }) } diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go index 2f08fc7b894..dba2172172d 100644 --- a/integrationtests/self/stateless_reset_test.go +++ b/integrationtests/self/stateless_reset_test.go @@ -35,9 +35,9 @@ var _ = Describe("Stateless Resets", func() { go func() { defer GinkgoRecover() - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - str, err := sess.OpenStream() + str, err := conn.OpenStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) @@ -56,7 +56,7 @@ var _ = Describe("Stateless Resets", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{ @@ -65,7 +65,7 @@ var _ = Describe("Stateless Resets", func() { }), ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptStream(context.Background()) + str, err := conn.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data := make([]byte, 6) _, err = str.Read(data) diff --git a/integrationtests/self/stream_test.go b/integrationtests/self/stream_test.go index b67ddc2230b..f5dd917a8f4 100644 --- a/integrationtests/self/stream_test.go +++ b/integrationtests/self/stream_test.go @@ -7,7 +7,7 @@ import ( "net" "sync" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" @@ -42,11 +42,11 @@ var _ = Describe("Bidirectional streams", func() { server.Close() }) - runSendingPeer := func(sess quic.Session) { + runSendingPeer := func(conn quic.Connection) { var wg sync.WaitGroup wg.Add(numStreams) for i := 0; i < numStreams; i++ { - str, err := sess.OpenStreamSync(context.Background()) + str, err := conn.OpenStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) data := GeneratePRData(25 * i) go func() { @@ -66,11 +66,11 @@ var _ = Describe("Bidirectional streams", func() { wg.Wait() } - runReceivingPeer := func(sess quic.Session) { + runReceivingPeer := func(conn quic.Connection) { var wg sync.WaitGroup wg.Add(numStreams) for i := 0; i < numStreams; i++ { - str, err := sess.AcceptStream(context.Background()) + str, err := conn.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) go func() { defer GinkgoRecover() @@ -88,13 +88,13 @@ var _ = Describe("Bidirectional streams", func() { } It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() { - var sess quic.Session + var conn quic.Connection go func() { defer GinkgoRecover() var err error - sess, err = server.Accept(context.Background()) + conn, err = server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - runReceivingPeer(sess) + runReceivingPeer(conn) }() client, err := quic.DialAddr( @@ -109,10 +109,10 @@ var _ = Describe("Bidirectional streams", func() { It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() { go func() { defer GinkgoRecover() - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - runSendingPeer(sess) - sess.CloseWithError(0, "") + runSendingPeer(conn) + conn.CloseWithError(0, "") }() client, err := quic.DialAddr( @@ -129,15 +129,15 @@ var _ = Describe("Bidirectional streams", func() { done1 := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) go func() { defer GinkgoRecover() - runReceivingPeer(sess) + runReceivingPeer(conn) close(done) }() - runSendingPeer(sess) + runSendingPeer(conn) <-done close(done1) }() diff --git a/integrationtests/self/timeout_test.go b/integrationtests/self/timeout_test.go index 52a10bc22ac..cd7b82ca2f6 100644 --- a/integrationtests/self/timeout_test.go +++ b/integrationtests/self/timeout_test.go @@ -125,9 +125,9 @@ var _ = Describe("Timeout tests", func() { go func() { defer GinkgoRecover() - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - str, err := sess.OpenStream() + str, err := conn.OpenStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) @@ -144,15 +144,15 @@ var _ = Describe("Timeout tests", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, MaxIdleTimeout: idleTimeout}), ) Expect(err).ToNot(HaveOccurred()) - strIn, err := sess.AcceptStream(context.Background()) + strIn, err := conn.AcceptStream(context.Background()) Expect(err).ToNot(HaveOccurred()) - strOut, err := sess.OpenStream() + strOut, err := conn.OpenStream() Expect(err).ToNot(HaveOccurred()) _, err = strIn.Read(make([]byte, 6)) Expect(err).ToNot(HaveOccurred()) @@ -167,13 +167,13 @@ var _ = Describe("Timeout tests", func() { checkTimeoutError(err) _, err = strOut.Read([]byte{0}) checkTimeoutError(err) - _, err = sess.OpenStream() + _, err = conn.OpenStream() checkTimeoutError(err) - _, err = sess.OpenUniStream() + _, err = conn.OpenUniStream() checkTimeoutError(err) - _, err = sess.AcceptStream(context.Background()) + _, err = conn.AcceptStream(context.Background()) checkTimeoutError(err) - _, err = sess.AcceptUniStream(context.Background()) + _, err = conn.AcceptUniStream(context.Background()) checkTimeoutError(err) }) @@ -193,17 +193,17 @@ var _ = Describe("Timeout tests", func() { Expect(err).ToNot(HaveOccurred()) defer server.Close() - serverSessionClosed := make(chan struct{}) + serverConnClosed := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - sess.AcceptStream(context.Background()) // blocks until the session is closed - close(serverSessionClosed) + conn.AcceptStream(context.Background()) // blocks until the connection is closed + close(serverConnClosed) }() tr := newPacketTracer() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{ @@ -216,7 +216,7 @@ var _ = Describe("Timeout tests", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := sess.AcceptStream(context.Background()) + _, err := conn.AcceptStream(context.Background()) checkTimeoutError(err) close(done) }() @@ -244,11 +244,11 @@ var _ = Describe("Timeout tests", func() { BeNumerically(">=", idleTimeout), BeNumerically("<", idleTimeout*6/5), )) - Consistently(serverSessionClosed).ShouldNot(BeClosed()) + Consistently(serverConnClosed).ShouldNot(BeClosed()) // make the go routine return Expect(server.Close()).To(Succeed()) - Eventually(serverSessionClosed).Should(BeClosed()) + Eventually(serverConnClosed).Should(BeClosed()) }) It("times out after sending a packet", func() { @@ -273,16 +273,16 @@ var _ = Describe("Timeout tests", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - serverSessionClosed := make(chan struct{}) + serverConnClosed := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - <-sess.Context().Done() // block until the session is closed - close(serverSessionClosed) + <-conn.Context().Done() // block until the connection is closed + close(serverConnClosed) }() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout, DisablePathMTUDiscovery: true}), @@ -292,7 +292,7 @@ var _ = Describe("Timeout tests", func() { // wait half the idle timeout, then send a packet time.Sleep(idleTimeout / 2) drop.Set(true) - str, err := sess.OpenUniStream() + str, err := conn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) @@ -302,7 +302,7 @@ var _ = Describe("Timeout tests", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - _, err := sess.AcceptStream(context.Background()) + _, err := conn.AcceptStream(context.Background()) checkTimeoutError(err) close(done) }() @@ -312,11 +312,11 @@ var _ = Describe("Timeout tests", func() { BeNumerically(">=", idleTimeout), BeNumerically("<", idleTimeout*12/10), )) - Consistently(serverSessionClosed).ShouldNot(BeClosed()) + Consistently(serverConnClosed).ShouldNot(BeClosed()) // make the go routine return Expect(server.Close()).To(Succeed()) - Eventually(serverSessionClosed).Should(BeClosed()) + Eventually(serverConnClosed).Should(BeClosed()) }) }) @@ -331,13 +331,13 @@ var _ = Describe("Timeout tests", func() { Expect(err).ToNot(HaveOccurred()) defer server.Close() - serverSessionClosed := make(chan struct{}) + serverConnClosed := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - sess.AcceptStream(context.Background()) // blocks until the session is closed - close(serverSessionClosed) + conn.AcceptStream(context.Background()) // blocks until the connection is closed + close(serverConnClosed) }() drop := utils.AtomicBool{} @@ -350,7 +350,7 @@ var _ = Describe("Timeout tests", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), getQuicConfig(&quic.Config{ @@ -363,11 +363,11 @@ var _ = Describe("Timeout tests", func() { // wait longer than the idle timeout time.Sleep(3 * idleTimeout) - str, err := sess.OpenUniStream() + str, err := conn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) - Consistently(serverSessionClosed).ShouldNot(BeClosed()) + Consistently(serverConnClosed).ShouldNot(BeClosed()) // idle timeout will still kick in if pings are dropped drop.Set(true) @@ -376,7 +376,7 @@ var _ = Describe("Timeout tests", func() { checkTimeoutError(err) Expect(server.Close()).To(Succeed()) - Eventually(serverSessionClosed).Should(BeClosed()) + Eventually(serverConnClosed).Should(BeClosed()) }) Context("faulty packet conns", func() { @@ -391,11 +391,11 @@ var _ = Describe("Timeout tests", func() { }) runServer := func(ln quic.Listener) error { - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) if err != nil { return err } - str, err := sess.OpenUniStream() + str, err := conn.OpenUniStream() if err != nil { return err } @@ -404,8 +404,8 @@ var _ = Describe("Timeout tests", func() { return err } - runClient := func(sess quic.Session) error { - str, err := sess.AcceptUniStream(context.Background()) + runClient := func(conn quic.Connection) error { + str, err := conn.AcceptUniStream(context.Background()) if err != nil { return err } @@ -414,7 +414,7 @@ var _ = Describe("Timeout tests", func() { return err } Expect(data).To(Equal(PRData)) - return sess.CloseWithError(0, "done") + return conn.CloseWithError(0, "done") } It("deals with an erroring packet conn, on the server side", func() { @@ -440,7 +440,7 @@ var _ = Describe("Timeout tests", func() { clientErrChan := make(chan error, 1) go func() { defer GinkgoRecover() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), getQuicConfig(&quic.Config{ @@ -453,7 +453,7 @@ var _ = Describe("Timeout tests", func() { clientErrChan <- err return } - clientErrChan <- runClient(sess) + clientErrChan <- runClient(conn) }() var clientErr error @@ -501,7 +501,7 @@ var _ = Describe("Timeout tests", func() { clientErrChan := make(chan error, 1) go func() { defer GinkgoRecover() - sess, err := quic.Dial( + conn, err := quic.Dial( &faultyConn{PacketConn: conn, MaxPackets: maxPackets}, ln.Addr(), "localhost", @@ -512,7 +512,7 @@ var _ = Describe("Timeout tests", func() { clientErrChan <- err return } - clientErrChan <- runClient(sess) + clientErrChan <- runClient(conn) }() var clientErr error diff --git a/integrationtests/self/tracer_test.go b/integrationtests/self/tracer_test.go index b0903e8d5b5..f244381c8aa 100644 --- a/integrationtests/self/tracer_test.go +++ b/integrationtests/self/tracer_test.go @@ -115,9 +115,9 @@ var _ = Describe("Handshake tests", func() { ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), quicServerConf) Expect(err).ToNot(HaveOccurred()) serverChan <- ln - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - str, err := sess.OpenUniStream() + str, err := conn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write(PRData) Expect(err).ToNot(HaveOccurred()) @@ -127,14 +127,14 @@ var _ = Describe("Handshake tests", func() { ln := <-serverChan defer ln.Close() - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), quicClientConf, ) Expect(err).ToNot(HaveOccurred()) - defer sess.CloseWithError(0, "") - str, err := sess.AcceptUniStream(context.Background()) + defer conn.CloseWithError(0, "") + str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) Expect(err).ToNot(HaveOccurred()) diff --git a/integrationtests/self/uni_stream_test.go b/integrationtests/self/uni_stream_test.go index db63da80377..9251d8c9b8f 100644 --- a/integrationtests/self/uni_stream_test.go +++ b/integrationtests/self/uni_stream_test.go @@ -7,7 +7,7 @@ import ( "net" "sync" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" @@ -39,9 +39,9 @@ var _ = Describe("Unidirectional Streams", func() { return GeneratePRData(10 * int(id)) } - runSendingPeer := func(sess quic.Session) { + runSendingPeer := func(conn quic.Connection) { for i := 0; i < numStreams; i++ { - str, err := sess.OpenUniStreamSync(context.Background()) + str, err := conn.OpenUniStreamSync(context.Background()) Expect(err).ToNot(HaveOccurred()) go func() { defer GinkgoRecover() @@ -52,11 +52,11 @@ var _ = Describe("Unidirectional Streams", func() { } } - runReceivingPeer := func(sess quic.Session) { + runReceivingPeer := func(conn quic.Connection) { var wg sync.WaitGroup wg.Add(numStreams) for i := 0; i < numStreams; i++ { - str, err := sess.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) go func() { defer GinkgoRecover() @@ -72,10 +72,10 @@ var _ = Describe("Unidirectional Streams", func() { It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() { go func() { defer GinkgoRecover() - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - runReceivingPeer(sess) - sess.CloseWithError(0, "") + runReceivingPeer(conn) + conn.CloseWithError(0, "") }() client, err := quic.DialAddr( @@ -91,9 +91,9 @@ var _ = Describe("Unidirectional Streams", func() { It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() { go func() { defer GinkgoRecover() - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - runSendingPeer(sess) + runSendingPeer(conn) }() client, err := quic.DialAddr( @@ -109,15 +109,15 @@ var _ = Describe("Unidirectional Streams", func() { done1 := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := server.Accept(context.Background()) + conn, err := server.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) go func() { defer GinkgoRecover() - runReceivingPeer(sess) + runReceivingPeer(conn) close(done) }() - runSendingPeer(sess) + runSendingPeer(conn) <-done close(done1) }() diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index ec7a8575913..274a9183932 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -75,21 +75,21 @@ var _ = Describe("0-RTT", func() { Expect(err).ToNot(HaveOccurred()) defer proxy.Close() - // dial the first session in order to receive a session ticket + // dial the first connection in order to receive a session ticket done := make(chan struct{}) go func() { defer GinkgoRecover() defer close(done) - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - <-sess.Context().Done() + <-conn.Context().Done() }() clientConf := getTLSClientConfig() gets := make(chan string, 100) puts := make(chan string, 100) clientConf.ClientSessionCache = newClientSessionCache(gets, puts) - sess, err := quic.DialAddr( + conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalPort()), clientConf, getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), @@ -97,7 +97,7 @@ var _ = Describe("0-RTT", func() { Expect(err).ToNot(HaveOccurred()) Eventually(puts).Should(Receive()) // received the session ticket. We're done here. - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) Eventually(done).Should(BeClosed()) return tlsConf, clientConf } @@ -109,40 +109,40 @@ var _ = Describe("0-RTT", func() { clientConf *quic.Config, testdata []byte, // data to transfer ) { - // now dial the second session, and use 0-RTT to send some data + // now dial the second connection, and use 0-RTT to send some data done := make(chan struct{}) go func() { defer GinkgoRecover() - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(testdata)) - Expect(sess.ConnectionState().TLS.Used0RTT).To(BeTrue()) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) close(done) }() if clientConf == nil { clientConf = getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}) } - sess, err := quic.DialAddrEarly( + conn, err := quic.DialAddrEarly( fmt.Sprintf("localhost:%d", proxyPort), clientTLSConf, clientConf, ) Expect(err).ToNot(HaveOccurred()) - defer sess.CloseWithError(0, "") - str, err := sess.OpenUniStream() + defer conn.CloseWithError(0, "") + str, err := conn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write(testdata) Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) - Expect(sess.ConnectionState().TLS.Used0RTT).To(BeTrue()) + Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) Eventually(done).Should(BeClosed()) - Eventually(sess.Context().Done()).Should(BeClosed()) + Eventually(conn.Context().Done()).Should(BeClosed()) } check0RTTRejected := func( @@ -150,29 +150,29 @@ var _ = Describe("0-RTT", func() { proxyPort int, clientConf *tls.Config, ) { - sess, err := quic.DialAddrEarly( + conn, err := quic.DialAddrEarly( fmt.Sprintf("localhost:%d", proxyPort), clientConf, getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.OpenUniStream() + str, err := conn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write(make([]byte, 3000)) Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) - Expect(sess.ConnectionState().TLS.Used0RTT).To(BeFalse()) + Expect(conn.ConnectionState().TLS.Used0RTT).To(BeFalse()) // make sure the server doesn't process the data ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond)) defer cancel() - serverSess, err := ln.Accept(ctx) + serverConn, err := ln.Accept(ctx) Expect(err).ToNot(HaveOccurred()) - Expect(serverSess.ConnectionState().TLS.Used0RTT).To(BeFalse()) - _, err = serverSess.AcceptUniStream(ctx) + Expect(serverConn.ConnectionState().TLS.Used0RTT).To(BeFalse()) + _, err = serverConn.AcceptUniStream(ctx) Expect(err).To(Equal(context.DeadlineExceeded)) - Expect(serverSess.CloseWithError(0, "")).To(Succeed()) - Eventually(sess.Context().Done()).Should(BeClosed()) + Expect(serverConn.CloseWithError(0, "")).To(Succeed()) + Eventually(conn.Context().Done()).Should(BeClosed()) } // can be used to extract 0-RTT from a packetTracer @@ -244,7 +244,7 @@ var _ = Describe("0-RTT", func() { } // Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets. - It("waits until a session until the handshake is done", func() { + It("waits for a connection until the handshake is done", func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) zeroRTTData := GeneratePRData(2 * 1100) // 2 packets @@ -263,28 +263,28 @@ var _ = Describe("0-RTT", func() { Expect(err).ToNot(HaveOccurred()) defer ln.Close() - // now dial the second session, and use 0-RTT to send some data + // now dial the second connection, and use 0-RTT to send some data go func() { defer GinkgoRecover() - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - str, err := sess.AcceptUniStream(context.Background()) + str, err := conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(str) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(zeroRTTData)) - str, err = sess.AcceptUniStream(context.Background()) + str, err = conn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err = io.ReadAll(str) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(oneRTTData)) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) }() proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) defer proxy.Close() - sess, err := quic.DialAddrEarly( + conn, err := quic.DialAddrEarly( fmt.Sprintf("localhost:%d", proxy.LocalPort()), clientConf, getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), @@ -294,7 +294,7 @@ var _ = Describe("0-RTT", func() { go func() { defer GinkgoRecover() defer close(sent0RTT) - str, err := sess.OpenUniStream() + str, err := conn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write(zeroRTTData) Expect(err).ToNot(HaveOccurred()) @@ -303,13 +303,13 @@ var _ = Describe("0-RTT", func() { Eventually(sent0RTT).Should(BeClosed()) // wait for the handshake to complete - Eventually(sess.HandshakeComplete().Done()).Should(BeClosed()) - str, err := sess.OpenUniStream() + Eventually(conn.HandshakeComplete().Done()).Should(BeClosed()) + str, err := conn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write(PRData) Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) - <-sess.Context().Done() + <-conn.Context().Done() num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) @@ -472,27 +472,27 @@ var _ = Describe("0-RTT", func() { proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) defer proxy.Close() - sess, err := quic.DialAddrEarly( + conn, err := quic.DialAddrEarly( fmt.Sprintf("localhost:%d", proxy.LocalPort()), clientConf, getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.OpenUniStream() + str, err := conn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) _, err = str.Write([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) // The client remembers the old limit and refuses to open a new stream. - _, err = sess.OpenUniStream() + _, err = conn.OpenUniStream() Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("too many open streams")) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - _, err = sess.OpenUniStreamSync(ctx) + _, err = conn.OpenUniStreamSync(ctx) Expect(err).ToNot(HaveOccurred()) - Expect(sess.ConnectionState().TLS.Used0RTT).To(BeTrue()) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) }) It("rejects 0-RTT when the server's stream limit decreased", func() { @@ -582,13 +582,13 @@ var _ = Describe("0-RTT", func() { proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) defer proxy.Close() - sess, err := quic.DialAddrEarly( + conn, err := quic.DialAddrEarly( fmt.Sprintf("localhost:%d", proxy.LocalPort()), clientConf, getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), ) Expect(err).ToNot(HaveOccurred()) - str, err := sess.OpenUniStream() + str, err := conn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) written := make(chan struct{}) go func() { @@ -601,16 +601,16 @@ var _ = Describe("0-RTT", func() { Eventually(written).Should(BeClosed()) - serverSess, err := ln.Accept(context.Background()) + serverConn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - rstr, err := serverSess.AcceptUniStream(context.Background()) + rstr, err := serverConn.AcceptUniStream(context.Background()) Expect(err).ToNot(HaveOccurred()) data, err := io.ReadAll(rstr) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal([]byte("foobar"))) - Expect(serverSess.ConnectionState().TLS.Used0RTT).To(BeTrue()) - Expect(serverSess.CloseWithError(0, "")).To(Succeed()) - Eventually(sess.Context().Done()).Should(BeClosed()) + Expect(serverConn.ConnectionState().TLS.Used0RTT).To(BeTrue()) + Expect(serverConn.CloseWithError(0, "")).To(Succeed()) + Eventually(conn.Context().Done()).Should(BeClosed()) var processedFirst bool for _, p := range tracer.getRcvdPackets() { @@ -656,14 +656,14 @@ var _ = Describe("0-RTT", func() { proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) defer proxy.Close() - sess, err := quic.DialAddrEarly( + conn, err := quic.DialAddrEarly( fmt.Sprintf("localhost:%d", proxy.LocalPort()), clientConf, getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), ) Expect(err).ToNot(HaveOccurred()) // The client remembers that it was allowed to open 2 uni-directional streams. - firstStr, err := sess.OpenUniStream() + firstStr, err := conn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) written := make(chan struct{}, 2) go func() { @@ -672,7 +672,7 @@ var _ = Describe("0-RTT", func() { _, err := firstStr.Write([]byte("first flight")) Expect(err).ToNot(HaveOccurred()) }() - secondStr, err := sess.OpenUniStream() + secondStr, err := conn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) go func() { defer GinkgoRecover() @@ -683,28 +683,28 @@ var _ = Describe("0-RTT", func() { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - _, err = sess.AcceptStream(ctx) + _, err = conn.AcceptStream(ctx) Expect(err).To(MatchError(quic.Err0RTTRejected)) Eventually(written).Should(Receive()) Eventually(written).Should(Receive()) _, err = firstStr.Write([]byte("foobar")) Expect(err).To(MatchError(quic.Err0RTTRejected)) - _, err = sess.OpenUniStream() + _, err = conn.OpenUniStream() Expect(err).To(MatchError(quic.Err0RTTRejected)) - _, err = sess.AcceptStream(ctx) + _, err = conn.AcceptStream(ctx) Expect(err).To(Equal(quic.Err0RTTRejected)) - newSess := sess.NextSession() - str, err := newSess.OpenUniStream() + newConn := conn.NextConnection() + str, err := newConn.OpenUniStream() Expect(err).ToNot(HaveOccurred()) - _, err = newSess.OpenUniStream() + _, err = newConn.OpenUniStream() Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("too many open streams")) _, err = str.Write([]byte("second flight")) Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) - Expect(sess.CloseWithError(0, "")).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) // The client should send 0-RTT packets, but the server doesn't process them. num0RTT := atomic.LoadUint32(num0RTTPackets) diff --git a/interface.go b/interface.go index 6381e9eddf9..cb1c1de3a39 100644 --- a/interface.go +++ b/interface.go @@ -59,15 +59,15 @@ type TokenStore interface { // when the server rejects a 0-RTT connection attempt. var Err0RTTRejected = errors.New("0-RTT rejected") -// SessionTracingKey can be used to associate a ConnectionTracer with a Session. -// It is set on the Session.Context() context, +// ConnectionTracingKey can be used to associate a ConnectionTracer with a Connection. +// It is set on the Connection.Context() context, // as well as on the context passed to logging.Tracer.NewConnectionTracer. -var SessionTracingKey = sessionTracingCtxKey{} +var ConnectionTracingKey = connTracingCtxKey{} -type sessionTracingCtxKey struct{} +type connTracingCtxKey struct{} // Stream is the interface implemented by QUIC streams -// In addition to the errors listed on the Session, +// In addition to the errors listed on the Connection, // calls to stream functions can return a StreamError if the stream is canceled. type Stream interface { ReceiveStream @@ -87,7 +87,7 @@ type ReceiveStream interface { // after a fixed time limit; see SetDeadline and SetReadDeadline. // If the stream was canceled by the peer, the error implements the StreamError // interface, and Canceled() == true. - // If the session was closed due to a timeout, the error satisfies + // If the connection was closed due to a timeout, the error satisfies // the net.Error interface, and Timeout() will be true. io.Reader // CancelRead aborts receiving on this stream. @@ -111,7 +111,7 @@ type SendStream interface { // after a fixed time limit; see SetDeadline and SetWriteDeadline. // If the stream was canceled by the peer, the error implements the StreamError // interface, and Canceled() == true. - // If the session was closed due to a timeout, the error satisfies + // If the connection was closed due to a timeout, the error satisfies // the net.Error interface, and Timeout() will be true. io.Writer // Close closes the write-direction of the stream. @@ -127,7 +127,6 @@ type SendStream interface { // The Context is canceled as soon as the write-side of the stream is closed. // This happens when Close() or CancelWrite() is called, or when the peer // cancels the read-side of their stream. - // Warning: This API should not be considered stable and might change soon. Context() context.Context // SetWriteDeadline sets the deadline for future Write calls // and any currently-blocked Write call. @@ -137,21 +136,21 @@ type SendStream interface { SetWriteDeadline(t time.Time) error } -// A Session is a QUIC connection between two peers. -// Calls to the session (and to streams) can return the following types of errors: +// A Connection is a QUIC connection between two peers. +// Calls to the connection (and to streams) can return the following types of errors: // * ApplicationError: for errors triggered by the application running on top of QUIC // * TransportError: for errors triggered by the QUIC transport (in many cases a misbehaving peer) // * IdleTimeoutError: when the peer goes away unexpectedly (this is a net.Error timeout error) // * HandshakeTimeoutError: when the cryptographic handshake takes too long (this is a net.Error timeout error) // * StatelessResetError: when we receive a stateless reset (this is a net.Error temporary error) // * VersionNegotiationError: returned by the client, when there's no version overlap between the peers -type Session interface { +type Connection interface { // AcceptStream returns the next stream opened by the peer, blocking until one is available. - // If the session was closed due to a timeout, the error satisfies + // If the connection was closed due to a timeout, the error satisfies // the net.Error interface, and Timeout() will be true. AcceptStream(context.Context) (Stream, error) // AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available. - // If the session was closed due to a timeout, the error satisfies + // If the connection was closed due to a timeout, the error satisfies // the net.Error interface, and Timeout() will be true. AcceptUniStream(context.Context) (ReceiveStream, error) // OpenStream opens a new bidirectional QUIC stream. @@ -159,22 +158,22 @@ type Session interface { // The peer can only accept the stream after data has been sent on the stream. // If the error is non-nil, it satisfies the net.Error interface. // When reaching the peer's stream limit, err.Temporary() will be true. - // If the session was closed due to a timeout, Timeout() will be true. + // If the connection was closed due to a timeout, Timeout() will be true. OpenStream() (Stream, error) // OpenStreamSync opens a new bidirectional QUIC stream. // It blocks until a new stream can be opened. // If the error is non-nil, it satisfies the net.Error interface. - // If the session was closed due to a timeout, Timeout() will be true. + // If the connection was closed due to a timeout, Timeout() will be true. OpenStreamSync(context.Context) (Stream, error) // OpenUniStream opens a new outgoing unidirectional QUIC stream. // If the error is non-nil, it satisfies the net.Error interface. // When reaching the peer's stream limit, Temporary() will be true. - // If the session was closed due to a timeout, Timeout() will be true. + // If the connection was closed due to a timeout, Timeout() will be true. OpenUniStream() (SendStream, error) // OpenUniStreamSync opens a new outgoing unidirectional QUIC stream. // It blocks until a new stream can be opened. // If the error is non-nil, it satisfies the net.Error interface. - // If the session was closed due to a timeout, Timeout() will be true. + // If the connection was closed due to a timeout, Timeout() will be true. OpenUniStreamSync(context.Context) (SendStream, error) // LocalAddr returns the local address. LocalAddr() net.Addr @@ -183,42 +182,38 @@ type Session interface { // CloseWithError closes the connection with an error. // The error string will be sent to the peer. CloseWithError(ApplicationErrorCode, string) error - // The context is cancelled when the session is closed. - // Warning: This API should not be considered stable and might change soon. + // The context is cancelled when the connection is closed. Context() context.Context // ConnectionState returns basic details about the QUIC connection. // It blocks until the handshake completes. // Warning: This API should not be considered stable and might change soon. ConnectionState() ConnectionState - // SendMessage sends a message as a datagram. - // See https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/. + // SendMessage sends a message as a datagram, as specified in RFC 9221. SendMessage([]byte) error - // ReceiveMessage gets a message received in a datagram. - // See https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/. + // ReceiveMessage gets a message received in a datagram, as specified in RFC 9221. ReceiveMessage() ([]byte, error) } -// An EarlySession is a session that is handshaking. +// An EarlyConnection is a connection that is handshaking. // Data sent during the handshake is encrypted using the forward secure keys. // When using client certificates, the client's identity is only verified // after completion of the handshake. -type EarlySession interface { - Session +type EarlyConnection interface { + Connection // HandshakeComplete blocks until the handshake completes (or fails). // Data sent before completion of the handshake is encrypted with 1-RTT keys. // Note that the client's identity hasn't been verified yet. HandshakeComplete() context.Context - NextSession() Session + NextConnection() Connection } // Config contains all configuration data needed for a QUIC server or client. type Config struct { // The QUIC versions that can be negotiated. // If not set, it uses all versions available. - // Warning: This API should not be considered stable and will change soon. Versions []VersionNumber // The length of the connection ID in bytes. // It can be 0, or any value between 4 and 18. @@ -270,9 +265,9 @@ type Config struct { // to increase the connection flow control window. // If set, the caller can prevent an increase of the window. Typically, it would do so to // limit the memory usage. - // To avoid deadlocks, it is not valid to call other functions on the session or on streams + // To avoid deadlocks, it is not valid to call other functions on the connection or on streams // in this callback. - AllowConnectionWindowIncrease func(sess Session, delta uint64) bool + AllowConnectionWindowIncrease func(sess Connection, delta uint64) bool // MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open. // Values above 2^60 are invalid. // If not set, it will default to 100. @@ -290,7 +285,7 @@ type Config struct { KeepAlive bool // DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899). // Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size. - // Note that Path MTU discovery is always disabled on Windows, see https://github.com/lucas-clemente/quic-go/issues/3273. + // Note that if Path MTU discovery is causing issues on your system, please open a new issue DisablePathMTUDiscovery bool // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. // This can be useful if version information is exchanged out-of-band. @@ -310,21 +305,21 @@ type ConnectionState struct { // A Listener for incoming QUIC connections type Listener interface { - // Close the server. All active sessions will be closed. + // Close the server. All active connections will be closed. Close() error // Addr returns the local network addr that the server is listening on. Addr() net.Addr - // Accept returns new sessions. It should be called in a loop. - Accept(context.Context) (Session, error) + // Accept returns new connections. It should be called in a loop. + Accept(context.Context) (Connection, error) } // An EarlyListener listens for incoming QUIC connections, // and returns them before the handshake completes. type EarlyListener interface { - // Close the server. All active sessions will be closed. + // Close the server. All active connections will be closed. Close() error // Addr returns the local network addr that the server is listening on. Addr() net.Addr - // Accept returns new early sessions. It should be called in a loop. - Accept(context.Context) (EarlySession, error) + // Accept returns new early connections. It should be called in a loop. + Accept(context.Context) (EarlyConnection, error) } diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index ec0e4e1e40d..5720321a8f8 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -260,7 +260,8 @@ var _ = Describe("Crypto Setup TLS", func() { } handshake := func(client CryptoSetup, cChunkChan <-chan chunk, - server CryptoSetup, sChunkChan <-chan chunk) { + server CryptoSetup, sChunkChan <-chan chunk, + ) { done := make(chan struct{}) go func() { defer GinkgoRecover() diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go index 7221396d939..84372a1a00b 100644 --- a/internal/mocks/mockgen.go +++ b/internal/mocks/mockgen.go @@ -1,7 +1,7 @@ package mocks //go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream" -//go:generate sh -c "mockgen -package mockquic -destination quic/early_session_tmp.go github.com/lucas-clemente/quic-go EarlySession && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_session_tmp.go > quic/early_session.go && rm quic/early_session_tmp.go && goimports -w quic/early_session.go" +//go:generate sh -c "mockgen -package mockquic -destination quic/early_conn_tmp.go github.com/lucas-clemente/quic-go EarlyConnection && sed 's/qtls.ConnectionState/quic.ConnectionState/g' quic/early_conn_tmp.go > quic/early_conn.go && rm quic/early_conn_tmp.go && goimports -w quic/early_conn.go" //go:generate sh -c "mockgen -package mockquic -destination quic/early_listener.go github.com/lucas-clemente/quic-go EarlyListener" //go:generate sh -c "mockgen -package mocklogging -destination logging/tracer.go github.com/lucas-clemente/quic-go/logging Tracer" //go:generate sh -c "mockgen -package mocklogging -destination logging/connection_tracer.go github.com/lucas-clemente/quic-go/logging ConnectionTracer" diff --git a/internal/mocks/quic/early_session.go b/internal/mocks/quic/early_conn.go similarity index 54% rename from internal/mocks/quic/early_session.go rename to internal/mocks/quic/early_conn.go index ef09723344f..6db02300112 100644 --- a/internal/mocks/quic/early_session.go +++ b/internal/mocks/quic/early_conn.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/lucas-clemente/quic-go (interfaces: EarlySession) +// Source: github.com/lucas-clemente/quic-go (interfaces: EarlyConnection) // Package mockquic is a generated GoMock package. package mockquic @@ -14,31 +14,31 @@ import ( qerr "github.com/lucas-clemente/quic-go/internal/qerr" ) -// MockEarlySession is a mock of EarlySession interface. -type MockEarlySession struct { +// MockEarlyConnection is a mock of EarlyConnection interface. +type MockEarlyConnection struct { ctrl *gomock.Controller - recorder *MockEarlySessionMockRecorder + recorder *MockEarlyConnectionMockRecorder } -// MockEarlySessionMockRecorder is the mock recorder for MockEarlySession. -type MockEarlySessionMockRecorder struct { - mock *MockEarlySession +// MockEarlyConnectionMockRecorder is the mock recorder for MockEarlyConnection. +type MockEarlyConnectionMockRecorder struct { + mock *MockEarlyConnection } -// NewMockEarlySession creates a new mock instance. -func NewMockEarlySession(ctrl *gomock.Controller) *MockEarlySession { - mock := &MockEarlySession{ctrl: ctrl} - mock.recorder = &MockEarlySessionMockRecorder{mock} +// NewMockEarlyConnection creates a new mock instance. +func NewMockEarlyConnection(ctrl *gomock.Controller) *MockEarlyConnection { + mock := &MockEarlyConnection{ctrl: ctrl} + mock.recorder = &MockEarlyConnectionMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockEarlySession) EXPECT() *MockEarlySessionMockRecorder { +func (m *MockEarlyConnection) EXPECT() *MockEarlyConnectionMockRecorder { return m.recorder } // AcceptStream mocks base method. -func (m *MockEarlySession) AcceptStream(arg0 context.Context) (quic.Stream, error) { +func (m *MockEarlyConnection) AcceptStream(arg0 context.Context) (quic.Stream, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AcceptStream", arg0) ret0, _ := ret[0].(quic.Stream) @@ -47,13 +47,13 @@ func (m *MockEarlySession) AcceptStream(arg0 context.Context) (quic.Stream, erro } // AcceptStream indicates an expected call of AcceptStream. -func (mr *MockEarlySessionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockEarlySession)(nil).AcceptStream), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockEarlyConnection)(nil).AcceptStream), arg0) } // AcceptUniStream mocks base method. -func (m *MockEarlySession) AcceptUniStream(arg0 context.Context) (quic.ReceiveStream, error) { +func (m *MockEarlyConnection) AcceptUniStream(arg0 context.Context) (quic.ReceiveStream, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AcceptUniStream", arg0) ret0, _ := ret[0].(quic.ReceiveStream) @@ -62,13 +62,13 @@ func (m *MockEarlySession) AcceptUniStream(arg0 context.Context) (quic.ReceiveSt } // AcceptUniStream indicates an expected call of AcceptUniStream. -func (mr *MockEarlySessionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockEarlySession)(nil).AcceptUniStream), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockEarlyConnection)(nil).AcceptUniStream), arg0) } // CloseWithError mocks base method. -func (m *MockEarlySession) CloseWithError(arg0 qerr.ApplicationErrorCode, arg1 string) error { +func (m *MockEarlyConnection) CloseWithError(arg0 qerr.ApplicationErrorCode, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CloseWithError", arg0, arg1) ret0, _ := ret[0].(error) @@ -76,13 +76,13 @@ func (m *MockEarlySession) CloseWithError(arg0 qerr.ApplicationErrorCode, arg1 s } // CloseWithError indicates an expected call of CloseWithError. -func (mr *MockEarlySessionMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockEarlySession)(nil).CloseWithError), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockEarlyConnection)(nil).CloseWithError), arg0, arg1) } // ConnectionState mocks base method. -func (m *MockEarlySession) ConnectionState() quic.ConnectionState { +func (m *MockEarlyConnection) ConnectionState() quic.ConnectionState { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ConnectionState") ret0, _ := ret[0].(quic.ConnectionState) @@ -90,13 +90,13 @@ func (m *MockEarlySession) ConnectionState() quic.ConnectionState { } // ConnectionState indicates an expected call of ConnectionState. -func (mr *MockEarlySessionMockRecorder) ConnectionState() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) ConnectionState() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockEarlySession)(nil).ConnectionState)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockEarlyConnection)(nil).ConnectionState)) } // Context mocks base method. -func (m *MockEarlySession) Context() context.Context { +func (m *MockEarlyConnection) Context() context.Context { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Context") ret0, _ := ret[0].(context.Context) @@ -104,13 +104,13 @@ func (m *MockEarlySession) Context() context.Context { } // Context indicates an expected call of Context. -func (mr *MockEarlySessionMockRecorder) Context() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) Context() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockEarlySession)(nil).Context)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockEarlyConnection)(nil).Context)) } // HandshakeComplete mocks base method. -func (m *MockEarlySession) HandshakeComplete() context.Context { +func (m *MockEarlyConnection) HandshakeComplete() context.Context { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HandshakeComplete") ret0, _ := ret[0].(context.Context) @@ -118,13 +118,13 @@ func (m *MockEarlySession) HandshakeComplete() context.Context { } // HandshakeComplete indicates an expected call of HandshakeComplete. -func (mr *MockEarlySessionMockRecorder) HandshakeComplete() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) HandshakeComplete() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockEarlySession)(nil).HandshakeComplete)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockEarlyConnection)(nil).HandshakeComplete)) } // LocalAddr mocks base method. -func (m *MockEarlySession) LocalAddr() net.Addr { +func (m *MockEarlyConnection) LocalAddr() net.Addr { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LocalAddr") ret0, _ := ret[0].(net.Addr) @@ -132,27 +132,27 @@ func (m *MockEarlySession) LocalAddr() net.Addr { } // LocalAddr indicates an expected call of LocalAddr. -func (mr *MockEarlySessionMockRecorder) LocalAddr() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) LocalAddr() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockEarlySession)(nil).LocalAddr)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockEarlyConnection)(nil).LocalAddr)) } -// NextSession mocks base method. -func (m *MockEarlySession) NextSession() quic.Session { +// NextConnection mocks base method. +func (m *MockEarlyConnection) NextConnection() quic.Connection { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NextSession") - ret0, _ := ret[0].(quic.Session) + ret := m.ctrl.Call(m, "NextConnection") + ret0, _ := ret[0].(quic.Connection) return ret0 } -// NextSession indicates an expected call of NextSession. -func (mr *MockEarlySessionMockRecorder) NextSession() *gomock.Call { +// NextConnection indicates an expected call of NextConnection. +func (mr *MockEarlyConnectionMockRecorder) NextConnection() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextSession", reflect.TypeOf((*MockEarlySession)(nil).NextSession)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockEarlyConnection)(nil).NextConnection)) } // OpenStream mocks base method. -func (m *MockEarlySession) OpenStream() (quic.Stream, error) { +func (m *MockEarlyConnection) OpenStream() (quic.Stream, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OpenStream") ret0, _ := ret[0].(quic.Stream) @@ -161,13 +161,13 @@ func (m *MockEarlySession) OpenStream() (quic.Stream, error) { } // OpenStream indicates an expected call of OpenStream. -func (mr *MockEarlySessionMockRecorder) OpenStream() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) OpenStream() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockEarlySession)(nil).OpenStream)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockEarlyConnection)(nil).OpenStream)) } // OpenStreamSync mocks base method. -func (m *MockEarlySession) OpenStreamSync(arg0 context.Context) (quic.Stream, error) { +func (m *MockEarlyConnection) OpenStreamSync(arg0 context.Context) (quic.Stream, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OpenStreamSync", arg0) ret0, _ := ret[0].(quic.Stream) @@ -176,13 +176,13 @@ func (m *MockEarlySession) OpenStreamSync(arg0 context.Context) (quic.Stream, er } // OpenStreamSync indicates an expected call of OpenStreamSync. -func (mr *MockEarlySessionMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockEarlySession)(nil).OpenStreamSync), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockEarlyConnection)(nil).OpenStreamSync), arg0) } // OpenUniStream mocks base method. -func (m *MockEarlySession) OpenUniStream() (quic.SendStream, error) { +func (m *MockEarlyConnection) OpenUniStream() (quic.SendStream, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OpenUniStream") ret0, _ := ret[0].(quic.SendStream) @@ -191,13 +191,13 @@ func (m *MockEarlySession) OpenUniStream() (quic.SendStream, error) { } // OpenUniStream indicates an expected call of OpenUniStream. -func (mr *MockEarlySessionMockRecorder) OpenUniStream() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) OpenUniStream() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockEarlySession)(nil).OpenUniStream)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockEarlyConnection)(nil).OpenUniStream)) } // OpenUniStreamSync mocks base method. -func (m *MockEarlySession) OpenUniStreamSync(arg0 context.Context) (quic.SendStream, error) { +func (m *MockEarlyConnection) OpenUniStreamSync(arg0 context.Context) (quic.SendStream, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0) ret0, _ := ret[0].(quic.SendStream) @@ -206,13 +206,13 @@ func (m *MockEarlySession) OpenUniStreamSync(arg0 context.Context) (quic.SendStr } // OpenUniStreamSync indicates an expected call of OpenUniStreamSync. -func (mr *MockEarlySessionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockEarlySession)(nil).OpenUniStreamSync), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockEarlyConnection)(nil).OpenUniStreamSync), arg0) } // ReceiveMessage mocks base method. -func (m *MockEarlySession) ReceiveMessage() ([]byte, error) { +func (m *MockEarlyConnection) ReceiveMessage() ([]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ReceiveMessage") ret0, _ := ret[0].([]byte) @@ -221,13 +221,13 @@ func (m *MockEarlySession) ReceiveMessage() ([]byte, error) { } // ReceiveMessage indicates an expected call of ReceiveMessage. -func (mr *MockEarlySessionMockRecorder) ReceiveMessage() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) ReceiveMessage() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockEarlySession)(nil).ReceiveMessage)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockEarlyConnection)(nil).ReceiveMessage)) } // RemoteAddr mocks base method. -func (m *MockEarlySession) RemoteAddr() net.Addr { +func (m *MockEarlyConnection) RemoteAddr() net.Addr { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RemoteAddr") ret0, _ := ret[0].(net.Addr) @@ -235,13 +235,13 @@ func (m *MockEarlySession) RemoteAddr() net.Addr { } // RemoteAddr indicates an expected call of RemoteAddr. -func (mr *MockEarlySessionMockRecorder) RemoteAddr() *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) RemoteAddr() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockEarlySession)(nil).RemoteAddr)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockEarlyConnection)(nil).RemoteAddr)) } // SendMessage mocks base method. -func (m *MockEarlySession) SendMessage(arg0 []byte) error { +func (m *MockEarlyConnection) SendMessage(arg0 []byte) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SendMessage", arg0) ret0, _ := ret[0].(error) @@ -249,7 +249,7 @@ func (m *MockEarlySession) SendMessage(arg0 []byte) error { } // SendMessage indicates an expected call of SendMessage. -func (mr *MockEarlySessionMockRecorder) SendMessage(arg0 interface{}) *gomock.Call { +func (mr *MockEarlyConnectionMockRecorder) SendMessage(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockEarlySession)(nil).SendMessage), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockEarlyConnection)(nil).SendMessage), arg0) } diff --git a/internal/mocks/quic/early_listener.go b/internal/mocks/quic/early_listener.go index 395a1b196ae..279096b86d1 100644 --- a/internal/mocks/quic/early_listener.go +++ b/internal/mocks/quic/early_listener.go @@ -37,10 +37,10 @@ func (m *MockEarlyListener) EXPECT() *MockEarlyListenerMockRecorder { } // Accept mocks base method. -func (m *MockEarlyListener) Accept(arg0 context.Context) (quic.EarlySession, error) { +func (m *MockEarlyListener) Accept(arg0 context.Context) (quic.EarlyConnection, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Accept", arg0) - ret0, _ := ret[0].(quic.EarlySession) + ret0, _ := ret[0].(quic.EarlyConnection) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/internal/protocol/params.go b/internal/protocol/params.go index 4bc33e27461..83137113977 100644 --- a/internal/protocol/params.go +++ b/internal/protocol/params.go @@ -14,7 +14,7 @@ const InitialPacketSizeIPv6 = 1232 // MaxCongestionWindowPackets is the maximum congestion window in packet. const MaxCongestionWindowPackets = 10000 -// MaxUndecryptablePackets limits the number of undecryptable packets that are queued in the session. +// MaxUndecryptablePackets limits the number of undecryptable packets that are queued in the connection. const MaxUndecryptablePackets = 32 // ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window @@ -45,8 +45,8 @@ const DefaultMaxIncomingUniStreams = 100 // MaxServerUnprocessedPackets is the max number of packets stored in the server that are not yet processed. const MaxServerUnprocessedPackets = 1024 -// MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed. -const MaxSessionUnprocessedPackets = 256 +// MaxConnUnprocessedPackets is the max number of packets stored in each connection that are not yet processed. +const MaxConnUnprocessedPackets = 256 // SkipPacketInitialPeriod is the initial period length used for packet number skipping to prevent an Optimistic ACK attack. // Every time a packet number is skipped, the period is doubled, up to SkipPacketMaxPeriod. @@ -55,7 +55,7 @@ const SkipPacketInitialPeriod PacketNumber = 256 // SkipPacketMaxPeriod is the maximum period length used for packet number skipping. const SkipPacketMaxPeriod PacketNumber = 128 * 1024 -// MaxAcceptQueueSize is the maximum number of sessions that the server queues for accepting. +// MaxAcceptQueueSize is the maximum number of connections that the server queues for accepting. // If the queue is full, new connection attempts will be rejected. const MaxAcceptQueueSize = 32 @@ -112,7 +112,7 @@ const DefaultHandshakeTimeout = 10 * time.Second // It should be shorter than the time that NATs clear their mapping. const MaxKeepAliveInterval = 20 * time.Second -// RetiredConnectionIDDeleteTimeout is the time we keep closed sessions around in order to retransmit the CONNECTION_CLOSE. +// RetiredConnectionIDDeleteTimeout is the time we keep closed connections around in order to retransmit the CONNECTION_CLOSE. // after this time all information about the old connection will be deleted const RetiredConnectionIDDeleteTimeout = 5 * time.Second @@ -132,13 +132,11 @@ const MaxPostHandshakeCryptoFrameSize = 1000 // but must ensure that a maximum size ACK frame fits into one packet. const MaxAckFrameSize ByteCount = 1000 -// MaxDatagramFrameSize is the maximum size of a DATAGRAM frame as defined in -// https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/. +// MaxDatagramFrameSize is the maximum size of a DATAGRAM frame (RFC 9221). // The size is chosen such that a DATAGRAM frame fits into a QUIC packet. const MaxDatagramFrameSize ByteCount = 1220 -// DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames. -// See https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/. +// DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames (RFC 9221) const DatagramRcvQueueLen = 128 // MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame. @@ -189,7 +187,7 @@ const Max0RTTQueueingDuration = 100 * time.Millisecond const Max0RTTQueues = 32 // Max0RTTQueueLen is the maximum number of 0-RTT packets that we buffer for each connection. -// When a new session is created, all buffered packets are passed to the session immediately. -// To avoid blocking, this value has to be smaller than MaxSessionUnprocessedPackets. -// To avoid packets being dropped as undecryptable by the session, this value has to be smaller than MaxUndecryptablePackets. +// When a new connection is created, all buffered packets are passed to the connection immediately. +// To avoid blocking, this value has to be smaller than MaxConnUnprocessedPackets. +// To avoid packets being dropped as undecryptable by the connection, this value has to be smaller than MaxUndecryptablePackets. const Max0RTTQueueLen = 31 diff --git a/internal/protocol/params_test.go b/internal/protocol/params_test.go index b144054a6b4..50a260d2742 100644 --- a/internal/protocol/params_test.go +++ b/internal/protocol/params_test.go @@ -7,7 +7,7 @@ import ( var _ = Describe("Parameters", func() { It("can queue more packets in the session than in the 0-RTT queue", func() { - Expect(MaxSessionUnprocessedPackets).To(BeNumerically(">", Max0RTTQueueLen)) + Expect(MaxConnUnprocessedPackets).To(BeNumerically(">", Max0RTTQueueLen)) Expect(MaxUndecryptablePackets).To(BeNumerically(">", Max0RTTQueueLen)) }) }) diff --git a/internal/qerr/errors_test.go b/internal/qerr/errors_test.go index b2a3cca5faf..5a825b8c4be 100644 --- a/internal/qerr/errors_test.go +++ b/internal/qerr/errors_test.go @@ -73,7 +73,6 @@ var _ = Describe("QUIC Errors", func() { nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) Expect(nerr.Timeout()).To(BeTrue()) - Expect(nerr.Temporary()).To(BeFalse()) Expect(err.Error()).To(Equal("timeout: handshake did not complete in time")) }) @@ -84,7 +83,6 @@ var _ = Describe("QUIC Errors", func() { nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) Expect(nerr.Timeout()).To(BeTrue()) - Expect(nerr.Temporary()).To(BeFalse()) Expect(err.Error()).To(Equal("timeout: no recent network activity")) }) }) @@ -112,7 +110,6 @@ var _ = Describe("QUIC Errors", func() { nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) Expect(nerr.Timeout()).To(BeFalse()) - Expect(nerr.Temporary()).To(BeTrue()) }) }) diff --git a/internal/qtls/go119.go b/internal/qtls/go119.go index 2c648639e3c..87e7132e5ab 100644 --- a/internal/qtls/go119.go +++ b/internal/qtls/go119.go @@ -3,4 +3,4 @@ package qtls -var _ int = "quic-go doesn't build on Go 1.19 yet." +var _ int = "The version of quic-go you're using can't be built on Go 1.19 yet. For more details, please see https://github.com/lucas-clemente/quic-go/wiki/quic-go-and-Go-versions." diff --git a/internal/qtls/go_oldversion.go b/internal/qtls/go_oldversion.go new file mode 100644 index 00000000000..384d719c6e8 --- /dev/null +++ b/internal/qtls/go_oldversion.go @@ -0,0 +1,7 @@ +//go:build (go1.9 || go1.10 || go1.11 || go1.12 || go1.13 || go1.14 || go1.15) && !go1.16 +// +build go1.9 go1.10 go1.11 go1.12 go1.13 go1.14 go1.15 +// +build !go1.16 + +package qtls + +var _ int = "The version of quic-go you're using can't be built using outdated Go versions. For more details, please see https://github.com/lucas-clemente/quic-go/wiki/quic-go-and-Go-versions." diff --git a/internal/utils/rand_test.go b/internal/utils/rand_test.go index 4e865c371f8..f15a644e432 100644 --- a/internal/utils/rand_test.go +++ b/internal/utils/rand_test.go @@ -9,7 +9,7 @@ var _ = Describe("Rand", func() { It("generates random numbers", func() { const ( num = 1000 - max = 123456 + max = 12345678 ) var values [num]int32 diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index b7e0a8c9dc1..e1f83cd6564 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -42,7 +42,7 @@ const ( activeConnectionIDLimitParameterID transportParameterID = 0xe initialSourceConnectionIDParameterID transportParameterID = 0xf retrySourceConnectionIDParameterID transportParameterID = 0x10 - // https://datatracker.ietf.org/doc/draft-ietf-quic-datagram/ + // RFC 9221 maxDatagramFrameSizeParameterID transportParameterID = 0x20 ) diff --git a/internal/wire/version_negotiation.go b/internal/wire/version_negotiation.go index bcae87d17f7..196853e0fc4 100644 --- a/internal/wire/version_negotiation.go +++ b/internal/wire/version_negotiation.go @@ -35,7 +35,7 @@ func ParseVersionNegotiationPacket(b *bytes.Reader) (*Header, []protocol.Version } // ComposeVersionNegotiation composes a Version Negotiation -func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []protocol.VersionNumber) ([]byte, error) { +func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { greasedVersions := protocol.GetGreasedVersions(versions) expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4 buf := bytes.NewBuffer(make([]byte, 0, expectedLen)) @@ -50,5 +50,5 @@ func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, vers for _, v := range greasedVersions { utils.BigEndian.WriteUint32(buf, uint32(v)) } - return buf.Bytes(), nil + return buf.Bytes() } diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index d3a0b7e6215..31ad5d93f86 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -36,20 +36,18 @@ var _ = Describe("Version Negotiation Packets", func() { It("errors if it contains versions of the wrong length", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{0x22334455, 0x33445566} - data, err := ComposeVersionNegotiation(connID, connID, versions) - Expect(err).ToNot(HaveOccurred()) - _, _, err = ParseVersionNegotiationPacket(bytes.NewReader(data[:len(data)-2])) + data := ComposeVersionNegotiation(connID, connID, versions) + _, _, err := ParseVersionNegotiationPacket(bytes.NewReader(data[:len(data)-2])) Expect(err).To(MatchError("Version Negotiation packet has a version list with an invalid length")) }) It("errors if the version list is empty", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{0x22334455} - data, err := ComposeVersionNegotiation(connID, connID, versions) - Expect(err).ToNot(HaveOccurred()) + data := ComposeVersionNegotiation(connID, connID, versions) // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number data = data[:len(data)-8] - _, _, err = ParseVersionNegotiationPacket(bytes.NewReader(data)) + _, _, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) Expect(err).To(MatchError("Version Negotiation packet has empty version list")) }) @@ -57,8 +55,7 @@ var _ = Describe("Version Negotiation Packets", func() { srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{1001, 1003} - data, err := ComposeVersionNegotiation(destConnID, srcConnID, versions) - Expect(err).ToNot(HaveOccurred()) + data := ComposeVersionNegotiation(destConnID, srcConnID, versions) Expect(data[0] & 0x80).ToNot(BeZero()) hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) Expect(err).ToNot(HaveOccurred()) diff --git a/interop/Dockerfile b/interop/Dockerfile index 206d4f8952c..f961d90e4fb 100644 --- a/interop/Dockerfile +++ b/interop/Dockerfile @@ -2,9 +2,9 @@ FROM martenseemann/quic-network-simulator-endpoint:latest AS builder RUN apt-get update && apt-get install -y wget tar git -RUN wget https://dl.google.com/go/go1.17.linux-amd64.tar.gz && \ - tar xfz go1.17.linux-amd64.tar.gz && \ - rm go1.17.linux-amd64.tar.gz +RUN wget https://dl.google.com/go/go1.18.linux-amd64.tar.gz && \ + tar xfz go1.18.linux-amd64.tar.gz && \ + rm go1.18.linux-amd64.tar.gz ENV PATH="/go/bin:${PATH}" diff --git a/interop/http09/client.go b/interop/http09/client.go index 14f45d3e022..cf2a03276bf 100644 --- a/interop/http09/client.go +++ b/interop/http09/client.go @@ -84,25 +84,25 @@ type client struct { quicConf *quic.Config once sync.Once - sess quic.EarlySession + conn quic.EarlyConnection dialErr error } func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { c.once.Do(func() { - c.sess, c.dialErr = quic.DialAddrEarly(c.hostname, c.tlsConf, c.quicConf) + c.conn, c.dialErr = quic.DialAddrEarly(c.hostname, c.tlsConf, c.quicConf) }) if c.dialErr != nil { return nil, c.dialErr } if req.Method != MethodGet0RTT { - <-c.sess.HandshakeComplete().Done() + <-c.conn.HandshakeComplete().Done() } return c.doRequest(req) } func (c *client) doRequest(req *http.Request) (*http.Response, error) { - str, err := c.sess.OpenStreamSync(context.Background()) + str, err := c.conn.OpenStreamSync(context.Background()) if err != nil { return nil, err } @@ -124,10 +124,10 @@ func (c *client) doRequest(req *http.Request) (*http.Response, error) { } func (c *client) Close() error { - if c.sess == nil { + if c.conn == nil { return nil } - return c.sess.CloseWithError(0, "") + return c.conn.CloseWithError(0, "") } func hostnameFromRequest(req *http.Request) string { diff --git a/interop/http09/server.go b/interop/http09/server.go index a30e85c7b35..59665bd64ca 100644 --- a/interop/http09/server.go +++ b/interop/http09/server.go @@ -78,17 +78,17 @@ func (s *Server) ListenAndServe() error { s.mutex.Unlock() for { - sess, err := ln.Accept(context.Background()) + conn, err := ln.Accept(context.Background()) if err != nil { return err } - go s.handleConn(sess) + go s.handleConn(conn) } } -func (s *Server) handleConn(sess quic.Session) { +func (s *Server) handleConn(conn quic.Connection) { for { - str, err := sess.AcceptStream(context.Background()) + str, err := conn.AcceptStream(context.Background()) if err != nil { log.Printf("Error accepting stream: %s\n", err.Error()) return diff --git a/logging/types.go b/logging/types.go index e18865033ec..ad800692353 100644 --- a/logging/types.go +++ b/logging/types.go @@ -68,14 +68,14 @@ const ( TimerTypePTO ) -// TimeoutReason is the reason why a session is closed +// TimeoutReason is the reason why a connection is closed type TimeoutReason uint8 const ( - // TimeoutReasonHandshake is used when the session is closed due to a handshake timeout + // TimeoutReasonHandshake is used when the connection is closed due to a handshake timeout // This reason is not defined in the qlog draft, but very useful for debugging. TimeoutReasonHandshake TimeoutReason = iota - // TimeoutReasonIdle is used when the session is closed due to an idle timeout + // TimeoutReasonIdle is used when the connection is closed due to an idle timeout // This reason is not defined in the qlog draft, but very useful for debugging. TimeoutReasonIdle ) @@ -87,7 +87,7 @@ const ( CongestionStateSlowStart CongestionState = iota // CongestionStateCongestionAvoidance is the slow start phase of Reno / Cubic CongestionStateCongestionAvoidance - // CongestionStateCongestionAvoidance is the recovery phase of Reno / Cubic + // CongestionStateRecovery is the recovery phase of Reno / Cubic CongestionStateRecovery // CongestionStateApplicationLimited means that the congestion controller is application limited CongestionStateApplicationLimited diff --git a/mock_batch_conn_test.go b/mock_batch_conn_test.go index e3e0db676de..74032900089 100644 --- a/mock_batch_conn_test.go +++ b/mock_batch_conn_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: conn_oob.go +// Source: sys_conn_oob.go // Package quic is a generated GoMock package. package quic diff --git a/mock_session_runner_test.go b/mock_conn_runner_test.go similarity index 50% rename from mock_session_runner_test.go rename to mock_conn_runner_test.go index e51a1f809fa..607bd027444 100644 --- a/mock_session_runner_test.go +++ b/mock_conn_runner_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: session.go +// Source: connection.go // Package quic is a generated GoMock package. package quic @@ -11,31 +11,31 @@ import ( protocol "github.com/lucas-clemente/quic-go/internal/protocol" ) -// MockSessionRunner is a mock of SessionRunner interface. -type MockSessionRunner struct { +// MockConnRunner is a mock of ConnRunner interface. +type MockConnRunner struct { ctrl *gomock.Controller - recorder *MockSessionRunnerMockRecorder + recorder *MockConnRunnerMockRecorder } -// MockSessionRunnerMockRecorder is the mock recorder for MockSessionRunner. -type MockSessionRunnerMockRecorder struct { - mock *MockSessionRunner +// MockConnRunnerMockRecorder is the mock recorder for MockConnRunner. +type MockConnRunnerMockRecorder struct { + mock *MockConnRunner } -// NewMockSessionRunner creates a new mock instance. -func NewMockSessionRunner(ctrl *gomock.Controller) *MockSessionRunner { - mock := &MockSessionRunner{ctrl: ctrl} - mock.recorder = &MockSessionRunnerMockRecorder{mock} +// NewMockConnRunner creates a new mock instance. +func NewMockConnRunner(ctrl *gomock.Controller) *MockConnRunner { + mock := &MockConnRunner{ctrl: ctrl} + mock.recorder = &MockConnRunnerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSessionRunner) EXPECT() *MockSessionRunnerMockRecorder { +func (m *MockConnRunner) EXPECT() *MockConnRunnerMockRecorder { return m.recorder } // Add mocks base method. -func (m *MockSessionRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) bool { +func (m *MockConnRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Add", arg0, arg1) ret0, _ := ret[0].(bool) @@ -43,25 +43,25 @@ func (m *MockSessionRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) } // Add indicates an expected call of Add. -func (mr *MockSessionRunnerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSessionRunner)(nil).Add), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockConnRunner)(nil).Add), arg0, arg1) } // AddResetToken mocks base method. -func (m *MockSessionRunner) AddResetToken(arg0 protocol.StatelessResetToken, arg1 packetHandler) { +func (m *MockConnRunner) AddResetToken(arg0 protocol.StatelessResetToken, arg1 packetHandler) { m.ctrl.T.Helper() m.ctrl.Call(m, "AddResetToken", arg0, arg1) } // AddResetToken indicates an expected call of AddResetToken. -func (mr *MockSessionRunnerMockRecorder) AddResetToken(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) AddResetToken(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockSessionRunner)(nil).AddResetToken), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockConnRunner)(nil).AddResetToken), arg0, arg1) } // GetStatelessResetToken mocks base method. -func (m *MockSessionRunner) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken { +func (m *MockConnRunner) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0) ret0, _ := ret[0].(protocol.StatelessResetToken) @@ -69,55 +69,55 @@ func (m *MockSessionRunner) GetStatelessResetToken(arg0 protocol.ConnectionID) p } // GetStatelessResetToken indicates an expected call of GetStatelessResetToken. -func (mr *MockSessionRunnerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockSessionRunner)(nil).GetStatelessResetToken), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockConnRunner)(nil).GetStatelessResetToken), arg0) } // Remove mocks base method. -func (m *MockSessionRunner) Remove(arg0 protocol.ConnectionID) { +func (m *MockConnRunner) Remove(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() m.ctrl.Call(m, "Remove", arg0) } // Remove indicates an expected call of Remove. -func (mr *MockSessionRunnerMockRecorder) Remove(arg0 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) Remove(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSessionRunner)(nil).Remove), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockConnRunner)(nil).Remove), arg0) } // RemoveResetToken mocks base method. -func (m *MockSessionRunner) RemoveResetToken(arg0 protocol.StatelessResetToken) { +func (m *MockConnRunner) RemoveResetToken(arg0 protocol.StatelessResetToken) { m.ctrl.T.Helper() m.ctrl.Call(m, "RemoveResetToken", arg0) } // RemoveResetToken indicates an expected call of RemoveResetToken. -func (mr *MockSessionRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockSessionRunner)(nil).RemoveResetToken), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockConnRunner)(nil).RemoveResetToken), arg0) } // ReplaceWithClosed mocks base method. -func (m *MockSessionRunner) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) { +func (m *MockConnRunner) ReplaceWithClosed(arg0 protocol.ConnectionID, arg1 packetHandler) { m.ctrl.T.Helper() m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1) } // ReplaceWithClosed indicates an expected call of ReplaceWithClosed. -func (mr *MockSessionRunnerMockRecorder) ReplaceWithClosed(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockSessionRunner)(nil).ReplaceWithClosed), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockConnRunner)(nil).ReplaceWithClosed), arg0, arg1) } // Retire mocks base method. -func (m *MockSessionRunner) Retire(arg0 protocol.ConnectionID) { +func (m *MockConnRunner) Retire(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() m.ctrl.Call(m, "Retire", arg0) } // Retire indicates an expected call of Retire. -func (mr *MockSessionRunnerMockRecorder) Retire(arg0 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) Retire(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockSessionRunner)(nil).Retire), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockConnRunner)(nil).Retire), arg0) } diff --git a/mock_quic_session_test.go b/mock_quic_conn_test.go similarity index 55% rename from mock_quic_session_test.go rename to mock_quic_conn_test.go index 08c03c3072d..880f1dd1446 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_conn_test.go @@ -13,31 +13,31 @@ import ( protocol "github.com/lucas-clemente/quic-go/internal/protocol" ) -// MockQuicSession is a mock of QuicSession interface. -type MockQuicSession struct { +// MockQuicConn is a mock of QuicConn interface. +type MockQuicConn struct { ctrl *gomock.Controller - recorder *MockQuicSessionMockRecorder + recorder *MockQuicConnMockRecorder } -// MockQuicSessionMockRecorder is the mock recorder for MockQuicSession. -type MockQuicSessionMockRecorder struct { - mock *MockQuicSession +// MockQuicConnMockRecorder is the mock recorder for MockQuicConn. +type MockQuicConnMockRecorder struct { + mock *MockQuicConn } -// NewMockQuicSession creates a new mock instance. -func NewMockQuicSession(ctrl *gomock.Controller) *MockQuicSession { - mock := &MockQuicSession{ctrl: ctrl} - mock.recorder = &MockQuicSessionMockRecorder{mock} +// NewMockQuicConn creates a new mock instance. +func NewMockQuicConn(ctrl *gomock.Controller) *MockQuicConn { + mock := &MockQuicConn{ctrl: ctrl} + mock.recorder = &MockQuicConnMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockQuicSession) EXPECT() *MockQuicSessionMockRecorder { +func (m *MockQuicConn) EXPECT() *MockQuicConnMockRecorder { return m.recorder } // AcceptStream mocks base method. -func (m *MockQuicSession) AcceptStream(arg0 context.Context) (Stream, error) { +func (m *MockQuicConn) AcceptStream(arg0 context.Context) (Stream, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AcceptStream", arg0) ret0, _ := ret[0].(Stream) @@ -46,13 +46,13 @@ func (m *MockQuicSession) AcceptStream(arg0 context.Context) (Stream, error) { } // AcceptStream indicates an expected call of AcceptStream. -func (mr *MockQuicSessionMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { +func (mr *MockQuicConnMockRecorder) AcceptStream(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptStream), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptStream", reflect.TypeOf((*MockQuicConn)(nil).AcceptStream), arg0) } // AcceptUniStream mocks base method. -func (m *MockQuicSession) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) { +func (m *MockQuicConn) AcceptUniStream(arg0 context.Context) (ReceiveStream, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AcceptUniStream", arg0) ret0, _ := ret[0].(ReceiveStream) @@ -61,13 +61,13 @@ func (m *MockQuicSession) AcceptUniStream(arg0 context.Context) (ReceiveStream, } // AcceptUniStream indicates an expected call of AcceptUniStream. -func (mr *MockQuicSessionMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { +func (mr *MockQuicConnMockRecorder) AcceptUniStream(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicSession)(nil).AcceptUniStream), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptUniStream", reflect.TypeOf((*MockQuicConn)(nil).AcceptUniStream), arg0) } // CloseWithError mocks base method. -func (m *MockQuicSession) CloseWithError(arg0 ApplicationErrorCode, arg1 string) error { +func (m *MockQuicConn) CloseWithError(arg0 ApplicationErrorCode, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CloseWithError", arg0, arg1) ret0, _ := ret[0].(error) @@ -75,13 +75,13 @@ func (m *MockQuicSession) CloseWithError(arg0 ApplicationErrorCode, arg1 string) } // CloseWithError indicates an expected call of CloseWithError. -func (mr *MockQuicSessionMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockQuicConnMockRecorder) CloseWithError(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQuicSession)(nil).CloseWithError), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseWithError", reflect.TypeOf((*MockQuicConn)(nil).CloseWithError), arg0, arg1) } // ConnectionState mocks base method. -func (m *MockQuicSession) ConnectionState() ConnectionState { +func (m *MockQuicConn) ConnectionState() ConnectionState { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ConnectionState") ret0, _ := ret[0].(ConnectionState) @@ -89,13 +89,13 @@ func (m *MockQuicSession) ConnectionState() ConnectionState { } // ConnectionState indicates an expected call of ConnectionState. -func (mr *MockQuicSessionMockRecorder) ConnectionState() *gomock.Call { +func (mr *MockQuicConnMockRecorder) ConnectionState() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockQuicSession)(nil).ConnectionState)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionState", reflect.TypeOf((*MockQuicConn)(nil).ConnectionState)) } // Context mocks base method. -func (m *MockQuicSession) Context() context.Context { +func (m *MockQuicConn) Context() context.Context { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Context") ret0, _ := ret[0].(context.Context) @@ -103,13 +103,13 @@ func (m *MockQuicSession) Context() context.Context { } // Context indicates an expected call of Context. -func (mr *MockQuicSessionMockRecorder) Context() *gomock.Call { +func (mr *MockQuicConnMockRecorder) Context() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockQuicSession)(nil).Context)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockQuicConn)(nil).Context)) } // GetVersion mocks base method. -func (m *MockQuicSession) GetVersion() protocol.VersionNumber { +func (m *MockQuicConn) GetVersion() protocol.VersionNumber { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetVersion") ret0, _ := ret[0].(protocol.VersionNumber) @@ -117,13 +117,13 @@ func (m *MockQuicSession) GetVersion() protocol.VersionNumber { } // GetVersion indicates an expected call of GetVersion. -func (mr *MockQuicSessionMockRecorder) GetVersion() *gomock.Call { +func (mr *MockQuicConnMockRecorder) GetVersion() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockQuicSession)(nil).GetVersion)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVersion", reflect.TypeOf((*MockQuicConn)(nil).GetVersion)) } // HandshakeComplete mocks base method. -func (m *MockQuicSession) HandshakeComplete() context.Context { +func (m *MockQuicConn) HandshakeComplete() context.Context { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HandshakeComplete") ret0, _ := ret[0].(context.Context) @@ -131,13 +131,13 @@ func (m *MockQuicSession) HandshakeComplete() context.Context { } // HandshakeComplete indicates an expected call of HandshakeComplete. -func (mr *MockQuicSessionMockRecorder) HandshakeComplete() *gomock.Call { +func (mr *MockQuicConnMockRecorder) HandshakeComplete() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockQuicSession)(nil).HandshakeComplete)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandshakeComplete", reflect.TypeOf((*MockQuicConn)(nil).HandshakeComplete)) } // LocalAddr mocks base method. -func (m *MockQuicSession) LocalAddr() net.Addr { +func (m *MockQuicConn) LocalAddr() net.Addr { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LocalAddr") ret0, _ := ret[0].(net.Addr) @@ -145,27 +145,27 @@ func (m *MockQuicSession) LocalAddr() net.Addr { } // LocalAddr indicates an expected call of LocalAddr. -func (mr *MockQuicSessionMockRecorder) LocalAddr() *gomock.Call { +func (mr *MockQuicConnMockRecorder) LocalAddr() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockQuicSession)(nil).LocalAddr)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockQuicConn)(nil).LocalAddr)) } -// NextSession mocks base method. -func (m *MockQuicSession) NextSession() Session { +// NextConnection mocks base method. +func (m *MockQuicConn) NextConnection() Connection { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NextSession") - ret0, _ := ret[0].(Session) + ret := m.ctrl.Call(m, "NextConnection") + ret0, _ := ret[0].(Connection) return ret0 } -// NextSession indicates an expected call of NextSession. -func (mr *MockQuicSessionMockRecorder) NextSession() *gomock.Call { +// NextConnection indicates an expected call of NextConnection. +func (mr *MockQuicConnMockRecorder) NextConnection() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextSession", reflect.TypeOf((*MockQuicSession)(nil).NextSession)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextConnection", reflect.TypeOf((*MockQuicConn)(nil).NextConnection)) } // OpenStream mocks base method. -func (m *MockQuicSession) OpenStream() (Stream, error) { +func (m *MockQuicConn) OpenStream() (Stream, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OpenStream") ret0, _ := ret[0].(Stream) @@ -174,13 +174,13 @@ func (m *MockQuicSession) OpenStream() (Stream, error) { } // OpenStream indicates an expected call of OpenStream. -func (mr *MockQuicSessionMockRecorder) OpenStream() *gomock.Call { +func (mr *MockQuicConnMockRecorder) OpenStream() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockQuicSession)(nil).OpenStream)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockQuicConn)(nil).OpenStream)) } // OpenStreamSync mocks base method. -func (m *MockQuicSession) OpenStreamSync(arg0 context.Context) (Stream, error) { +func (m *MockQuicConn) OpenStreamSync(arg0 context.Context) (Stream, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OpenStreamSync", arg0) ret0, _ := ret[0].(Stream) @@ -189,13 +189,13 @@ func (m *MockQuicSession) OpenStreamSync(arg0 context.Context) (Stream, error) { } // OpenStreamSync indicates an expected call of OpenStreamSync. -func (mr *MockQuicSessionMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { +func (mr *MockQuicConnMockRecorder) OpenStreamSync(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenStreamSync), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStreamSync", reflect.TypeOf((*MockQuicConn)(nil).OpenStreamSync), arg0) } // OpenUniStream mocks base method. -func (m *MockQuicSession) OpenUniStream() (SendStream, error) { +func (m *MockQuicConn) OpenUniStream() (SendStream, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OpenUniStream") ret0, _ := ret[0].(SendStream) @@ -204,13 +204,13 @@ func (m *MockQuicSession) OpenUniStream() (SendStream, error) { } // OpenUniStream indicates an expected call of OpenUniStream. -func (mr *MockQuicSessionMockRecorder) OpenUniStream() *gomock.Call { +func (mr *MockQuicConnMockRecorder) OpenUniStream() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockQuicSession)(nil).OpenUniStream)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStream", reflect.TypeOf((*MockQuicConn)(nil).OpenUniStream)) } // OpenUniStreamSync mocks base method. -func (m *MockQuicSession) OpenUniStreamSync(arg0 context.Context) (SendStream, error) { +func (m *MockQuicConn) OpenUniStreamSync(arg0 context.Context) (SendStream, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "OpenUniStreamSync", arg0) ret0, _ := ret[0].(SendStream) @@ -219,13 +219,13 @@ func (m *MockQuicSession) OpenUniStreamSync(arg0 context.Context) (SendStream, e } // OpenUniStreamSync indicates an expected call of OpenUniStreamSync. -func (mr *MockQuicSessionMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { +func (mr *MockQuicConnMockRecorder) OpenUniStreamSync(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQuicSession)(nil).OpenUniStreamSync), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenUniStreamSync", reflect.TypeOf((*MockQuicConn)(nil).OpenUniStreamSync), arg0) } // ReceiveMessage mocks base method. -func (m *MockQuicSession) ReceiveMessage() ([]byte, error) { +func (m *MockQuicConn) ReceiveMessage() ([]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ReceiveMessage") ret0, _ := ret[0].([]byte) @@ -234,13 +234,13 @@ func (m *MockQuicSession) ReceiveMessage() ([]byte, error) { } // ReceiveMessage indicates an expected call of ReceiveMessage. -func (mr *MockQuicSessionMockRecorder) ReceiveMessage() *gomock.Call { +func (mr *MockQuicConnMockRecorder) ReceiveMessage() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockQuicSession)(nil).ReceiveMessage)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveMessage", reflect.TypeOf((*MockQuicConn)(nil).ReceiveMessage)) } // RemoteAddr mocks base method. -func (m *MockQuicSession) RemoteAddr() net.Addr { +func (m *MockQuicConn) RemoteAddr() net.Addr { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RemoteAddr") ret0, _ := ret[0].(net.Addr) @@ -248,13 +248,13 @@ func (m *MockQuicSession) RemoteAddr() net.Addr { } // RemoteAddr indicates an expected call of RemoteAddr. -func (mr *MockQuicSessionMockRecorder) RemoteAddr() *gomock.Call { +func (mr *MockQuicConnMockRecorder) RemoteAddr() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQuicSession)(nil).RemoteAddr)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockQuicConn)(nil).RemoteAddr)) } // SendMessage mocks base method. -func (m *MockQuicSession) SendMessage(arg0 []byte) error { +func (m *MockQuicConn) SendMessage(arg0 []byte) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SendMessage", arg0) ret0, _ := ret[0].(error) @@ -262,39 +262,39 @@ func (m *MockQuicSession) SendMessage(arg0 []byte) error { } // SendMessage indicates an expected call of SendMessage. -func (mr *MockQuicSessionMockRecorder) SendMessage(arg0 interface{}) *gomock.Call { +func (mr *MockQuicConnMockRecorder) SendMessage(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockQuicSession)(nil).SendMessage), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockQuicConn)(nil).SendMessage), arg0) } // destroy mocks base method. -func (m *MockQuicSession) destroy(arg0 error) { +func (m *MockQuicConn) destroy(arg0 error) { m.ctrl.T.Helper() m.ctrl.Call(m, "destroy", arg0) } // destroy indicates an expected call of destroy. -func (mr *MockQuicSessionMockRecorder) destroy(arg0 interface{}) *gomock.Call { +func (mr *MockQuicConnMockRecorder) destroy(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockQuicSession)(nil).destroy), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockQuicConn)(nil).destroy), arg0) } -// earlySessionReady mocks base method. -func (m *MockQuicSession) earlySessionReady() <-chan struct{} { +// earlyConnReady mocks base method. +func (m *MockQuicConn) earlyConnReady() <-chan struct{} { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "earlySessionReady") + ret := m.ctrl.Call(m, "earlyConnReady") ret0, _ := ret[0].(<-chan struct{}) return ret0 } -// earlySessionReady indicates an expected call of earlySessionReady. -func (mr *MockQuicSessionMockRecorder) earlySessionReady() *gomock.Call { +// earlyConnReady indicates an expected call of earlyConnReady. +func (mr *MockQuicConnMockRecorder) earlyConnReady() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "earlySessionReady", reflect.TypeOf((*MockQuicSession)(nil).earlySessionReady)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "earlyConnReady", reflect.TypeOf((*MockQuicConn)(nil).earlyConnReady)) } // getPerspective mocks base method. -func (m *MockQuicSession) getPerspective() protocol.Perspective { +func (m *MockQuicConn) getPerspective() protocol.Perspective { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "getPerspective") ret0, _ := ret[0].(protocol.Perspective) @@ -302,25 +302,25 @@ func (m *MockQuicSession) getPerspective() protocol.Perspective { } // getPerspective indicates an expected call of getPerspective. -func (mr *MockQuicSessionMockRecorder) getPerspective() *gomock.Call { +func (mr *MockQuicConnMockRecorder) getPerspective() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getPerspective", reflect.TypeOf((*MockQuicSession)(nil).getPerspective)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getPerspective", reflect.TypeOf((*MockQuicConn)(nil).getPerspective)) } // handlePacket mocks base method. -func (m *MockQuicSession) handlePacket(arg0 *receivedPacket) { +func (m *MockQuicConn) handlePacket(arg0 *receivedPacket) { m.ctrl.T.Helper() m.ctrl.Call(m, "handlePacket", arg0) } // handlePacket indicates an expected call of handlePacket. -func (mr *MockQuicSessionMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { +func (mr *MockQuicConnMockRecorder) handlePacket(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockQuicSession)(nil).handlePacket), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockQuicConn)(nil).handlePacket), arg0) } // run mocks base method. -func (m *MockQuicSession) run() error { +func (m *MockQuicConn) run() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "run") ret0, _ := ret[0].(error) @@ -328,19 +328,19 @@ func (m *MockQuicSession) run() error { } // run indicates an expected call of run. -func (mr *MockQuicSessionMockRecorder) run() *gomock.Call { +func (mr *MockQuicConnMockRecorder) run() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "run", reflect.TypeOf((*MockQuicSession)(nil).run)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "run", reflect.TypeOf((*MockQuicConn)(nil).run)) } // shutdown mocks base method. -func (m *MockQuicSession) shutdown() { +func (m *MockQuicConn) shutdown() { m.ctrl.T.Helper() m.ctrl.Call(m, "shutdown") } // shutdown indicates an expected call of shutdown. -func (mr *MockQuicSessionMockRecorder) shutdown() *gomock.Call { +func (mr *MockQuicConnMockRecorder) shutdown() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "shutdown", reflect.TypeOf((*MockQuicSession)(nil).shutdown)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "shutdown", reflect.TypeOf((*MockQuicConn)(nil).shutdown)) } diff --git a/mock_stream_getter_test.go b/mock_stream_getter_test.go index 50934898713..d4d08b4a829 100644 --- a/mock_stream_getter_test.go +++ b/mock_stream_getter_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: session.go +// Source: connection.go // Package quic is a generated GoMock package. package quic diff --git a/mock_stream_manager_test.go b/mock_stream_manager_test.go index 92c31da94a8..9c86e6d1311 100644 --- a/mock_stream_manager_test.go +++ b/mock_stream_manager_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: session.go +// Source: connection.go // Package quic is a generated GoMock package. package quic diff --git a/mock_unpacker_test.go b/mock_unpacker_test.go index 0703c111f60..22da001b215 100644 --- a/mock_unpacker_test.go +++ b/mock_unpacker_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: session.go +// Source: connection.go // Package quic is a generated GoMock package. package quic diff --git a/mockgen.go b/mockgen.go index 053cfa9ab57..22c2c0e74ee 100644 --- a/mockgen.go +++ b/mockgen.go @@ -16,8 +16,8 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker" //go:generate sh -c "./mockgen_private.sh quic mock_packer_test.go github.com/lucas-clemente/quic-go packer" //go:generate sh -c "./mockgen_private.sh quic mock_mtu_discoverer_test.go github.com/lucas-clemente/quic-go mtuDiscoverer" -//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner" -//go:generate sh -c "./mockgen_private.sh quic mock_quic_session_test.go github.com/lucas-clemente/quic-go quicSession" +//go:generate sh -c "./mockgen_private.sh quic mock_conn_runner_test.go github.com/lucas-clemente/quic-go connRunner" +//go:generate sh -c "./mockgen_private.sh quic mock_quic_conn_test.go github.com/lucas-clemente/quic-go quicConn" //go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler" //go:generate sh -c "./mockgen_private.sh quic mock_unknown_packet_handler_test.go github.com/lucas-clemente/quic-go unknownPacketHandler" //go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_manager_test.go github.com/lucas-clemente/quic-go packetHandlerManager" diff --git a/multiplexer.go b/multiplexer.go index 006305af18d..2271b551722 100644 --- a/multiplexer.go +++ b/multiplexer.go @@ -32,7 +32,7 @@ type connManager struct { } // The connMultiplexer listens on multiple net.PacketConns and dispatches -// incoming packets to the session handler. +// incoming packets to the connection handler. type connMultiplexer struct { mutex sync.Mutex diff --git a/packet_handler_map.go b/packet_handler_map.go index 5b4659d20a4..2d55a95ef86 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -7,8 +7,12 @@ import ( "errors" "fmt" "hash" + "io" "log" "net" + "os" + "strconv" + "strings" "sync" "time" @@ -45,6 +49,14 @@ func (h *zeroRTTQueue) Clear() { } } +// rawConn is a connection that allow reading of a receivedPacket. +type rawConn interface { + ReadPacket() (*receivedPacket, error) + WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) + LocalAddr() net.Addr + io.Closer +} + type packetHandlerMapEntry struct { packetHandler packetHandler is0RTTQueue bool @@ -52,12 +64,12 @@ type packetHandlerMapEntry struct { // The packetHandlerMap stores packetHandlers, identified by connection ID. // It is used: -// * by the server to store sessions +// * by the server to store connections // * when multiplexing outgoing connections to store clients type packetHandlerMap struct { mutex sync.Mutex - conn connection + conn rawConn connIDLen int handlers map[string] /* string(ConnectionID)*/ packetHandlerMapEntry @@ -68,8 +80,8 @@ type packetHandlerMap struct { listening chan struct{} // is closed when listen returns closed bool - deleteRetiredSessionsAfter time.Duration - zeroRTTQueueDuration time.Duration + deleteRetiredConnsAfter time.Duration + zeroRTTQueueDuration time.Duration statelessResetEnabled bool statelessResetMutex sync.Mutex @@ -92,6 +104,7 @@ func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { } if size >= protocol.DesiredReceiveBufferSize { logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024) + return nil } if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { return fmt.Errorf("failed to increase receive buffer size: %w", err) @@ -110,7 +123,7 @@ func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { return nil } -// only print warnings about the UPD receive buffer size once +// only print warnings about the UDP receive buffer size once var receiveBufferWarningOnce sync.Once func newPacketHandlerMap( @@ -121,26 +134,31 @@ func newPacketHandlerMap( logger utils.Logger, ) (packetHandlerManager, error) { if err := setReceiveBuffer(c, logger); err != nil { - receiveBufferWarningOnce.Do(func() { - log.Printf("%s. See https://github.com/lucas-clemente/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err) - }) + if !strings.Contains(err.Error(), "use of closed network connection") { + receiveBufferWarningOnce.Do(func() { + if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { + return + } + log.Printf("%s. See https://github.com/lucas-clemente/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err) + }) + } } conn, err := wrapConn(c) if err != nil { return nil, err } m := &packetHandlerMap{ - conn: conn, - connIDLen: connIDLen, - listening: make(chan struct{}), - handlers: make(map[string]packetHandlerMapEntry), - resetTokens: make(map[protocol.StatelessResetToken]packetHandler), - deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout, - zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, - statelessResetEnabled: len(statelessResetKey) > 0, - statelessResetHasher: hmac.New(sha256.New, statelessResetKey), - tracer: tracer, - logger: logger, + conn: conn, + connIDLen: connIDLen, + listening: make(chan struct{}), + handlers: make(map[string]packetHandlerMapEntry), + resetTokens: make(map[protocol.StatelessResetToken]packetHandler), + deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout, + zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, + statelessResetEnabled: len(statelessResetKey) > 0, + statelessResetHasher: hmac.New(sha256.New, statelessResetKey), + tracer: tracer, + logger: logger, } go m.listen() @@ -196,7 +214,7 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co var q *zeroRTTQueue if entry, ok := h.handlers[string(clientDestConnID)]; ok { if !entry.is0RTTQueue { - h.logger.Debugf("Not adding connection ID %s for a new session, as it already exists.", clientDestConnID) + h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) return false } q = entry.packetHandler.(*zeroRTTQueue) @@ -212,7 +230,7 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co } h.handlers[string(clientDestConnID)] = packetHandlerMapEntry{packetHandler: sess} h.handlers[string(newConnID)] = packetHandlerMapEntry{packetHandler: sess} - h.logger.Debugf("Adding connection IDs %s and %s for a new session.", clientDestConnID, newConnID) + h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) return true } @@ -224,8 +242,8 @@ func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { } func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { - h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredSessionsAfter) - time.AfterFunc(h.deleteRetiredSessionsAfter, func() { + h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredConnsAfter) + time.AfterFunc(h.deleteRetiredConnsAfter, func() { h.mutex.Lock() delete(h.handlers, string(id)) h.mutex.Unlock() @@ -237,14 +255,14 @@ func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler p h.mutex.Lock() h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler} h.mutex.Unlock() - h.logger.Debugf("Replacing session for connection ID %s with a closed session.", id) + h.logger.Debugf("Replacing connection for connection ID %s with a closed connection.", id) - time.AfterFunc(h.deleteRetiredSessionsAfter, func() { + time.AfterFunc(h.deleteRetiredConnsAfter, func() { h.mutex.Lock() handler.shutdown() delete(h.handlers, string(id)) h.mutex.Unlock() - h.logger.Debugf("Removing connection ID %s for a closed session after it has been retired.", id) + h.logger.Debugf("Removing connection ID %s for a closed connection after it has been retired.", id) }) } @@ -289,7 +307,7 @@ func (h *packetHandlerMap) CloseServer() { } // Destroy closes the underlying connection and waits until listen() has returned. -// It does not close active sessions. +// It does not close active connections. func (h *packetHandlerMap) Destroy() error { if err := h.conn.Close(); err != nil { return err @@ -327,6 +345,10 @@ func (h *packetHandlerMap) listen() { defer close(h.listening) for { p, err := h.conn.ReadPacket() + //nolint:staticcheck // SA1019 ignore this! + // TODO: This code is used to ignore wsa errors on Windows. + // Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution. + // See https://github.com/lucas-clemente/quic-go/issues/1737 for details. if nerr, ok := err.(net.Error); ok && nerr.Temporary() { h.logger.Debugf("Temporary error reading from conn: %w", err) continue @@ -363,7 +385,7 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) { entry.packetHandler.handlePacket(p) return } - } else { // existing session + } else { // existing connection entry.packetHandler.handlePacket(p) return } @@ -389,7 +411,7 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) { queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() { h.mutex.Lock() defer h.mutex.Unlock() - // The entry might have been replaced by an actual session. + // The entry might have been replaced by an actual connection. // Only delete it if it's still a 0-RTT queue. if entry, ok := h.handlers[string(connID)]; ok && entry.is0RTTQueue { delete(h.handlers, string(connID)) @@ -421,7 +443,7 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { var token protocol.StatelessResetToken copy(token[:], data[len(data)-16:]) if sess, ok := h.resetTokens[token]; ok { - h.logger.Debugf("Received a stateless reset with token %#x. Closing session.", token) + h.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) go sess.destroy(&StatelessResetError{Token: token}) return true } diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 48f1b91cb6e..d678d6dbf82 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -89,12 +89,12 @@ var _ = Describe("Packet Handler Map", func() { }() testErr := errors.New("test error ") - sess1 := NewMockPacketHandler(mockCtrl) - sess1.EXPECT().destroy(testErr) - sess2 := NewMockPacketHandler(mockCtrl) - sess2.EXPECT().destroy(testErr) - handler.Add(protocol.ConnectionID{1, 1, 1, 1}, sess1) - handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2) + conn1 := NewMockPacketHandler(mockCtrl) + conn1.EXPECT().destroy(testErr) + conn2 := NewMockPacketHandler(mockCtrl) + conn2.EXPECT().destroy(testErr) + handler.Add(protocol.ConnectionID{1, 1, 1, 1}, conn1) + handler.Add(protocol.ConnectionID{2, 2, 2, 2}, conn2) mockMultiplexer.EXPECT().RemoveConn(gomock.Any()) handler.close(testErr) close(packetChan) @@ -103,7 +103,7 @@ var _ = Describe("Packet Handler Map", func() { Context("other operations", func() { AfterEach(func() { - // delete sessions and the server before closing + // delete connections and the server before closing // They might be mock implementations, and we'd have to register the expected calls before otherwise. handler.mutex.Lock() for connID := range handler.handlers { @@ -160,8 +160,8 @@ var _ = Describe("Packet Handler Map", func() { }) }) - It("deletes removed sessions immediately", func() { - handler.deleteRetiredSessionsAfter = time.Hour + It("deletes removed connections immediately", func() { + handler.deleteRetiredConnsAfter = time.Hour connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Remove(connID) @@ -169,19 +169,19 @@ var _ = Describe("Packet Handler Map", func() { // don't EXPECT any calls to handlePacket of the MockPacketHandler }) - It("deletes retired session entries after a wait time", func() { - handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) + It("deletes retired connection entries after a wait time", func() { + handler.deleteRetiredConnsAfter = scaleDuration(10 * time.Millisecond) connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - sess := NewMockPacketHandler(mockCtrl) - handler.Add(connID, sess) + conn := NewMockPacketHandler(mockCtrl) + handler.Add(connID, conn) handler.Retire(connID) time.Sleep(scaleDuration(30 * time.Millisecond)) handler.handlePacket(&receivedPacket{data: getPacket(connID)}) // don't EXPECT any calls to handlePacket of the MockPacketHandler }) - It("passes packets arriving late for closed sessions to that session", func() { - handler.deleteRetiredSessionsAfter = time.Hour + It("passes packets arriving late for closed connections to that connection", func() { + handler.deleteRetiredConnsAfter = time.Hour connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} packetHandler := NewMockPacketHandler(mockCtrl) handled := make(chan struct{}) @@ -250,16 +250,16 @@ var _ = Describe("Packet Handler Map", func() { handler.handlePacket(&receivedPacket{data: p}) }) - It("closes all server sessions", func() { + It("closes all server connections", func() { handler.SetServer(NewMockUnknownPacketHandler(mockCtrl)) - clientSess := NewMockPacketHandler(mockCtrl) - clientSess.EXPECT().getPerspective().Return(protocol.PerspectiveClient) - serverSess := NewMockPacketHandler(mockCtrl) - serverSess.EXPECT().getPerspective().Return(protocol.PerspectiveServer) - serverSess.EXPECT().shutdown() - - handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess) - handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverSess) + clientConn := NewMockPacketHandler(mockCtrl) + clientConn.EXPECT().getPerspective().Return(protocol.PerspectiveClient) + serverConn := NewMockPacketHandler(mockCtrl) + serverConn.EXPECT().getPerspective().Return(protocol.PerspectiveServer) + serverConn.EXPECT().shutdown() + + handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientConn) + handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverConn) handler.CloseServer() }) @@ -293,23 +293,23 @@ var _ = Describe("Packet Handler Map", func() { handler.handlePacket(p1) handler.handlePacket(p2) handler.handlePacket(p3) - sess := NewMockPacketHandler(mockCtrl) + conn := NewMockPacketHandler(mockCtrl) done := make(chan struct{}) gomock.InOrder( - sess.EXPECT().handlePacket(p1), - sess.EXPECT().handlePacket(p2), - sess.EXPECT().handlePacket(p3).Do(func(packet *receivedPacket) { close(done) }), + conn.EXPECT().handlePacket(p1), + conn.EXPECT().handlePacket(p2), + conn.EXPECT().handlePacket(p3).Do(func(packet *receivedPacket) { close(done) }), ) - handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return sess }) + handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) Eventually(done).Should(BeClosed()) }) - It("directs 0-RTT packets to existing sessions", func() { + It("directs 0-RTT packets to existing connections", func() { connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} - sess := NewMockPacketHandler(mockCtrl) - handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return sess }) + conn := NewMockPacketHandler(mockCtrl) + handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} - sess.EXPECT().handlePacket(p1) + conn.EXPECT().handlePacket(p1) handler.handlePacket(p1) }) @@ -324,12 +324,12 @@ var _ = Describe("Packet Handler Map", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9} handler.handlePacket(&receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)}) // Don't EXPECT any handlePacket() calls. - sess := NewMockPacketHandler(mockCtrl) - handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return sess }) + conn := NewMockPacketHandler(mockCtrl) + handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) time.Sleep(20 * time.Millisecond) }) - It("deletes queues if no session is created for this connection ID", func() { + It("deletes queues if no connection is created for this connection ID", func() { queueDuration := scaleDuration(10 * time.Millisecond) handler.zeroRTTQueueDuration = queueDuration @@ -350,8 +350,8 @@ var _ = Describe("Packet Handler Map", func() { // wait a bit. The queue should now already be deleted. time.Sleep(queueDuration * 3) // Don't EXPECT any handlePacket() calls. - sess := NewMockPacketHandler(mockCtrl) - handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return sess }) + conn := NewMockPacketHandler(mockCtrl) + handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) time.Sleep(20 * time.Millisecond) }) }) diff --git a/receive_stream.go b/receive_stream.go index f9a1e066ff5..ae6a449b575 100644 --- a/receive_stream.go +++ b/receive_stream.go @@ -47,6 +47,7 @@ type receiveStream struct { resetRemotely bool // set when HandleResetStreamFrame() is called readChan chan struct{} + readOnce chan struct{} // cap: 1, to protect against concurrent use of Read deadline time.Time flowController flowcontrol.StreamFlowController @@ -70,6 +71,7 @@ func newReceiveStream( flowController: flowController, frameQueue: newFrameSorter(), readChan: make(chan struct{}, 1), + readOnce: make(chan struct{}, 1), finalOffset: protocol.MaxByteCount, version: version, } @@ -81,6 +83,12 @@ func (s *receiveStream) StreamID() protocol.StreamID { // Read implements io.Reader. It is not thread safe! func (s *receiveStream) Read(p []byte) (int, error) { + // Concurrent use of Read is not permitted (and doesn't make any sense), + // but sometimes people do it anyway. + // Make sure that we only execute one call at any given time to avoid hard to debug failures. + s.readOnce <- struct{}{} + defer func() { <-s.readOnce }() + s.mutex.Lock() completed, n, err := s.readImpl(p) s.mutex.Unlock() @@ -105,7 +113,7 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err return false, 0, s.closeForShutdownErr } - bytesRead := 0 + var bytesRead int var deadlineTimer *utils.Timer for bytesRead < len(p) { if s.currentFrame == nil || s.readPosInFrame >= len(s.currentFrame) { diff --git a/receive_stream_test.go b/receive_stream_test.go index 088be4297d1..20157bb408c 100644 --- a/receive_stream_test.go +++ b/receive_stream_test.go @@ -4,6 +4,8 @@ import ( "errors" "io" "runtime" + "sync" + "sync/atomic" "time" "github.com/golang/mock/gomock" @@ -218,7 +220,6 @@ var _ = Describe("Receive Stream", func() { Context("deadlines", func() { It("the deadline error has the right net.Error properties", func() { - Expect(errDeadline.Temporary()).To(BeTrue()) Expect(errDeadline.Timeout()).To(BeTrue()) Expect(errDeadline).To(MatchError("deadline exceeded")) }) @@ -404,6 +405,43 @@ var _ = Describe("Receive Stream", func() { Expect(n).To(BeZero()) Expect(err).To(MatchError(io.EOF)) }) + + // Calling Read concurrently doesn't make any sense (and is forbidden), + // but we still want to make sure that we don't complete the stream more than once + // if the user misuses our API. + // This would lead to an INTERNAL_ERROR ("tried to delete unknown outgoing stream"), + // which can be hard to debug. + // Note that even without the protection built into the receiveStream, this test + // is very timing-dependent, and would need to run a few hundred times to trigger the failure. + It("handles concurrent reads", func() { + mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), gomock.Any()).AnyTimes() + var bytesRead protocol.ByteCount + mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) { bytesRead += n }).AnyTimes() + + var numCompleted int32 + mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) { + atomic.AddInt32(&numCompleted, 1) + }).AnyTimes() + const num = 3 + var wg sync.WaitGroup + wg.Add(num) + for i := 0; i < num; i++ { + go func() { + defer wg.Done() + defer GinkgoRecover() + _, err := str.Read(make([]byte, 8)) + Expect(err).To(MatchError(io.EOF)) + }() + } + str.handleStreamFrame(&wire.StreamFrame{ + Offset: 0, + Data: []byte("foobar"), + Fin: true, + }) + wg.Wait() + Expect(bytesRead).To(BeEquivalentTo(6)) + Expect(atomic.LoadInt32(&numCompleted)).To(BeEquivalentTo(1)) + }) }) It("closes when CloseRemote is called", func() { diff --git a/send_conn.go b/send_conn.go index b276af11388..c53ebdfab1c 100644 --- a/send_conn.go +++ b/send_conn.go @@ -13,7 +13,7 @@ type sendConn interface { } type sconn struct { - connection + rawConn remoteAddr net.Addr info *packetInfo @@ -22,9 +22,9 @@ type sconn struct { var _ sendConn = &sconn{} -func newSendConn(c connection, remote net.Addr, info *packetInfo) sendConn { +func newSendConn(c rawConn, remote net.Addr, info *packetInfo) sendConn { return &sconn{ - connection: c, + rawConn: c, remoteAddr: remote, info: info, oob: info.OOB(), @@ -41,7 +41,7 @@ func (c *sconn) RemoteAddr() net.Addr { } func (c *sconn) LocalAddr() net.Addr { - addr := c.connection.LocalAddr() + addr := c.rawConn.LocalAddr() if c.info != nil { if udpAddr, ok := addr.(*net.UDPAddr); ok { addrCopy := *udpAddr diff --git a/send_queue.go b/send_queue.go index bf25dded60e..1fc8c1bf893 100644 --- a/send_queue.go +++ b/send_queue.go @@ -64,7 +64,13 @@ func (h *sendQueue) Run() error { shouldClose = true case p := <-h.queue: if err := h.conn.Write(p.Data); err != nil { - return err + // This additional check enables: + // 1. Checking for "datagram too large" message from the kernel, as such, + // 2. Path MTU discovery,and + // 3. Eventual detection of loss PingFrame. + if !isMsgSizeErr(err) { + return err + } } p.Release() select { diff --git a/send_stream.go b/send_stream.go index 946243ca12f..b23df00b377 100644 --- a/send_stream.go +++ b/send_stream.go @@ -50,6 +50,7 @@ type sendStream struct { nextFrame *wire.StreamFrame writeChan chan struct{} + writeOnce chan struct{} deadline time.Time flowController flowcontrol.StreamFlowController @@ -73,6 +74,7 @@ func newSendStream( sender: sender, flowController: flowController, writeChan: make(chan struct{}, 1), + writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write version: version, } s.ctx, s.ctxCancel = context.WithCancel(context.Background()) @@ -84,6 +86,12 @@ func (s *sendStream) StreamID() protocol.StreamID { } func (s *sendStream) Write(p []byte) (int, error) { + // Concurrent use of Write is not permitted (and doesn't make any sense), + // but sometimes people do it anyway. + // Make sure that we only execute one call at any given time to avoid hard to debug failures. + s.writeOnce <- struct{}{} + defer func() { <-s.writeOnce }() + s.mutex.Lock() defer s.mutex.Unlock() diff --git a/server.go b/server.go index feb2c79ab53..33f5af91608 100644 --- a/server.go +++ b/server.go @@ -20,6 +20,9 @@ import ( "github.com/lucas-clemente/quic-go/logging" ) +// ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close. +var ErrServerClosed = errors.New("quic: Server closed") + // packetHandler handles packets type packetHandler interface { handlePacket(*receivedPacket) @@ -36,14 +39,14 @@ type unknownPacketHandler interface { type packetHandlerManager interface { AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool Destroy() error - sessionRunner + connRunner SetServer(unknownPacketHandler) CloseServer() } -type quicSession interface { - EarlySession - earlySessionReady() <-chan struct{} +type quicConn interface { + EarlyConnection + earlyConnReady() <-chan struct{} handlePacket(*receivedPacket) GetVersion() protocol.VersionNumber getPerspective() protocol.Perspective @@ -56,26 +59,26 @@ type quicSession interface { type baseServer struct { mutex sync.Mutex - acceptEarlySessions bool + acceptEarlyConns bool tlsConf *tls.Config config *Config - conn connection + conn rawConn // If the server is started with ListenAddr, we create a packet conn. // If it is started with Listen, we take a packet conn as a parameter. createdPacketConn bool tokenGenerator *handshake.TokenGenerator - sessionHandler packetHandlerManager + connHandler packetHandlerManager receivedPackets chan *receivedPacket // set as a member, so they can be set in the tests - newSession func( + newConn func( sendConn, - sessionRunner, + connRunner, protocol.ConnectionID, /* original dest connection ID */ *protocol.ConnectionID, /* retry src connection ID */ protocol.ConnectionID, /* client dest connection ID */ @@ -91,15 +94,15 @@ type baseServer struct { uint64, utils.Logger, protocol.VersionNumber, - ) quicSession + ) quicConn serverError error errorChan chan struct{} closed bool running chan struct{} // closed as soon as run() returns - sessionQueue chan quicSession - sessionQueueLen int32 // to be used as an atomic + connQueue chan quicConn + connQueueLen int32 // to be used as an atomic logger utils.Logger } @@ -113,7 +116,7 @@ type earlyServer struct{ *baseServer } var _ EarlyListener = &earlyServer{} -func (s *earlyServer) Accept(ctx context.Context) (EarlySession, error) { +func (s *earlyServer) Accept(ctx context.Context) (EarlyConnection, error) { return s.baseServer.accept(ctx) } @@ -124,7 +127,7 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err return listenAddr(addr, tlsConf, config, false) } -// ListenAddrEarly works like ListenAddr, but it returns sessions before the handshake completes. +// ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes. func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (EarlyListener, error) { s, err := listenAddr(addr, tlsConf, config, true) if err != nil { @@ -165,7 +168,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, return listen(conn, tlsConf, config, false) } -// ListenEarly works like Listen, but it returns sessions before the handshake completes. +// ListenEarly works like Listen, but it returns connections before the handshake completes. func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (EarlyListener, error) { s, err := listen(conn, tlsConf, config, true) if err != nil { @@ -188,7 +191,7 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl } } - sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer) + connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer) if err != nil { return nil, err } @@ -201,21 +204,21 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl return nil, err } s := &baseServer{ - conn: c, - tlsConf: tlsConf, - config: config, - tokenGenerator: tokenGenerator, - sessionHandler: sessionHandler, - sessionQueue: make(chan quicSession), - errorChan: make(chan struct{}), - running: make(chan struct{}), - receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), - newSession: newSession, - logger: utils.DefaultLogger.WithPrefix("server"), - acceptEarlySessions: acceptEarly, + conn: c, + tlsConf: tlsConf, + config: config, + tokenGenerator: tokenGenerator, + connHandler: connHandler, + connQueue: make(chan quicConn), + errorChan: make(chan struct{}), + running: make(chan struct{}), + receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), + newConn: newConnection, + logger: utils.DefaultLogger.WithPrefix("server"), + acceptEarlyConns: acceptEarly, } go s.run() - sessionHandler.SetServer(s) + connHandler.SetServer(s) s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) return s, nil } @@ -261,19 +264,19 @@ var isClientAddressValidationTokenValid = func(clientAddr net.Addr, token *Token return sourceAddr == token.RemoteAddr } -// Accept returns sessions that already completed the handshake. -// It is only valid if acceptEarlySessions is false. -func (s *baseServer) Accept(ctx context.Context) (Session, error) { +// Accept returns connections that already completed the handshake. +// It is only valid if acceptEarlyConns is false. +func (s *baseServer) Accept(ctx context.Context) (Connection, error) { return s.accept(ctx) } -func (s *baseServer) accept(ctx context.Context) (quicSession, error) { +func (s *baseServer) accept(ctx context.Context) (quicConn, error) { select { case <-ctx.Done(): return nil, ctx.Err() - case sess := <-s.sessionQueue: - atomic.AddInt32(&s.sessionQueueLen, -1) - return sess, nil + case conn := <-s.connQueue: + atomic.AddInt32(&s.connQueueLen, -1) + return conn, nil case <-s.errorChan: return nil, s.serverError } @@ -287,7 +290,7 @@ func (s *baseServer) Close() error { return nil } if s.serverError == nil { - s.serverError = errors.New("server closed") + s.serverError = ErrServerClosed } // If the server was started with ListenAddr, we created the packet conn. // We need to close it in order to make the go routine reading from that conn return. @@ -297,9 +300,9 @@ func (s *baseServer) Close() error { s.mutex.Unlock() <-s.running - s.sessionHandler.CloseServer() + s.connHandler.CloseServer() if createdPacketConn { - return s.sessionHandler.Destroy() + return s.connHandler.Destroy() } return nil } @@ -339,7 +342,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s } return false } - // If we're creating a new session, the packet will be passed to the session. + // If we're creating a new connection, the packet will be passed to the connection. // The header will then be parsed again. hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength) if err != nil && err != wire.ErrUnsupportedVersion { @@ -439,7 +442,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro return nil } - if queueLen := atomic.LoadInt32(&s.sessionQueueLen); queueLen >= protocol.MaxAcceptQueueSize { + if queueLen := atomic.LoadInt32(&s.connQueueLen); queueLen >= protocol.MaxAcceptQueueSize { s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize) go func() { defer p.buffer.Release() @@ -455,9 +458,9 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro return err } s.logger.Debugf("Changing connection ID to %s.", connID) - var sess quicSession - tracingID := nextSessionTracingID() - if added := s.sessionHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { + var conn quicConn + tracingID := nextConnTracingID() + if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { var tracer logging.ConnectionTracer if s.config.Tracer != nil { // Use the same connection ID that is passed to the client's GetLogWriter callback. @@ -466,75 +469,75 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro connID = origDestConnID } tracer = s.config.Tracer.TracerForConnection( - context.WithValue(context.Background(), SessionTracingKey, tracingID), + context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID, ) } - sess = s.newSession( + conn = s.newConn( newSendConn(s.conn, p.remoteAddr, p.info), - s.sessionHandler, + s.connHandler, origDestConnID, retrySrcConnID, hdr.DestConnectionID, hdr.SrcConnectionID, connID, - s.sessionHandler.GetStatelessResetToken(connID), + s.connHandler.GetStatelessResetToken(connID), s.config, s.tlsConf, s.tokenGenerator, - s.acceptEarlySessions, + s.acceptEarlyConns, isClientAddressValidationTokenValid(p.remoteAddr, token), tracer, tracingID, s.logger, hdr.Version, ) - sess.handlePacket(p) - return sess + conn.handlePacket(p) + return conn }); !added { return nil } - go sess.run() - go s.handleNewSession(sess) - if sess == nil { + go conn.run() + go s.handleNewConn(conn) + if conn == nil { p.buffer.Release() return nil } return nil } -func (s *baseServer) handleNewSession(sess quicSession) { - sessCtx := sess.Context() - if s.acceptEarlySessions { - // wait until the early session is ready (or the handshake fails) +func (s *baseServer) handleNewConn(conn quicConn) { + connCtx := conn.Context() + if s.acceptEarlyConns { + // wait until the early connection is ready (or the handshake fails) select { - case <-sess.earlySessionReady(): - case <-sessCtx.Done(): + case <-conn.earlyConnReady(): + case <-connCtx.Done(): return } } else { // wait until the handshake is complete (or fails) select { - case <-sess.HandshakeComplete().Done(): - case <-sessCtx.Done(): + case <-conn.HandshakeComplete().Done(): + case <-connCtx.Done(): return } } - atomic.AddInt32(&s.sessionQueueLen, 1) + atomic.AddInt32(&s.connQueueLen, 1) select { - case s.sessionQueue <- sess: - // blocks until the session is accepted - case <-sessCtx.Done(): - atomic.AddInt32(&s.sessionQueueLen, -1) - // don't pass sessions that were already closed to Accept() + case s.connQueue <- conn: + // blocks until the connection is accepted + case <-connCtx.Done(): + atomic.AddInt32(&s.connQueueLen, -1) + // don't pass connections that were already closed to Accept() } } func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { // Log the Initial packet now. - // If no Retry is sent, the packet will be logged by the session. + // If no Retry is sent, the packet will be logged by the connection. (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) srcConnID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength) if err != nil { @@ -652,11 +655,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han func (s *baseServer) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.Header) { s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) - data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) - if err != nil { - s.logger.Debugf("Error composing Version Negotiation: %s", err) - return - } + data := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) if s.config.Tracer != nil { s.config.Tracer.SentPacket( p.remoteAddr, diff --git a/server_test.go b/server_test.go index ad47e33aaa2..9315a67ff09 100644 --- a/server_test.go +++ b/server_test.go @@ -146,7 +146,7 @@ var _ = Describe("Server", func() { ln, err := Listen(conn, tlsConf, &config) Expect(err).ToNot(HaveOccurred()) server := ln.(*baseServer) - Expect(server.sessionHandler).ToNot(BeNil()) + Expect(server.connHandler).ToNot(BeNil()) Expect(server.config.Versions).To(Equal(supportedVersions)) Expect(server.config.HandshakeIdleTimeout).To(Equal(1337 * time.Hour)) Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) @@ -178,7 +178,7 @@ var _ = Describe("Server", func() { Expect(err).To(BeAssignableToTypeOf(&net.OpError{})) }) - Context("server accepting sessions that completed the handshake", func() { + Context("server accepting connections that completed the handshake", func() { var ( serv *baseServer phm *MockPacketHandlerManager @@ -191,7 +191,7 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) serv = ln.(*baseServer) phm = NewMockPacketHandlerManager(mockCtrl) - serv.sessionHandler = phm + serv.connHandler = phm }) AfterEach(func() { @@ -291,7 +291,7 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) - It("creates a session when the token is accepted", func() { + It("creates a connection when the token is accepted", func() { serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true } retryToken, err := serv.tokenGenerator.NewRetryToken( &net.UDPAddr{}, @@ -323,10 +323,10 @@ var _ = Describe("Server", func() { return true }) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}) - sess := NewMockQuicSession(mockCtrl) - serv.newSession = func( + conn := NewMockQuicConn(mockCtrl) + serv.newConn = func( _ sendConn, - _ sessionRunner, + _ connRunner, origDestConnID protocol.ConnectionID, retrySrcConnID *protocol.ConnectionID, clientDestConnID protocol.ConnectionID, @@ -342,7 +342,7 @@ var _ = Describe("Server", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { + ) quicConn { Expect(enable0RTT).To(BeFalse()) Expect(origDestConnID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) Expect(retrySrcConnID).To(Equal(&protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) @@ -353,18 +353,18 @@ var _ = Describe("Server", func() { Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) Expect(srcConnID).To(Equal(newConnID)) Expect(tokenP).To(Equal(token)) - sess.EXPECT().handlePacket(p) - sess.EXPECT().run().Do(func() { close(run) }) - sess.EXPECT().Context().Return(context.Background()) - sess.EXPECT().HandshakeComplete().Return(context.Background()) - return sess + conn.EXPECT().handlePacket(p) + conn.EXPECT().run().Do(func() { close(run) }) + conn.EXPECT().Context().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(context.Background()) + return conn } done := make(chan struct{}) go func() { defer GinkgoRecover() serv.handlePacket(p) - // the Handshake packet is written by the session. + // the Handshake packet is written by the connection. // Make sure there are no Write calls on the packet conn. time.Sleep(50 * time.Millisecond) close(done) @@ -427,12 +427,11 @@ var _ = Describe("Server", func() { }) It("ignores Version Negotiation packets", func() { - data, err := wire.ComposeVersionNegotiation( + data := wire.ComposeVersionNegotiation( protocol.ConnectionID{1, 2, 3, 4}, protocol.ConnectionID{4, 3, 2, 1}, []protocol.VersionNumber{1, 2, 3}, ) - Expect(err).ToNot(HaveOccurred()) raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} done := make(chan struct{}) tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { @@ -577,7 +576,7 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) - It("creates a session, if no Token is required", func() { + It("creates a connection, if no Token is required", func() { serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } hdr := &wire.Header{ IsLongHeader: true, @@ -603,10 +602,10 @@ var _ = Describe("Server", func() { }) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) - sess := NewMockQuicSession(mockCtrl) - serv.newSession = func( + conn := NewMockQuicConn(mockCtrl) + serv.newConn = func( _ sendConn, - _ sessionRunner, + _ connRunner, origDestConnID protocol.ConnectionID, retrySrcConnID *protocol.ConnectionID, clientDestConnID protocol.ConnectionID, @@ -622,7 +621,7 @@ var _ = Describe("Server", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { + ) quicConn { Expect(enable0RTT).To(BeFalse()) Expect(origDestConnID).To(Equal(hdr.DestConnectionID)) Expect(retrySrcConnID).To(BeNil()) @@ -633,18 +632,18 @@ var _ = Describe("Server", func() { Expect(srcConnID).ToNot(Equal(hdr.SrcConnectionID)) Expect(srcConnID).To(Equal(newConnID)) Expect(tokenP).To(Equal(token)) - sess.EXPECT().handlePacket(p) - sess.EXPECT().run().Do(func() { close(run) }) - sess.EXPECT().Context().Return(context.Background()) - sess.EXPECT().HandshakeComplete().Return(context.Background()) - return sess + conn.EXPECT().handlePacket(p) + conn.EXPECT().run().Do(func() { close(run) }) + conn.EXPECT().Context().Return(context.Background()) + conn.EXPECT().HandshakeComplete().Return(context.Background()) + return conn } done := make(chan struct{}) go func() { defer GinkgoRecover() serv.handlePacket(p) - // the Handshake packet is written by the session + // the Handshake packet is written by the connection // make sure there are no Write calls on the packet conn time.Sleep(50 * time.Millisecond) close(done) @@ -663,11 +662,11 @@ var _ = Describe("Server", func() { tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes() serv.config.AcceptToken = func(net.Addr, *Token) bool { return true } - acceptSession := make(chan struct{}) + acceptConn := make(chan struct{}) var counter uint32 // to be used as an atomic, so we query it in Eventually - serv.newSession = func( + serv.newConn = func( _ sendConn, - runner sessionRunner, + runner connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, @@ -683,15 +682,15 @@ var _ = Describe("Server", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { - <-acceptSession + ) quicConn { + <-acceptConn atomic.AddUint32(&counter, 1) - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().handlePacket(gomock.Any()).MaxTimes(1) - sess.EXPECT().run().MaxTimes(1) - sess.EXPECT().Context().Return(context.Background()).MaxTimes(1) - sess.EXPECT().HandshakeComplete().Return(context.Background()).MaxTimes(1) - return sess + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().handlePacket(gomock.Any()).MaxTimes(1) + conn.EXPECT().run().MaxTimes(1) + conn.EXPECT().Context().Return(context.Background()).MaxTimes(1) + conn.EXPECT().HandshakeComplete().Return(context.Background()).MaxTimes(1) + return conn } p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}) @@ -708,7 +707,7 @@ var _ = Describe("Server", func() { } wg.Wait() - close(acceptSession) + close(acceptConn) Eventually( func() uint32 { return atomic.LoadUint32(&counter) }, scaleDuration(100*time.Millisecond), @@ -716,13 +715,13 @@ var _ = Describe("Server", func() { Consistently(func() uint32 { return atomic.LoadUint32(&counter) }).Should(BeEquivalentTo(protocol.MaxServerUnprocessedPackets + 1)) }) - It("only creates a single session for a duplicate Initial", func() { + It("only creates a single connection for a duplicate Initial", func() { serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - var createdSession bool - sess := NewMockQuicSession(mockCtrl) - serv.newSession = func( + var createdConn bool + conn := NewMockQuicConn(mockCtrl) + serv.newConn = func( _ sendConn, - runner sessionRunner, + runner connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, @@ -738,23 +737,23 @@ var _ = Describe("Server", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { - createdSession = true - return sess + ) quicConn { + createdConn = true + return conn } p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}) phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, gomock.Any(), gomock.Any()).Return(false) Expect(serv.handlePacketImpl(p)).To(BeTrue()) - Expect(createdSession).To(BeFalse()) + Expect(createdConn).To(BeFalse()) }) It("rejects new connection attempts if the accept queue is full", func() { serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - serv.newSession = func( + serv.newConn = func( _ sendConn, - runner sessionRunner, + runner connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, @@ -770,15 +769,15 @@ var _ = Describe("Server", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().handlePacket(gomock.Any()) - sess.EXPECT().run() - sess.EXPECT().Context().Return(context.Background()) + ) quicConn { + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().handlePacket(gomock.Any()) + conn.EXPECT().run() + conn.EXPECT().Context().Return(context.Background()) ctx, cancel := context.WithCancel(context.Background()) cancel() - sess.EXPECT().HandshakeComplete().Return(ctx) - return sess + conn.EXPECT().HandshakeComplete().Return(ctx) + return conn } phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { @@ -818,16 +817,16 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) - It("doesn't accept new sessions if they were closed in the mean time", func() { + It("doesn't accept new connections if they were closed in the mean time", func() { serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) ctx, cancel := context.WithCancel(context.Background()) - sessionCreated := make(chan struct{}) - sess := NewMockQuicSession(mockCtrl) - serv.newSession = func( + connCreated := make(chan struct{}) + conn := NewMockQuicConn(mockCtrl) + serv.newConn = func( _ sendConn, - runner sessionRunner, + runner connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, @@ -843,15 +842,15 @@ var _ = Describe("Server", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { - sess.EXPECT().handlePacket(p) - sess.EXPECT().run() - sess.EXPECT().Context().Return(ctx) + ) quicConn { + conn.EXPECT().handlePacket(p) + conn.EXPECT().run() + conn.EXPECT().Context().Return(ctx) ctx, cancel := context.WithCancel(context.Background()) cancel() - sess.EXPECT().HandshakeComplete().Return(ctx) - close(sessionCreated) - return sess + conn.EXPECT().HandshakeComplete().Return(ctx) + close(connCreated) + return conn } phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { @@ -864,7 +863,7 @@ var _ = Describe("Server", func() { serv.handlePacket(p) // make sure there are no Write calls on the packet conn time.Sleep(50 * time.Millisecond) - Eventually(sessionCreated).Should(BeClosed()) + Eventually(connCreated).Should(BeClosed()) cancel() time.Sleep(scaleDuration(200 * time.Millisecond)) @@ -878,13 +877,13 @@ var _ = Describe("Server", func() { // make the go routine return phm.EXPECT().CloseServer() - sess.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID + conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID Expect(serv.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) }) - Context("accepting sessions", func() { + Context("accepting connections", func() { It("returns Accept when an error occurs", func() { testErr := errors.New("test err") @@ -924,23 +923,23 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) - It("accepts new sessions when the handshake completes", func() { - sess := NewMockQuicSession(mockCtrl) + It("accepts new connections when the handshake completes", func() { + conn := NewMockQuicConn(mockCtrl) done := make(chan struct{}) go func() { defer GinkgoRecover() s, err := serv.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(s).To(Equal(sess)) + Expect(s).To(Equal(conn)) close(done) }() ctx, cancel := context.WithCancel(context.Background()) // handshake context serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - serv.newSession = func( + serv.newConn = func( _ sendConn, - runner sessionRunner, + runner connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, @@ -956,12 +955,12 @@ var _ = Describe("Server", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { - sess.EXPECT().handlePacket(gomock.Any()) - sess.EXPECT().HandshakeComplete().Return(ctx) - sess.EXPECT().run().Do(func() {}) - sess.EXPECT().Context().Return(context.Background()) - return sess + ) quicConn { + conn.EXPECT().handlePacket(gomock.Any()) + conn.EXPECT().HandshakeComplete().Return(ctx) + conn.EXPECT().run().Do(func() {}) + conn.EXPECT().Context().Return(context.Background()) + return conn } phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) @@ -980,7 +979,7 @@ var _ = Describe("Server", func() { }) }) - Context("server accepting sessions that haven't completed the handshake", func() { + Context("server accepting connections that haven't completed the handshake", func() { var ( serv *earlyServer phm *MockPacketHandlerManager @@ -991,7 +990,7 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) serv = ln.(*earlyServer) phm = NewMockPacketHandlerManager(mockCtrl) - serv.sessionHandler = phm + serv.connHandler = phm }) AfterEach(func() { @@ -999,23 +998,23 @@ var _ = Describe("Server", func() { serv.Close() }) - It("accepts new sessions when they become ready", func() { - sess := NewMockQuicSession(mockCtrl) + It("accepts new connections when they become ready", func() { + conn := NewMockQuicConn(mockCtrl) done := make(chan struct{}) go func() { defer GinkgoRecover() s, err := serv.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(s).To(Equal(sess)) + Expect(s).To(Equal(conn)) close(done) }() ready := make(chan struct{}) serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - serv.newSession = func( + serv.newConn = func( _ sendConn, - runner sessionRunner, + runner connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, @@ -1031,13 +1030,13 @@ var _ = Describe("Server", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { + ) quicConn { Expect(enable0RTT).To(BeTrue()) - sess.EXPECT().handlePacket(gomock.Any()) - sess.EXPECT().run().Do(func() {}) - sess.EXPECT().earlySessionReady().Return(ready) - sess.EXPECT().Context().Return(context.Background()) - return sess + conn.EXPECT().handlePacket(gomock.Any()) + conn.EXPECT().run().Do(func() {}) + conn.EXPECT().earlyConnReady().Return(ready) + conn.EXPECT().Context().Return(context.Background()) + return conn } phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) @@ -1057,9 +1056,9 @@ var _ = Describe("Server", func() { serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} - serv.newSession = func( + serv.newConn = func( _ sendConn, - runner sessionRunner, + runner connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, @@ -1075,15 +1074,15 @@ var _ = Describe("Server", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { + ) quicConn { ready := make(chan struct{}) close(ready) - sess := NewMockQuicSession(mockCtrl) - sess.EXPECT().handlePacket(gomock.Any()) - sess.EXPECT().run() - sess.EXPECT().earlySessionReady().Return(ready) - sess.EXPECT().Context().Return(context.Background()) - return sess + conn := NewMockQuicConn(mockCtrl) + conn.EXPECT().handlePacket(gomock.Any()) + conn.EXPECT().run() + conn.EXPECT().earlyConnReady().Return(ready) + conn.EXPECT().Context().Return(context.Background()) + return conn } phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { @@ -1095,7 +1094,7 @@ var _ = Describe("Server", func() { serv.handlePacket(getInitialWithRandomDestConnID()) } - Eventually(func() int32 { return atomic.LoadInt32(&serv.sessionQueueLen) }).Should(BeEquivalentTo(protocol.MaxAcceptQueueSize)) + Eventually(func() int32 { return atomic.LoadInt32(&serv.connQueueLen) }).Should(BeEquivalentTo(protocol.MaxAcceptQueueSize)) // make sure there are no Write calls on the packet conn time.Sleep(50 * time.Millisecond) @@ -1115,16 +1114,16 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) - It("doesn't accept new sessions if they were closed in the mean time", func() { + It("doesn't accept new connections if they were closed in the mean time", func() { serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) ctx, cancel := context.WithCancel(context.Background()) - sessionCreated := make(chan struct{}) - sess := NewMockQuicSession(mockCtrl) - serv.newSession = func( + connCreated := make(chan struct{}) + conn := NewMockQuicConn(mockCtrl) + serv.newConn = func( _ sendConn, - runner sessionRunner, + runner connRunner, _ protocol.ConnectionID, _ *protocol.ConnectionID, _ protocol.ConnectionID, @@ -1140,13 +1139,13 @@ var _ = Describe("Server", func() { _ uint64, _ utils.Logger, _ protocol.VersionNumber, - ) quicSession { - sess.EXPECT().handlePacket(p) - sess.EXPECT().run() - sess.EXPECT().earlySessionReady() - sess.EXPECT().Context().Return(ctx) - close(sessionCreated) - return sess + ) quicConn { + conn.EXPECT().handlePacket(p) + conn.EXPECT().run() + conn.EXPECT().earlyConnReady() + conn.EXPECT().Context().Return(ctx) + close(connCreated) + return conn } phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { @@ -1157,7 +1156,7 @@ var _ = Describe("Server", func() { serv.handlePacket(p) // make sure there are no Write calls on the packet conn time.Sleep(50 * time.Millisecond) - Eventually(sessionCreated).Should(BeClosed()) + Eventually(connCreated).Should(BeClosed()) cancel() time.Sleep(scaleDuration(200 * time.Millisecond)) @@ -1171,7 +1170,7 @@ var _ = Describe("Server", func() { // make the go routine return phm.EXPECT().CloseServer() - sess.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID + conn.EXPECT().getPerspective().MaxTimes(2) // once for every conn ID Expect(serv.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) diff --git a/stream_test.go b/stream_test.go index c3d0f016208..b382efee99b 100644 --- a/stream_test.go +++ b/stream_test.go @@ -99,7 +99,6 @@ var _ = Describe("Stream", func() { var _ = Describe("Deadline Error", func() { It("is a net.Error that wraps os.ErrDeadlineError", func() { err := deadlineError{} - Expect(err.Temporary()).To(BeTrue()) Expect(err.Timeout()).To(BeTrue()) Expect(errors.Is(err, os.ErrDeadlineExceeded)).To(BeTrue()) Expect(errors.Unwrap(err)).To(Equal(os.ErrDeadlineExceeded)) diff --git a/streams_map_test.go b/streams_map_test.go index ffce136b750..29ce7efc1f8 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -38,7 +38,6 @@ func expectTooManyStreamsError(err error) { ExpectWithOffset(1, err.Error()).To(Equal(errTooManyOpenStreams.Error())) nerr, ok := err.(net.Error) ExpectWithOffset(1, ok).To(BeTrue()) - ExpectWithOffset(1, nerr.Temporary()).To(BeTrue()) ExpectWithOffset(1, nerr.Timeout()).To(BeFalse()) } diff --git a/conn.go b/sys_conn.go similarity index 69% rename from conn.go rename to sys_conn.go index 2f4e3a2398c..d73b01d2341 100644 --- a/conn.go +++ b/sys_conn.go @@ -1,7 +1,6 @@ package quic import ( - "io" "net" "syscall" "time" @@ -10,14 +9,8 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) -type connection interface { - ReadPacket() (*receivedPacket, error) - WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) - LocalAddr() net.Addr - io.Closer -} - -// If the PacketConn passed to Dial or Listen satisfies this interface, quic-go will read the ECN bits from the IP header. +// OOBCapablePacketConn is a connection that allows the reading of ECN bits from the IP header. +// If the PacketConn passed to Dial or Listen satisfies this interface, quic-go will use it. // In this case, ReadMsgUDP() will be used instead of ReadFrom() to read packets. type OOBCapablePacketConn interface { net.PacketConn @@ -28,7 +21,20 @@ type OOBCapablePacketConn interface { var _ OOBCapablePacketConn = &net.UDPConn{} -func wrapConn(pc net.PacketConn) (connection, error) { +func wrapConn(pc net.PacketConn) (rawConn, error) { + conn, ok := pc.(interface { + SyscallConn() (syscall.RawConn, error) + }) + if ok { + rawConn, err := conn.SyscallConn() + if err != nil { + return nil, err + } + err = setDF(rawConn) + if err != nil { + return nil, err + } + } c, ok := pc.(OOBCapablePacketConn) if !ok { utils.DefaultLogger.Infof("PacketConn is not a net.UDPConn. Disabling optimizations possible on UDP connections.") @@ -37,11 +43,16 @@ func wrapConn(pc net.PacketConn) (connection, error) { return newConn(c) } +// The basicConn is the most trivial implementation of a connection. +// It reads a single packet from the underlying net.PacketConn. +// It is used when +// * the net.PacketConn is not a OOBCapablePacketConn, and +// * when the OS doesn't support OOB. type basicConn struct { net.PacketConn } -var _ connection = &basicConn{} +var _ rawConn = &basicConn{} func (c *basicConn) ReadPacket() (*receivedPacket, error) { buffer := getPacketBuffer() diff --git a/sys_conn_df.go b/sys_conn_df.go new file mode 100644 index 00000000000..ae9274d97fa --- /dev/null +++ b/sys_conn_df.go @@ -0,0 +1,16 @@ +//go:build !linux && !windows +// +build !linux,!windows + +package quic + +import "syscall" + +func setDF(rawConn syscall.RawConn) error { + // no-op on unsupported platforms + return nil +} + +func isMsgSizeErr(err error) bool { + // to be implemented for more specific platforms + return false +} diff --git a/sys_conn_df_linux.go b/sys_conn_df_linux.go new file mode 100644 index 00000000000..17ac67f12ad --- /dev/null +++ b/sys_conn_df_linux.go @@ -0,0 +1,40 @@ +//go:build linux +// +build linux + +package quic + +import ( + "errors" + "syscall" + + "github.com/lucas-clemente/quic-go/internal/utils" + "golang.org/x/sys/unix" +) + +func setDF(rawConn syscall.RawConn) error { + // Enabling IP_MTU_DISCOVER will force the kernel to return "sendto: message too long" + // and the datagram will not be fragmented + var errDFIPv4, errDFIPv6 error + if err := rawConn.Control(func(fd uintptr) { + errDFIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO) + errDFIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_DO) + }); err != nil { + return err + } + switch { + case errDFIPv4 == nil && errDFIPv6 == nil: + utils.DefaultLogger.Debugf("Setting DF for IPv4 and IPv6.") + case errDFIPv4 == nil && errDFIPv6 != nil: + utils.DefaultLogger.Debugf("Setting DF for IPv4.") + case errDFIPv4 != nil && errDFIPv6 == nil: + utils.DefaultLogger.Debugf("Setting DF for IPv6.") + case errDFIPv4 != nil && errDFIPv6 != nil: + return errors.New("setting DF failed for both IPv4 and IPv6") + } + return nil +} + +func isMsgSizeErr(err error) bool { + // https://man7.org/linux/man-pages/man7/udp.7.html + return errors.Is(err, unix.EMSGSIZE) +} diff --git a/sys_conn_df_windows.go b/sys_conn_df_windows.go new file mode 100644 index 00000000000..4649f6463d2 --- /dev/null +++ b/sys_conn_df_windows.go @@ -0,0 +1,46 @@ +//go:build windows +// +build windows + +package quic + +import ( + "errors" + "syscall" + + "github.com/lucas-clemente/quic-go/internal/utils" + "golang.org/x/sys/windows" +) + +const ( + // same for both IPv4 and IPv6 on Windows + // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IP_DONTFRAG.html + // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IPV6_DONTFRAG.html + IP_DONTFRAGMENT = 14 + IPV6_DONTFRAG = 14 +) + +func setDF(rawConn syscall.RawConn) error { + var errDFIPv4, errDFIPv6 error + if err := rawConn.Control(func(fd uintptr) { + errDFIPv4 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_DONTFRAGMENT, 1) + errDFIPv6 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_DONTFRAG, 1) + }); err != nil { + return err + } + switch { + case errDFIPv4 == nil && errDFIPv6 == nil: + utils.DefaultLogger.Debugf("Setting DF for IPv4 and IPv6.") + case errDFIPv4 == nil && errDFIPv6 != nil: + utils.DefaultLogger.Debugf("Setting DF for IPv4.") + case errDFIPv4 != nil && errDFIPv6 == nil: + utils.DefaultLogger.Debugf("Setting DF for IPv6.") + case errDFIPv4 != nil && errDFIPv6 != nil: + return errors.New("setting DF failed for both IPv4 and IPv6") + } + return nil +} + +func isMsgSizeErr(err error) bool { + // https://docs.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2 + return errors.Is(err, windows.WSAEMSGSIZE) +} diff --git a/conn_helper_darwin.go b/sys_conn_helper_darwin.go similarity index 81% rename from conn_helper_darwin.go rename to sys_conn_helper_darwin.go index fdab73b6159..eabf489f109 100644 --- a/conn_helper_darwin.go +++ b/sys_conn_helper_darwin.go @@ -5,10 +5,7 @@ package quic import "golang.org/x/sys/unix" -const ( - msgTypeIPTOS = unix.IP_RECVTOS - disablePathMTUDiscovery = false -) +const msgTypeIPTOS = unix.IP_RECVTOS const ( ipv4RECVPKTINFO = unix.IP_RECVPKTINFO diff --git a/conn_helper_freebsd.go b/sys_conn_helper_freebsd.go similarity index 75% rename from conn_helper_freebsd.go rename to sys_conn_helper_freebsd.go index e22f986127c..0b3e8434b8a 100644 --- a/conn_helper_freebsd.go +++ b/sys_conn_helper_freebsd.go @@ -6,8 +6,7 @@ package quic import "golang.org/x/sys/unix" const ( - msgTypeIPTOS = unix.IP_RECVTOS - disablePathMTUDiscovery = false + msgTypeIPTOS = unix.IP_RECVTOS ) const ( diff --git a/conn_helper_linux.go b/sys_conn_helper_linux.go similarity index 81% rename from conn_helper_linux.go rename to sys_conn_helper_linux.go index 4aa04dc9c2e..51bec900242 100644 --- a/conn_helper_linux.go +++ b/sys_conn_helper_linux.go @@ -5,10 +5,7 @@ package quic import "golang.org/x/sys/unix" -const ( - msgTypeIPTOS = unix.IP_TOS - disablePathMTUDiscovery = false -) +const msgTypeIPTOS = unix.IP_TOS const ( ipv4RECVPKTINFO = unix.IP_PKTINFO diff --git a/conn_generic.go b/sys_conn_no_oob.go similarity index 75% rename from conn_generic.go rename to sys_conn_no_oob.go index 526778c1ccc..e3b0d11f685 100644 --- a/conn_generic.go +++ b/sys_conn_no_oob.go @@ -5,9 +5,7 @@ package quic import "net" -const disablePathMTUDiscovery = false - -func newConn(c net.PacketConn) (connection, error) { +func newConn(c net.PacketConn) (rawConn, error) { return &basicConn{PacketConn: c}, nil } diff --git a/conn_oob.go b/sys_conn_oob.go similarity index 99% rename from conn_oob.go rename to sys_conn_oob.go index b46781377d2..acd74d023c1 100644 --- a/conn_oob.go +++ b/sys_conn_oob.go @@ -64,7 +64,7 @@ type oobConn struct { buffers [batchSize]*packetBuffer } -var _ connection = &oobConn{} +var _ rawConn = &oobConn{} func newConn(c OOBCapablePacketConn) (*oobConn, error) { rawConn, err := c.SyscallConn() diff --git a/conn_oob_test.go b/sys_conn_oob_test.go similarity index 100% rename from conn_oob_test.go rename to sys_conn_oob_test.go diff --git a/conn_test.go b/sys_conn_test.go similarity index 100% rename from conn_test.go rename to sys_conn_test.go diff --git a/conn_windows.go b/sys_conn_windows.go similarity index 56% rename from conn_windows.go rename to sys_conn_windows.go index a6e591b62aa..f2cc22ab7c4 100644 --- a/conn_windows.go +++ b/sys_conn_windows.go @@ -12,24 +12,7 @@ import ( "golang.org/x/sys/windows" ) -const ( - disablePathMTUDiscovery = true - IP_DONTFRAGMENT = 14 -) - -func newConn(c OOBCapablePacketConn) (connection, error) { - rawConn, err := c.SyscallConn() - if err != nil { - return nil, fmt.Errorf("couldn't get syscall.RawConn: %w", err) - } - if err := rawConn.Control(func(fd uintptr) { - // This should succeed if the connection is a IPv4 or a dual-stack connection. - // It will fail for IPv6 connections. - // TODO: properly handle error. - _ = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_DONTFRAGMENT, 1) - }); err != nil { - return nil, err - } +func newConn(c OOBCapablePacketConn) (rawConn, error) { return &basicConn{PacketConn: c}, nil } diff --git a/conn_windows_test.go b/sys_conn_windows_test.go similarity index 100% rename from conn_windows_test.go rename to sys_conn_windows_test.go