diff --git a/connection.go b/connection.go index 24760a788cd..fae89e593d0 100644 --- a/connection.go +++ b/connection.go @@ -314,6 +314,8 @@ var newConnection = func( } if s.config.EnableDatagrams { params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + } else { + params.MaxDatagramFrameSize = protocol.InvalidByteCount } if s.tracer != nil { s.tracer.SentTransportParameters(params) @@ -438,6 +440,8 @@ var newClientConnection = func( } if s.config.EnableDatagrams { params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + } else { + params.MaxDatagramFrameSize = protocol.InvalidByteCount } if s.tracer != nil { s.tracer.SentTransportParameters(params) @@ -532,9 +536,7 @@ func (s *connection) preSetup() { s.creationTime = now s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame) - if s.config.EnableDatagrams { - s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger) - } + s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger) } // run the connection main loop @@ -724,7 +726,7 @@ func (s *connection) Context() context.Context { } func (s *connection) supportsDatagrams() bool { - return s.peerParams.MaxDatagramFrameSize != protocol.InvalidByteCount + return s.peerParams.MaxDatagramFrameSize > 0 } func (s *connection) ConnectionState() ConnectionState { @@ -1975,6 +1977,10 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) { } func (s *connection) SendMessage(p []byte) error { + if !s.supportsDatagrams() { + return errors.New("datagram support disabled") + } + f := &wire.DatagramFrame{DataLenPresent: true} if protocol.ByteCount(len(p)) > f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version) { return errors.New("message too large") @@ -1985,6 +1991,9 @@ func (s *connection) SendMessage(p []byte) error { } func (s *connection) ReceiveMessage() ([]byte, error) { + if !s.config.EnableDatagrams { + return nil, errors.New("datagram support disabled") + } return s.datagramQueue.Receive() } diff --git a/integrationtests/self/datagram_test.go b/integrationtests/self/datagram_test.go index 3dd47ac3ee1..c750b3c3381 100644 --- a/integrationtests/self/datagram_test.go +++ b/integrationtests/self/datagram_test.go @@ -32,7 +32,7 @@ var _ = Describe("Datagram test", func() { dropped, total int32 ) - startServerAndProxy := func() { + startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) { addr, err := net.ResolveUDPAddr("udp", "localhost:0") Expect(err).ToNot(HaveOccurred()) serverConn, err = net.ListenUDP("udp", addr) @@ -41,30 +41,39 @@ var _ = Describe("Datagram test", func() { serverConn, getTLSConfig(), getQuicConfig(&quic.Config{ - EnableDatagrams: true, + EnableDatagrams: enableDatagram, Versions: []protocol.VersionNumber{version}, }), ) Expect(err).ToNot(HaveOccurred()) + go func() { defer GinkgoRecover() conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) - - var wg sync.WaitGroup - wg.Add(num) - for i := 0; i < num; i++ { - go func(i int) { - defer GinkgoRecover() - defer wg.Done() - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, uint64(i)) - Expect(conn.SendMessage(b)).To(Succeed()) - }(i) + + if expectDatagramSupport { + Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) + + if enableDatagram { + var wg sync.WaitGroup + wg.Add(num) + for i := 0; i < num; i++ { + go func(i int) { + defer GinkgoRecover() + defer wg.Done() + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, uint64(i)) + Expect(conn.SendMessage(b)).To(Succeed()) + }(i) + } + wg.Wait() + } + } else { + Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) } - wg.Wait() }() + serverPort := ln.Addr().(*net.UDPAddr).Port proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), @@ -100,7 +109,7 @@ var _ = Describe("Datagram test", func() { }) It("sends datagrams", func() { - startServerAndProxy() + startServerAndProxy(true, true) raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) Expect(err).ToNot(HaveOccurred()) conn, err := quic.Dial( @@ -137,6 +146,49 @@ var _ = Describe("Datagram test", func() { BeNumerically("<", num), )) }) + + It("server can disable datagram", func() { + startServerAndProxy(false, true) + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + Expect(err).ToNot(HaveOccurred()) + conn, err := quic.Dial( + clientConn, + raddr, + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + getTLSClientConfig(), + getQuicConfig(&quic.Config{ + EnableDatagrams: true, + Versions: []protocol.VersionNumber{version}, + }), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) + + conn.CloseWithError(0, "") + <-time.After(10 * time.Millisecond) + }) + + It("client can disable datagram", func() { + startServerAndProxy(false, true) + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + Expect(err).ToNot(HaveOccurred()) + conn, err := quic.Dial( + clientConn, + raddr, + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + getTLSClientConfig(), + getQuicConfig(&quic.Config{ + EnableDatagrams: true, + Versions: []protocol.VersionNumber{version}, + }), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) + + Expect(conn.SendMessage([]byte{0})).To(HaveOccurred()) + conn.CloseWithError(0, "") + <-time.After(10 * time.Millisecond) + }) }) } })