diff --git a/http3/server.go b/http3/server.go index 2ae3fef5af7..c568decbdf3 100644 --- a/http3/server.go +++ b/http3/server.go @@ -51,6 +51,44 @@ func versionToALPN(v protocol.VersionNumber) string { return "" } +// ConfigureTLSConfig creates a new tls.Config which can be used +// to create a quic.Listener meant for serving http3. The created +// tls.Config adds the functionality of detecting the used QUIC version +// in order to set the correct ALPN value for the http3 connection. +func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config { + // The tls.Config used to setup the quic.Listener needs to have the GetConfigForClient callback set. + // That way, we can get the QUIC version and set the correct ALPN value. + return &tls.Config{ + GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + // determine the ALPN from the QUIC version used + proto := nextProtoH3Draft29 + if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok { + if qconn.GetQUICVersion() == protocol.Version1 { + proto = nextProtoH3 + } + } + config := tlsConf + if tlsConf.GetConfigForClient != nil { + getConfigForClient := tlsConf.GetConfigForClient + var err error + conf, err := getConfigForClient(ch) + if err != nil { + return nil, err + } + if conf != nil { + config = conf + } + } + if config == nil { + return nil, nil + } + config = config.Clone() + config.NextProtos = []string{proto} + return config, nil + }, + } +} + // contextKey is a value for use with context.WithValue. It's used as // a pointer so it fits in an interface{} without allocation. type contextKey struct { @@ -111,7 +149,7 @@ func (s *Server) ListenAndServe() error { if s.Server == nil { return errors.New("use of http3.Server without http.Server") } - return s.serveImpl(s.TLSConfig, nil) + return s.serveConn(s.TLSConfig, nil) } // ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/3 requests on incoming connections. @@ -127,17 +165,52 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { config := &tls.Config{ Certificates: certs, } - return s.serveImpl(config, nil) + return s.serveConn(config, nil) } // Serve an existing UDP connection. // It is possible to reuse the same connection for outgoing connections. // Closing the server does not close the packet conn. func (s *Server) Serve(conn net.PacketConn) error { - return s.serveImpl(s.TLSConfig, conn) + return s.serveConn(s.TLSConfig, conn) +} + +// Serve an existing QUIC listener. +// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config +// and use it to construct a http3-friendly QUIC listener. +// Closing the server does close the listener. +func (s *Server) ServeListener(listener quic.EarlyListener) error { + return s.serveImpl(func() (quic.EarlyListener, error) { return listener, nil }) +} + +func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error { + return s.serveImpl(func() (quic.EarlyListener, error) { + baseConf := ConfigureTLSConfig(tlsConf) + quicConf := s.QuicConfig + if quicConf == nil { + quicConf = &quic.Config{} + } else { + quicConf = s.QuicConfig.Clone() + } + if s.EnableDatagrams { + quicConf.EnableDatagrams = true + } + + var ln quic.EarlyListener + var err error + if conn == nil { + ln, err = quicListenAddr(s.Addr, baseConf, quicConf) + } else { + ln, err = quicListen(conn, baseConf, quicConf) + } + if err != nil { + return nil, err + } + return ln, nil + }) } -func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error { +func (s *Server) serveImpl(startListener func() (quic.EarlyListener, error)) error { if s.closed.Get() { return http.ErrServerClosed } @@ -148,54 +221,7 @@ func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error { s.logger = utils.DefaultLogger.WithPrefix("server") }) - // The tls.Config we pass to Listen needs to have the GetConfigForClient callback set. - // That way, we can get the QUIC version and set the correct ALPN value. - baseConf := &tls.Config{ - GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { - // determine the ALPN from the QUIC version used - proto := nextProtoH3Draft29 - if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok { - if qconn.GetQUICVersion() == protocol.Version1 { - proto = nextProtoH3 - } - } - config := tlsConf - if tlsConf.GetConfigForClient != nil { - getConfigForClient := tlsConf.GetConfigForClient - var err error - conf, err := getConfigForClient(ch) - if err != nil { - return nil, err - } - if conf != nil { - config = conf - } - } - if config == nil { - return nil, nil - } - config = config.Clone() - config.NextProtos = []string{proto} - return config, nil - }, - } - - var ln quic.EarlyListener - var err error - quicConf := s.QuicConfig - if quicConf == nil { - quicConf = &quic.Config{} - } else { - quicConf = s.QuicConfig.Clone() - } - if s.EnableDatagrams { - quicConf.EnableDatagrams = true - } - if conn == nil { - ln, err = quicListenAddr(s.Addr, baseConf, quicConf) - } else { - ln, err = quicListen(conn, baseConf, quicConf) - } + ln, err := startListener() if err != nil { return err } diff --git a/http3/server_test.go b/http3/server_test.go index 02e9c4166be..0ff6bdf8e95 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -9,6 +9,7 @@ import ( "io" "net" "net/http" + "sync/atomic" "time" "github.com/lucas-clemente/quic-go" @@ -619,6 +620,35 @@ var _ = Describe("Server", func() { Expect(serv.ListenAndServe()).To(MatchError(http.ErrServerClosed)) }) + Context("ConfigureTLSConfig", func() { + var tlsConf *tls.Config + var ch *tls.ClientHelloInfo + + BeforeEach(func() { + tlsConf = &tls.Config{} + ch = &tls.ClientHelloInfo{} + }) + + It("advertises draft by default", func() { + tlsConf = ConfigureTLSConfig(tlsConf) + Expect(tlsConf.GetConfigForClient).NotTo(BeNil()) + + config, err := tlsConf.GetConfigForClient(ch) + Expect(err).NotTo(HaveOccurred()) + Expect(config.NextProtos).To(Equal([]string{nextProtoH3Draft29})) + }) + + It("advertises h3 for quic version 1", func() { + tlsConf = ConfigureTLSConfig(tlsConf) + Expect(tlsConf.GetConfigForClient).NotTo(BeNil()) + + ch.Conn = newMockConn(protocol.Version1) + config, err := tlsConf.GetConfigForClient(ch) + Expect(err).NotTo(HaveOccurred()) + Expect(config.NextProtos).To(Equal([]string{nextProtoH3})) + }) + }) + Context("Serve", func() { origQuicListen := quicListen @@ -704,6 +734,93 @@ var _ = Describe("Server", func() { }) }) + Context("ServeListener", func() { + origQuicListen := quicListen + + AfterEach(func() { + quicListen = origQuicListen + }) + + It("serves a listener", func() { + var called int32 + ln := mockquic.NewMockEarlyListener(mockCtrl) + quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + atomic.StoreInt32(&called, 1) + return ln, nil + } + + s := &Server{Server: &http.Server{}} + s.TLSConfig = &tls.Config{} + + stopAccept := make(chan struct{}) + ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) { + <-stopAccept + return nil, errors.New("closed") + }) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + s.ServeListener(ln) + }() + + Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) + Consistently(done).ShouldNot(BeClosed()) + ln.EXPECT().Close().Do(func() { close(stopAccept) }) + Expect(s.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + + It("serves two listeners", func() { + var called int32 + ln1 := mockquic.NewMockEarlyListener(mockCtrl) + ln2 := mockquic.NewMockEarlyListener(mockCtrl) + lns := make(chan quic.EarlyListener, 2) + lns <- ln1 + lns <- ln2 + quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { + atomic.StoreInt32(&called, 1) + return <-lns, nil + } + + s := &Server{Server: &http.Server{}} + s.TLSConfig = &tls.Config{} + + stopAccept1 := make(chan struct{}) + ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) { + <-stopAccept1 + return nil, errors.New("closed") + }) + stopAccept2 := make(chan struct{}) + ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) { + <-stopAccept2 + return nil, errors.New("closed") + }) + + done1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done1) + s.ServeListener(ln1) + }() + done2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done2) + s.ServeListener(ln2) + }() + + Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) + Consistently(done1).ShouldNot(BeClosed()) + Expect(done2).ToNot(BeClosed()) + ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) + ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) + Expect(s.Close()).To(Succeed()) + Eventually(done1).Should(BeClosed()) + Eventually(done2).Should(BeClosed()) + }) + }) + Context("ListenAndServe", func() { BeforeEach(func() { s.Server.Addr = "localhost:0" diff --git a/integrationtests/self/hotswap_test.go b/integrationtests/self/hotswap_test.go new file mode 100644 index 00000000000..112e631662d --- /dev/null +++ b/integrationtests/self/hotswap_test.go @@ -0,0 +1,190 @@ +package self_test + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "strconv" + "sync/atomic" + "time" + + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/http3" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/testdata" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" +) + +type listenerWrapper struct { + quic.EarlyListener + listenerClosed bool + count int32 +} + +func (ln *listenerWrapper) Close() error { + ln.listenerClosed = true + return ln.EarlyListener.Close() +} + +func (ln *listenerWrapper) Faker() *fakeClosingListener { + atomic.AddInt32(&ln.count, 1) + ctx, cancel := context.WithCancel(context.Background()) + return &fakeClosingListener{ln, 0, ctx, cancel} +} + +type fakeClosingListener struct { + *listenerWrapper + closed int32 + ctx context.Context + cancel context.CancelFunc +} + +func (ln *fakeClosingListener) Accept(ctx context.Context) (quic.EarlySession, error) { + Expect(ctx).To(Equal(context.Background())) + return ln.listenerWrapper.Accept(ln.ctx) +} + +func (ln *fakeClosingListener) Close() error { + if atomic.CompareAndSwapInt32(&ln.closed, 0, 1) { + ln.cancel() + if atomic.AddInt32(&ln.listenerWrapper.count, -1) == 0 { + ln.listenerWrapper.Close() + } + } + return nil +} + +var _ = Describe("HTTP3 Server hotswap test", func() { + var ( + mux1 *http.ServeMux + mux2 *http.ServeMux + client *http.Client + server1 *http3.Server + server2 *http3.Server + ln *listenerWrapper + port string + ) + + versions := protocol.SupportedVersions + + BeforeEach(func() { + mux1 = http.NewServeMux() + mux1.HandleFunc("/hello1", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + io.WriteString(w, "Hello, World 1!\n") // don't check the error here. Stream may be reset. + }) + + mux2 = http.NewServeMux() + mux2.HandleFunc("/hello2", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + io.WriteString(w, "Hello, World 2!\n") // don't check the error here. Stream may be reset. + }) + + server1 = &http3.Server{ + Server: &http.Server{ + Handler: mux1, + TLSConfig: testdata.GetTLSConfig(), + }, + QuicConfig: getQuicConfig(&quic.Config{Versions: versions}), + } + + server2 = &http3.Server{ + Server: &http.Server{ + Handler: mux2, + TLSConfig: testdata.GetTLSConfig(), + }, + QuicConfig: getQuicConfig(&quic.Config{Versions: versions}), + } + + tlsConf := http3.ConfigureTLSConfig(testdata.GetTLSConfig()) + quicln, err := quic.ListenAddrEarly("0.0.0.0:0", tlsConf, getQuicConfig(&quic.Config{Versions: versions})) + ln = &listenerWrapper{EarlyListener: quicln} + Expect(err).NotTo(HaveOccurred()) + port = strconv.Itoa(ln.Addr().(*net.UDPAddr).Port) + }) + + AfterEach(func() { + Expect(ln.Close()).NotTo(HaveOccurred()) + }) + + for _, v := range versions { + version := v + + Context(fmt.Sprintf("with QUIC version %s", version), func() { + BeforeEach(func() { + client = &http.Client{ + Transport: &http3.RoundTripper{ + TLSClientConfig: &tls.Config{ + RootCAs: testdata.GetRootCA(), + }, + DisableCompression: true, + QuicConfig: getQuicConfig(&quic.Config{ + Versions: []protocol.VersionNumber{version}, + MaxIdleTimeout: 10 * time.Second, + }), + }, + } + }) + + It("hotswap works", func() { + // open first server and make single request to it + fake1 := ln.Faker() + stoppedServing1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + server1.ServeListener(fake1) + close(stoppedServing1) + }() + + resp, err := client.Get("https://localhost:" + port + "/hello1") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal("Hello, World 1!\n")) + + // open second server with same underlying listener, + // make sure it opened and both servers are currently running + fake2 := ln.Faker() + stoppedServing2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + server2.ServeListener(fake2) + close(stoppedServing2) + }() + + Consistently(stoppedServing1).ShouldNot(BeClosed()) + Consistently(stoppedServing2).ShouldNot(BeClosed()) + + // now close first server, no errors should occur here + // and only the fake listener should be closed + Expect(server1.Close()).NotTo(HaveOccurred()) + Eventually(stoppedServing1).Should(BeClosed()) + Expect(fake1.closed).To(Equal(int32(1))) + Expect(fake2.closed).To(Equal(int32(0))) + Expect(ln.listenerClosed).ToNot(BeTrue()) + Expect(client.Transport.(*http3.RoundTripper).Close()).NotTo(HaveOccurred()) + + // verify that new sessions are being initiated from the second server now + resp, err = client.Get("https://localhost:" + port + "/hello2") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + body, err = io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal("Hello, World 2!\n")) + + // close the other server - both the fake and the actual listeners must close now + Expect(server2.Close()).NotTo(HaveOccurred()) + Eventually(stoppedServing2).Should(BeClosed()) + Expect(fake2.closed).To(Equal(int32(1))) + Expect(ln.listenerClosed).To(BeTrue()) + }) + }) + } +})