diff --git a/http3/server.go b/http3/server.go index 645b2b3ea27..8d50f5b12a7 100644 --- a/http3/server.go +++ b/http3/server.go @@ -44,7 +44,7 @@ const ( ) func versionToALPN(v protocol.VersionNumber) string { - if v == protocol.Version1 { + if v == protocol.Version1 || v == protocol.Version2 { return nextProtoH3 } if v == protocol.VersionTLS || v == protocol.VersionDraft29 { @@ -63,11 +63,9 @@ func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config { return &tls.Config{ GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { // determine the ALPN from the QUIC version used - proto := nextProtoH3Draft29 + proto := nextProtoH3 if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok { - if qconn.GetQUICVersion() == protocol.Version1 { - proto = nextProtoH3 - } + proto = versionToALPN(qconn.GetQUICVersion()) } config := tlsConf if tlsConf.GetConfigForClient != nil { diff --git a/http3/server_test.go b/http3/server_test.go index b5e23f775aa..e481867ef1c 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -819,23 +819,23 @@ var _ = Describe("Server", func() { ch = &tls.ClientHelloInfo{} }) - It("advertises draft by default", func() { + It("advertises v1 by default", func() { tlsConf = ConfigureTLSConfig(tlsConf) Expect(tlsConf.GetConfigForClient).NotTo(BeNil()) config, err := tlsConf.GetConfigForClient(ch) Expect(err).NotTo(HaveOccurred()) - Expect(config.NextProtos).To(Equal([]string{nextProtoH3Draft29})) + Expect(config.NextProtos).To(Equal([]string{nextProtoH3})) }) - It("advertises h3 for quic version 1", func() { + It("advertises h3-29 for draft-29", func() { tlsConf = ConfigureTLSConfig(tlsConf) Expect(tlsConf.GetConfigForClient).NotTo(BeNil()) - ch.Conn = newMockConn(protocol.Version1) + ch.Conn = newMockConn(protocol.VersionDraft29) config, err := tlsConf.GetConfigForClient(ch) Expect(err).NotTo(HaveOccurred()) - Expect(config.NextProtos).To(Equal([]string{nextProtoH3})) + Expect(config.NextProtos).To(Equal([]string{nextProtoH3Draft29})) }) }) diff --git a/interface.go b/interface.go index cb1c1de3a39..80c7517e92e 100644 --- a/interface.go +++ b/interface.go @@ -23,6 +23,7 @@ const ( VersionDraft29 = protocol.VersionDraft29 // Version1 is RFC 9000 Version1 = protocol.Version1 + Version2 = protocol.Version2 ) // A Token can be used to verify the ownership of the client address. diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go index 54eda9b7d86..03b039289ce 100644 --- a/internal/handshake/aead.go +++ b/internal/handshake/aead.go @@ -9,9 +9,15 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) -func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) cipher.AEAD { - key := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic key", suite.KeyLen) - iv := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic iv", suite.IVLen()) +func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD { + keyLabel := hkdfLabelKeyV1 + ivLabel := hkdfLabelIVV1 + if v == protocol.Version2 { + keyLabel = hkdfLabelKeyV2 + ivLabel = hkdfLabelIVV2 + } + key := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, keyLabel, suite.KeyLen) + iv := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, ivLabel, suite.IVLen()) return suite.AEAD(key, iv) } diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go index 922f068d39a..d035d7240b7 100644 --- a/internal/handshake/aead_test.go +++ b/internal/handshake/aead_test.go @@ -15,184 +15,190 @@ import ( ) var _ = Describe("Long Header AEAD", func() { - for i := range cipherSuites { - cs := cipherSuites[i] + for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1, protocol.Version2} { + v := ver + + Context(fmt.Sprintf("using version %s", v), func() { + for i := range cipherSuites { + cs := cipherSuites[i] + + Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() { + getSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) { + key := make([]byte, 16) + hpKey := make([]byte, 16) + rand.Read(key) + rand.Read(hpKey) + block, err := aes.NewCipher(key) + Expect(err).ToNot(HaveOccurred()) + aead, err := cipher.NewGCM(block) + Expect(err).ToNot(HaveOccurred()) + + return newLongHeaderSealer(aead, newHeaderProtector(cs, hpKey, true, v)), + newLongHeaderOpener(aead, newHeaderProtector(cs, hpKey, true, v)) + } + + Context("message encryption", func() { + msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + ad := []byte("Donec in velit neque.") + + It("encrypts and decrypts a message", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + opened, err := opener.Open(nil, encrypted, 0x1337, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opened).To(Equal(msg)) + }) + + It("fails to open a message if the associated data is not the same", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("fails to open a message if the packet number is not the same", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x42, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("decodes the packet number", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted, 0x1337, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338)) + }) + + It("ignores packets it can't decrypt for packet number derivation", func() { + sealer, opener := getSealerAndOpener() + encrypted := sealer.Seal(nil, msg, 0x1337, ad) + _, err := opener.Open(nil, encrypted[:len(encrypted)-1], 0x1337, ad) + Expect(err).To(HaveOccurred()) + Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38)) + }) + }) + + Context("header encryption", func() { + It("encrypts and encrypts the header", func() { + sealer, opener := getSealerAndOpener() + var lastFourBitsDifferent int + for i := 0; i < 100; i++ { + sample := make([]byte, 16) + rand.Read(sample) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + sealer.EncryptHeader(sample, &header[0], header[9:13]) + if header[0]&0xf != 0xb5&0xf { + lastFourBitsDifferent++ + } + Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) + Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + opener.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + } + Expect(lastFourBitsDifferent).To(BeNumerically(">", 75)) + }) + + It("encrypts and encrypts the header, for a 0xfff..fff sample", func() { + sealer, opener := getSealerAndOpener() + var lastFourBitsDifferent int + for i := 0; i < 100; i++ { + sample := bytes.Repeat([]byte{0xff}, 16) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + sealer.EncryptHeader(sample, &header[0], header[9:13]) + if header[0]&0xf != 0xb5&0xf { + lastFourBitsDifferent++ + } + Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) + Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + opener.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + } + }) + + It("fails to decrypt the header when using a different sample", func() { + sealer, opener := getSealerAndOpener() + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + sample := make([]byte, 16) + rand.Read(sample) + sealer.EncryptHeader(sample, &header[0], header[9:13]) + rand.Read(sample) // use a different sample + opener.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).ToNot(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + }) + }) + }) + } + }) - Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() { - getSealerAndOpener := func() (LongHeaderSealer, LongHeaderOpener) { + Describe("Long Header AEAD", func() { + var ( + dropped chan struct{} // use a chan because closing it twice will panic + aead cipher.AEAD + hp headerProtector + ) + dropCb := func() { close(dropped) } + msg := []byte("Lorem ipsum dolor sit amet.") + ad := []byte("Donec in velit neque.") + + BeforeEach(func() { + dropped = make(chan struct{}) key := make([]byte, 16) hpKey := make([]byte, 16) rand.Read(key) rand.Read(hpKey) block, err := aes.NewCipher(key) Expect(err).ToNot(HaveOccurred()) - aead, err := cipher.NewGCM(block) + aead, err = cipher.NewGCM(block) Expect(err).ToNot(HaveOccurred()) + hp = newHeaderProtector(cipherSuites[0], hpKey, true, protocol.Version1) + }) - return newLongHeaderSealer(aead, newHeaderProtector(cs, hpKey, true)), - newLongHeaderOpener(aead, newHeaderProtector(cs, hpKey, true)) - } - - Context("message encryption", func() { - msg := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - ad := []byte("Donec in velit neque.") - - It("encrypts and decrypts a message", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - opened, err := opener.Open(nil, encrypted, 0x1337, ad) + Context("for the server", func() { + It("drops keys when first successfully processing a Handshake packet", func() { + serverOpener := newHandshakeOpener(aead, hp, dropCb, protocol.PerspectiveServer) + // first try to open an invalid message + _, err := serverOpener.Open(nil, []byte("invalid"), 0, []byte("invalid")) + Expect(err).To(HaveOccurred()) + Expect(dropped).ToNot(BeClosed()) + // then open a valid message + enc := newLongHeaderSealer(aead, hp).Seal(nil, msg, 10, ad) + _, err = serverOpener.Open(nil, enc, 10, ad) Expect(err).ToNot(HaveOccurred()) - Expect(opened).To(Equal(msg)) - }) - - It("fails to open a message if the associated data is not the same", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted, 0x1337, []byte("wrong ad")) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("fails to open a message if the packet number is not the same", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted, 0x42, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("decodes the packet number", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted, 0x1337, ad) + Expect(dropped).To(BeClosed()) + // now open the same message again to make sure the callback is only called once + _, err = serverOpener.Open(nil, enc, 10, ad) Expect(err).ToNot(HaveOccurred()) - Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338)) }) - It("ignores packets it can't decrypt for packet number derivation", func() { - sealer, opener := getSealerAndOpener() - encrypted := sealer.Seal(nil, msg, 0x1337, ad) - _, err := opener.Open(nil, encrypted[:len(encrypted)-1], 0x1337, ad) - Expect(err).To(HaveOccurred()) - Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38)) + It("doesn't drop keys when sealing a Handshake packet", func() { + serverSealer := newHandshakeSealer(aead, hp, dropCb, protocol.PerspectiveServer) + serverSealer.Seal(nil, msg, 1, ad) + Expect(dropped).ToNot(BeClosed()) }) }) - Context("header encryption", func() { - It("encrypts and encrypts the header", func() { - sealer, opener := getSealerAndOpener() - var lastFourBitsDifferent int - for i := 0; i < 100; i++ { - sample := make([]byte, 16) - rand.Read(sample) - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - sealer.EncryptHeader(sample, &header[0], header[9:13]) - if header[0]&0xf != 0xb5&0xf { - lastFourBitsDifferent++ - } - Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) - Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) - Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - opener.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) - } - Expect(lastFourBitsDifferent).To(BeNumerically(">", 75)) - }) - - It("encrypts and encrypts the header, for a 0xfff..fff sample", func() { - sealer, opener := getSealerAndOpener() - var lastFourBitsDifferent int - for i := 0; i < 100; i++ { - sample := bytes.Repeat([]byte{0xff}, 16) - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - sealer.EncryptHeader(sample, &header[0], header[9:13]) - if header[0]&0xf != 0xb5&0xf { - lastFourBitsDifferent++ - } - Expect(header[0] & 0xf0).To(Equal(byte(0xb5 & 0xf0))) - Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) - Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - opener.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) - } + Context("for the client", func() { + It("drops keys when first sealing a Handshake packet", func() { + clientSealer := newHandshakeSealer(aead, hp, dropCb, protocol.PerspectiveClient) + // seal the first message + clientSealer.Seal(nil, msg, 1, ad) + Expect(dropped).To(BeClosed()) + // seal another message to make sure the callback is only called once + clientSealer.Seal(nil, msg, 2, ad) }) - It("fails to decrypt the header when using a different sample", func() { - sealer, opener := getSealerAndOpener() - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - sample := make([]byte, 16) - rand.Read(sample) - sealer.EncryptHeader(sample, &header[0], header[9:13]) - rand.Read(sample) // use a different sample - opener.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).ToNot(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + It("doesn't drop keys when processing a Handshake packet", func() { + enc := newLongHeaderSealer(aead, hp).Seal(nil, msg, 42, ad) + clientOpener := newHandshakeOpener(aead, hp, dropCb, protocol.PerspectiveClient) + _, err := clientOpener.Open(nil, enc, 42, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(dropped).ToNot(BeClosed()) }) }) }) } }) - -var _ = Describe("Long Header AEAD", func() { - var ( - dropped chan struct{} // use a chan because closing it twice will panic - aead cipher.AEAD - hp headerProtector - ) - dropCb := func() { close(dropped) } - msg := []byte("Lorem ipsum dolor sit amet.") - ad := []byte("Donec in velit neque.") - - BeforeEach(func() { - dropped = make(chan struct{}) - key := make([]byte, 16) - hpKey := make([]byte, 16) - rand.Read(key) - rand.Read(hpKey) - block, err := aes.NewCipher(key) - Expect(err).ToNot(HaveOccurred()) - aead, err = cipher.NewGCM(block) - Expect(err).ToNot(HaveOccurred()) - hp = newHeaderProtector(cipherSuites[0], hpKey, true) - }) - - Context("for the server", func() { - It("drops keys when first successfully processing a Handshake packet", func() { - serverOpener := newHandshakeOpener(aead, hp, dropCb, protocol.PerspectiveServer) - // first try to open an invalid message - _, err := serverOpener.Open(nil, []byte("invalid"), 0, []byte("invalid")) - Expect(err).To(HaveOccurred()) - Expect(dropped).ToNot(BeClosed()) - // then open a valid message - enc := newLongHeaderSealer(aead, hp).Seal(nil, msg, 10, ad) - _, err = serverOpener.Open(nil, enc, 10, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(dropped).To(BeClosed()) - // now open the same message again to make sure the callback is only called once - _, err = serverOpener.Open(nil, enc, 10, ad) - Expect(err).ToNot(HaveOccurred()) - }) - - It("doesn't drop keys when sealing a Handshake packet", func() { - serverSealer := newHandshakeSealer(aead, hp, dropCb, protocol.PerspectiveServer) - serverSealer.Seal(nil, msg, 1, ad) - Expect(dropped).ToNot(BeClosed()) - }) - }) - - Context("for the client", func() { - It("drops keys when first sealing a Handshake packet", func() { - clientSealer := newHandshakeSealer(aead, hp, dropCb, protocol.PerspectiveClient) - // seal the first message - clientSealer.Seal(nil, msg, 1, ad) - Expect(dropped).To(BeClosed()) - // seal another message to make sure the callback is only called once - clientSealer.Seal(nil, msg, 2, ad) - }) - - It("doesn't drop keys when processing a Handshake packet", func() { - enc := newLongHeaderSealer(aead, hp).Seal(nil, msg, 42, ad) - clientOpener := newHandshakeOpener(aead, hp, dropCb, protocol.PerspectiveClient) - _, err := clientOpener.Open(nil, enc, 42, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(dropped).ToNot(BeClosed()) - }) - }) -}) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 25b35cffcca..31d9bf0aa40 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -246,7 +246,7 @@ func newCryptoSetup( initialSealer: initialSealer, initialOpener: initialOpener, handshakeStream: handshakeStream, - aead: newUpdatableAEAD(rttStats, tracer, logger), + aead: newUpdatableAEAD(rttStats, tracer, logger, version), readEncLevel: protocol.EncryptionInitial, writeEncLevel: protocol.EncryptionInitial, runner: runner, @@ -572,8 +572,8 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph panic("Received 0-RTT read key for the client") } h.zeroRTTOpener = newLongHeaderOpener( - createAEAD(suite, trafficSecret), - newHeaderProtector(suite, trafficSecret, true), + createAEAD(suite, trafficSecret, h.version), + newHeaderProtector(suite, trafficSecret, true, h.version), ) h.mutex.Unlock() h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) @@ -584,8 +584,8 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph case qtls.EncryptionHandshake: h.readEncLevel = protocol.EncryptionHandshake h.handshakeOpener = newHandshakeOpener( - createAEAD(suite, trafficSecret), - newHeaderProtector(suite, trafficSecret, true), + createAEAD(suite, trafficSecret, h.version), + newHeaderProtector(suite, trafficSecret, true, h.version), h.dropInitialKeys, h.perspective, ) @@ -612,8 +612,8 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip panic("Received 0-RTT write key for the server") } h.zeroRTTSealer = newLongHeaderSealer( - createAEAD(suite, trafficSecret), - newHeaderProtector(suite, trafficSecret, true), + createAEAD(suite, trafficSecret, h.version), + newHeaderProtector(suite, trafficSecret, true, h.version), ) h.mutex.Unlock() h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) @@ -624,8 +624,8 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip case qtls.EncryptionHandshake: h.writeEncLevel = protocol.EncryptionHandshake h.handshakeSealer = newHandshakeSealer( - createAEAD(suite, trafficSecret), - newHeaderProtector(suite, trafficSecret, true), + createAEAD(suite, trafficSecret, h.version), + newHeaderProtector(suite, trafficSecret, true, h.version), h.dropInitialKeys, h.perspective, ) diff --git a/internal/handshake/header_protector.go b/internal/handshake/header_protector.go index e1c72c3b69e..1f800c50fec 100644 --- a/internal/handshake/header_protector.go +++ b/internal/handshake/header_protector.go @@ -9,6 +9,7 @@ import ( "golang.org/x/crypto/chacha20" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qtls" ) @@ -17,12 +18,20 @@ type headerProtector interface { DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) } -func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector { +func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string { + if v == protocol.Version2 { + return "quicv2 hp" + } + return "quic hp" +} + +func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector { + hkdfLabel := hkdfHeaderProtectionLabel(v) switch suite.ID { case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: - return newAESHeaderProtector(suite, trafficSecret, isLongHeader) + return newAESHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel) case tls.TLS_CHACHA20_POLY1305_SHA256: - return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader) + return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel) default: panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID)) } @@ -36,8 +45,8 @@ type aesHeaderProtector struct { var _ headerProtector = &aesHeaderProtector{} -func newAESHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector { - hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic hp", suite.KeyLen) +func newAESHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { + hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen) block, err := aes.NewCipher(hpKey) if err != nil { panic(fmt.Sprintf("error creating new AES cipher: %s", err)) @@ -81,8 +90,8 @@ type chachaHeaderProtector struct { var _ headerProtector = &chachaHeaderProtector{} -func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector { - hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic hp", suite.KeyLen) +func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { + hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen) p := &chachaHeaderProtector{ isLongHeader: isLongHeader, diff --git a/internal/handshake/initial_aead.go b/internal/handshake/initial_aead.go index 2880acf3238..00ed243c75f 100644 --- a/internal/handshake/initial_aead.go +++ b/internal/handshake/initial_aead.go @@ -12,12 +12,23 @@ import ( var ( quicSaltOld = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99} - quicSalt = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a} + quicSaltV1 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a} + quicSaltV2 = []byte{0xa7, 0x07, 0xc2, 0x03, 0xa5, 0x9b, 0x47, 0x18, 0x4a, 0x1d, 0x62, 0xca, 0x57, 0x04, 0x06, 0xea, 0x7a, 0xe3, 0xe5, 0xd3} +) + +const ( + hkdfLabelKeyV1 = "quic key" + hkdfLabelKeyV2 = "quicv2 key" + hkdfLabelIVV1 = "quic iv" + hkdfLabelIVV2 = "quicv2 iv" ) func getSalt(v protocol.VersionNumber) []byte { + if v == protocol.Version2 { + return quicSaltV2 + } if v == protocol.Version1 { - return quicSalt + return quicSaltV1 } return quicSaltOld } @@ -40,14 +51,14 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v p mySecret = serverSecret otherSecret = clientSecret } - myKey, myIV := computeInitialKeyAndIV(mySecret) - otherKey, otherIV := computeInitialKeyAndIV(otherSecret) + myKey, myIV := computeInitialKeyAndIV(mySecret, v) + otherKey, otherIV := computeInitialKeyAndIV(otherSecret, v) encrypter := qtls.AEADAESGCMTLS13(myKey, myIV) decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV) - return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true)), - newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true)) + return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true, v)), + newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v))) } func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) { @@ -57,8 +68,14 @@ func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (cli return } -func computeInitialKeyAndIV(secret []byte) (key, iv []byte) { - key = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16) - iv = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic iv", 12) +func computeInitialKeyAndIV(secret []byte, v protocol.VersionNumber) (key, iv []byte) { + keyLabel := hkdfLabelKeyV1 + ivLabel := hkdfLabelIVV1 + if v == protocol.Version2 { + keyLabel = hkdfLabelKeyV2 + ivLabel = hkdfLabelIVV2 + } + key = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16) + iv = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12) return } diff --git a/internal/handshake/initial_aead_test.go b/internal/handshake/initial_aead_test.go index acabd920f5f..bb8c4a156ac 100644 --- a/internal/handshake/initial_aead_test.go +++ b/internal/handshake/initial_aead_test.go @@ -7,6 +7,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" ) @@ -17,115 +18,144 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { Expect(splitHexString("dead beef")).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) }) - // values taken from the Appendix of the draft - Context("using the test vector from the QUIC draft, for old draft version", func() { - const version = protocol.VersionDraft29 - var connID protocol.ConnectionID - - BeforeEach(func() { - connID = protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) - }) - - It("computes the client key and IV", func() { - clientSecret, _ := computeSecrets(connID, version) - Expect(clientSecret).To(Equal(splitHexString("0088119288f1d866733ceeed15ff9d50 902cf82952eee27e9d4d4918ea371d87"))) - key, iv := computeInitialKeyAndIV(clientSecret) - Expect(key).To(Equal(splitHexString("175257a31eb09dea9366d8bb79ad80ba"))) - Expect(iv).To(Equal(splitHexString("6b26114b9cba2b63a9e8dd4f"))) - }) - - It("computes the server key and IV", func() { - _, serverSecret := computeSecrets(connID, version) - Expect(serverSecret).To(Equal(splitHexString("006f881359244dd9ad1acf85f595bad6 7c13f9f5586f5e64e1acae1d9ea8f616"))) - key, iv := computeInitialKeyAndIV(serverSecret) - Expect(key).To(Equal(splitHexString("149d0b1662ab871fbe63c49b5e655a5d"))) - Expect(iv).To(Equal(splitHexString("bab2b12a4c76016ace47856d"))) - }) - - It("encrypts the client's Initial", func() { - sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveClient, version) - header := splitHexString("c3ff00001d088394c8f03e5157080000449e00000002") - data := splitHexString("060040c4010000c003036660261ff947 cea49cce6cfad687f457cf1b14531ba1 4131a0e8f309a1d0b9c4000006130113 031302010000910000000b0009000006 736572766572ff01000100000a001400 12001d00170018001901000101010201 03010400230000003300260024001d00 204cfdfcd178b784bf328cae793b136f 2aedce005ff183d7bb14952072366470 37002b0003020304000d0020001e0403 05030603020308040805080604010501 060102010402050206020202002d0002 0101001c00024001") + connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) + + DescribeTable("computes the client key and IV", + func(v protocol.VersionNumber, expectedClientSecret, expectedKey, expectedIV []byte) { + clientSecret, _ := computeSecrets(connID, v) + Expect(clientSecret).To(Equal(expectedClientSecret)) + key, iv := computeInitialKeyAndIV(clientSecret, v) + Expect(key).To(Equal(expectedKey)) + Expect(iv).To(Equal(expectedIV)) + }, + Entry("draft-29", + protocol.VersionDraft29, + splitHexString("0088119288f1d866733ceeed15ff9d50 902cf82952eee27e9d4d4918ea371d87"), + splitHexString("175257a31eb09dea9366d8bb79ad80ba"), + splitHexString("6b26114b9cba2b63a9e8dd4f"), + ), + Entry("QUIC v1", + protocol.Version1, + splitHexString("c00cf151ca5be075ed0ebfb5c80323c4 2d6b7db67881289af4008f1f6c357aea"), + splitHexString("1f369613dd76d5467730efcbe3b1a22d"), + splitHexString("fa044b2f42a3fd3b46fb255c"), + ), + Entry("QUIC v2", + protocol.Version2, + splitHexString("9fe72e1452e91f551b770005054034e4 7575d4a0fb4c27b7c6cb303a338423ae"), + splitHexString("95df2be2e8d549c82e996fc9339f4563"), + splitHexString("ea5e3c95f933db14b7020ad8"), + ), + ) + + DescribeTable("computes the server key and IV", + func(v protocol.VersionNumber, expectedServerSecret, expectedKey, expectedIV []byte) { + _, serverSecret := computeSecrets(connID, v) + Expect(serverSecret).To(Equal(expectedServerSecret)) + key, iv := computeInitialKeyAndIV(serverSecret, v) + Expect(key).To(Equal(expectedKey)) + Expect(iv).To(Equal(expectedIV)) + }, + Entry("draft 29", + protocol.VersionDraft29, + splitHexString("006f881359244dd9ad1acf85f595bad6 7c13f9f5586f5e64e1acae1d9ea8f616"), + splitHexString("149d0b1662ab871fbe63c49b5e655a5d"), + splitHexString("bab2b12a4c76016ace47856d"), + ), + Entry("QUIC v1", + protocol.Version1, + splitHexString("3c199828fd139efd216c155ad844cc81 fb82fa8d7446fa7d78be803acdda951b"), + splitHexString("cf3a5331653c364c88f0f379b6067e37"), + splitHexString("0ac1493ca1905853b0bba03e"), + ), + Entry("QUIC v2", + protocol.Version2, + splitHexString("3c9bf6a9c1c8c71819876967bd8b979e fd98ec665edf27f22c06e9845ba0ae2f"), + splitHexString("15d5b4d9a2b8916aa39b1bfe574d2aad"), + splitHexString("a85e7ac31cd275cbb095c626"), + ), + ) + + DescribeTable("encrypts the client's Initial", + func(v protocol.VersionNumber, header, data, expectedSample []byte, expectedHdrFirstByte byte, expectedHdr, expectedPacket []byte) { + sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveClient, v) data = append(data, make([]byte, 1162-len(data))...) // add PADDING sealed := sealer.Seal(nil, data, 2, header) sample := sealed[0:16] - Expect(sample).To(Equal(splitHexString("fb66bc5f93032b7ddd89fe0ff15d9c4f"))) + Expect(sample).To(Equal(expectedSample)) sealer.EncryptHeader(sample, &header[0], header[len(header)-4:]) - Expect(header[0]).To(Equal(byte(0xc5))) - Expect(header[len(header)-4:]).To(Equal(splitHexString("4a95245b"))) + Expect(header[0]).To(Equal(expectedHdrFirstByte)) + Expect(header[len(header)-4:]).To(Equal(expectedHdr)) packet := append(header, sealed...) - Expect(packet).To(Equal(splitHexString("c5ff00001d088394c8f03e5157080000 449e4a95245bfb66bc5f93032b7ddd89 fe0ff15d9c4f7050fccdb71c1cd80512 d4431643a53aafa1b0b518b44968b18b 8d3e7a4d04c30b3ed9410325b2abb2da fb1c12f8b70479eb8df98abcaf95dd8f 3d1c78660fbc719f88b23c8aef6771f3 d50e10fdfb4c9d92386d44481b6c52d5 9e5538d3d3942de9f13a7f8b702dc317 24180da9df22714d01003fc5e3d165c9 50e630b8540fbd81c9df0ee63f949970 26c4f2e1887a2def79050ac2d86ba318 e0b3adc4c5aa18bcf63c7cf8e85f5692 49813a2236a7e72269447cd1c755e451 f5e77470eb3de64c8849d29282069802 9cfa18e5d66176fe6e5ba4ed18026f90 900a5b4980e2f58e39151d5cd685b109 29636d4f02e7fad2a5a458249f5c0298 a6d53acbe41a7fc83fa7cc01973f7a74 d1237a51974e097636b6203997f921d0 7bc1940a6f2d0de9f5a11432946159ed 6cc21df65c4ddd1115f86427259a196c 7148b25b6478b0dc7766e1c4d1b1f515 9f90eabc61636226244642ee148b464c 9e619ee50a5e3ddc836227cad938987c 4ea3c1fa7c75bbf88d89e9ada642b2b8 8fe8107b7ea375b1b64889a4e9e5c38a 1c896ce275a5658d250e2d76e1ed3a34 ce7e3a3f383d0c996d0bed106c2899ca 6fc263ef0455e74bb6ac1640ea7bfedc 59f03fee0e1725ea150ff4d69a7660c5 542119c71de270ae7c3ecfd1af2c4ce5 51986949cc34a66b3e216bfe18b347e6 c05fd050f85912db303a8f054ec23e38 f44d1c725ab641ae929fecc8e3cefa56 19df4231f5b4c009fa0c0bbc60bc75f7 6d06ef154fc8577077d9d6a1d2bd9bf0 81dc783ece60111bea7da9e5a9748069 d078b2bef48de04cabe3755b197d52b3 2046949ecaa310274b4aac0d008b1948 c1082cdfe2083e386d4fd84c0ed0666d 3ee26c4515c4fee73433ac703b690a9f 7bf278a77486ace44c489a0c7ac8dfe4 d1a58fb3a730b993ff0f0d61b4d89557 831eb4c752ffd39c10f6b9f46d8db278 da624fd800e4af85548a294c1518893a 8778c4f6d6d73c93df200960104e062b 388ea97dcf4016bced7f62b4f062cb6c 04c20693d9a0e3b74ba8fe74cc012378 84f40d765ae56a51688d985cf0ceaef4 3045ed8c3f0c33bced08537f6882613a cd3b08d665fce9dd8aa73171e2d3771a 61dba2790e491d413d93d987e2745af2 9418e428be34941485c93447520ffe23 1da2304d6a0fd5d07d08372202369661 59bef3cf904d722324dd852513df39ae 030d8173908da6364786d3c1bfcb19ea 77a63b25f1e7fc661def480c5d00d444 56269ebd84efd8e3a8b2c257eec76060 682848cbf5194bc99e49ee75e4d0d254 bad4bfd74970c30e44b65511d4ad0e6e c7398e08e01307eeeea14e46ccd87cf3 6b285221254d8fc6a6765c524ded0085 dca5bd688ddf722e2c0faf9d0fb2ce7a 0c3f2cee19ca0ffba461ca8dc5d2c817 8b0762cf67135558494d2a96f1a139f0 edb42d2af89a9c9122b07acbc29e5e72 2df8615c343702491098478a389c9872 a10b0c9875125e257c7bfdf27eef4060 bd3d00f4c14fd3e3496c38d3c5d1a566 8c39350effbc2d16ca17be4ce29f02ed 969504dda2a8c6b9ff919e693ee79e09 089316e7d1d89ec099db3b2b268725d8 88536a4b8bf9aee8fb43e82a4d919d48 43b1ca70a2d8d3f725ead1391377dcc0"))) - }) - - It("encrypt the server's Initial", func() { - sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveServer, version) - header := splitHexString("c1ff00001d0008f067a5502a4262b50040740001") - data := splitHexString("0d0000000018410a020000560303eefc e7f7b37ba1d1632e96677825ddf73988 cfc79825df566dc5430b9a045a120013 0100002e00330024001d00209d3c940d 89690b84d08a60993c144eca684d1081 287c834d5311bcf32bb9da1a002b0002 0304") + Expect(packet).To(Equal(expectedPacket)) + }, + Entry("draft 29", + protocol.VersionDraft29, + splitHexString("c3ff00001d088394c8f03e5157080000449e00000002"), + splitHexString("060040c4010000c003036660261ff947 cea49cce6cfad687f457cf1b14531ba1 4131a0e8f309a1d0b9c4000006130113 031302010000910000000b0009000006 736572766572ff01000100000a001400 12001d00170018001901000101010201 03010400230000003300260024001d00 204cfdfcd178b784bf328cae793b136f 2aedce005ff183d7bb14952072366470 37002b0003020304000d0020001e0403 05030603020308040805080604010501 060102010402050206020202002d0002 0101001c00024001"), + splitHexString("fb66bc5f93032b7ddd89fe0ff15d9c4f"), + byte(0xc5), + splitHexString("4a95245b"), + splitHexString("c5ff00001d088394c8f03e5157080000 449e4a95245bfb66bc5f93032b7ddd89 fe0ff15d9c4f7050fccdb71c1cd80512 d4431643a53aafa1b0b518b44968b18b 8d3e7a4d04c30b3ed9410325b2abb2da fb1c12f8b70479eb8df98abcaf95dd8f 3d1c78660fbc719f88b23c8aef6771f3 d50e10fdfb4c9d92386d44481b6c52d5 9e5538d3d3942de9f13a7f8b702dc317 24180da9df22714d01003fc5e3d165c9 50e630b8540fbd81c9df0ee63f949970 26c4f2e1887a2def79050ac2d86ba318 e0b3adc4c5aa18bcf63c7cf8e85f5692 49813a2236a7e72269447cd1c755e451 f5e77470eb3de64c8849d29282069802 9cfa18e5d66176fe6e5ba4ed18026f90 900a5b4980e2f58e39151d5cd685b109 29636d4f02e7fad2a5a458249f5c0298 a6d53acbe41a7fc83fa7cc01973f7a74 d1237a51974e097636b6203997f921d0 7bc1940a6f2d0de9f5a11432946159ed 6cc21df65c4ddd1115f86427259a196c 7148b25b6478b0dc7766e1c4d1b1f515 9f90eabc61636226244642ee148b464c 9e619ee50a5e3ddc836227cad938987c 4ea3c1fa7c75bbf88d89e9ada642b2b8 8fe8107b7ea375b1b64889a4e9e5c38a 1c896ce275a5658d250e2d76e1ed3a34 ce7e3a3f383d0c996d0bed106c2899ca 6fc263ef0455e74bb6ac1640ea7bfedc 59f03fee0e1725ea150ff4d69a7660c5 542119c71de270ae7c3ecfd1af2c4ce5 51986949cc34a66b3e216bfe18b347e6 c05fd050f85912db303a8f054ec23e38 f44d1c725ab641ae929fecc8e3cefa56 19df4231f5b4c009fa0c0bbc60bc75f7 6d06ef154fc8577077d9d6a1d2bd9bf0 81dc783ece60111bea7da9e5a9748069 d078b2bef48de04cabe3755b197d52b3 2046949ecaa310274b4aac0d008b1948 c1082cdfe2083e386d4fd84c0ed0666d 3ee26c4515c4fee73433ac703b690a9f 7bf278a77486ace44c489a0c7ac8dfe4 d1a58fb3a730b993ff0f0d61b4d89557 831eb4c752ffd39c10f6b9f46d8db278 da624fd800e4af85548a294c1518893a 8778c4f6d6d73c93df200960104e062b 388ea97dcf4016bced7f62b4f062cb6c 04c20693d9a0e3b74ba8fe74cc012378 84f40d765ae56a51688d985cf0ceaef4 3045ed8c3f0c33bced08537f6882613a cd3b08d665fce9dd8aa73171e2d3771a 61dba2790e491d413d93d987e2745af2 9418e428be34941485c93447520ffe23 1da2304d6a0fd5d07d08372202369661 59bef3cf904d722324dd852513df39ae 030d8173908da6364786d3c1bfcb19ea 77a63b25f1e7fc661def480c5d00d444 56269ebd84efd8e3a8b2c257eec76060 682848cbf5194bc99e49ee75e4d0d254 bad4bfd74970c30e44b65511d4ad0e6e c7398e08e01307eeeea14e46ccd87cf3 6b285221254d8fc6a6765c524ded0085 dca5bd688ddf722e2c0faf9d0fb2ce7a 0c3f2cee19ca0ffba461ca8dc5d2c817 8b0762cf67135558494d2a96f1a139f0 edb42d2af89a9c9122b07acbc29e5e72 2df8615c343702491098478a389c9872 a10b0c9875125e257c7bfdf27eef4060 bd3d00f4c14fd3e3496c38d3c5d1a566 8c39350effbc2d16ca17be4ce29f02ed 969504dda2a8c6b9ff919e693ee79e09 089316e7d1d89ec099db3b2b268725d8 88536a4b8bf9aee8fb43e82a4d919d48 43b1ca70a2d8d3f725ead1391377dcc0"), + ), + Entry("QUIC v1", + protocol.Version1, + splitHexString("c300000001088394c8f03e5157080000449e00000002"), + splitHexString("060040f1010000ed0303ebf8fa56f129 39b9584a3896472ec40bb863cfd3e868 04fe3a47f06a2b69484c000004130113 02010000c000000010000e00000b6578 616d706c652e636f6dff01000100000a 00080006001d00170018001000070005 04616c706e0005000501000000000033 00260024001d00209370b2c9caa47fba baf4559fedba753de171fa71f50f1ce1 5d43e994ec74d748002b000302030400 0d0010000e0403050306030203080408 050806002d00020101001c0002400100 3900320408ffffffffffffffff050480 00ffff07048000ffff08011001048000 75300901100f088394c8f03e51570806 048000ffff"), + splitHexString("d1b1c98dd7689fb8ec11d242b123dc9b"), + byte(0xc0), + splitHexString("7b9aec34"), + splitHexString("c000000001088394c8f03e5157080000 449e7b9aec34d1b1c98dd7689fb8ec11 d242b123dc9bd8bab936b47d92ec356c 0bab7df5976d27cd449f63300099f399 1c260ec4c60d17b31f8429157bb35a12 82a643a8d2262cad67500cadb8e7378c 8eb7539ec4d4905fed1bee1fc8aafba1 7c750e2c7ace01e6005f80fcb7df6212 30c83711b39343fa028cea7f7fb5ff89 eac2308249a02252155e2347b63d58c5 457afd84d05dfffdb20392844ae81215 4682e9cf012f9021a6f0be17ddd0c208 4dce25ff9b06cde535d0f920a2db1bf3 62c23e596d11a4f5a6cf3948838a3aec 4e15daf8500a6ef69ec4e3feb6b1d98e 610ac8b7ec3faf6ad760b7bad1db4ba3 485e8a94dc250ae3fdb41ed15fb6a8e5 eba0fc3dd60bc8e30c5c4287e53805db 059ae0648db2f64264ed5e39be2e20d8 2df566da8dd5998ccabdae053060ae6c 7b4378e846d29f37ed7b4ea9ec5d82e7 961b7f25a9323851f681d582363aa5f8 9937f5a67258bf63ad6f1a0b1d96dbd4 faddfcefc5266ba6611722395c906556 be52afe3f565636ad1b17d508b73d874 3eeb524be22b3dcbc2c7468d54119c74 68449a13d8e3b95811a198f3491de3e7 fe942b330407abf82a4ed7c1b311663a c69890f4157015853d91e923037c227a 33cdd5ec281ca3f79c44546b9d90ca00 f064c99e3dd97911d39fe9c5d0b23a22 9a234cb36186c4819e8b9c5927726632 291d6a418211cc2962e20fe47feb3edf 330f2c603a9d48c0fcb5699dbfe58964 25c5bac4aee82e57a85aaf4e2513e4f0 5796b07ba2ee47d80506f8d2c25e50fd 14de71e6c418559302f939b0e1abd576 f279c4b2e0feb85c1f28ff18f58891ff ef132eef2fa09346aee33c28eb130ff2 8f5b766953334113211996d20011a198 e3fc433f9f2541010ae17c1bf202580f 6047472fb36857fe843b19f5984009dd c324044e847a4f4a0ab34f719595de37 252d6235365e9b84392b061085349d73 203a4a13e96f5432ec0fd4a1ee65accd d5e3904df54c1da510b0ff20dcc0c77f cb2c0e0eb605cb0504db87632cf3d8b4 dae6e705769d1de354270123cb11450e fc60ac47683d7b8d0f811365565fd98c 4c8eb936bcab8d069fc33bd801b03ade a2e1fbc5aa463d08ca19896d2bf59a07 1b851e6c239052172f296bfb5e724047 90a2181014f3b94a4e97d117b4381303 68cc39dbb2d198065ae3986547926cd2 162f40a29f0c3c8745c0f50fba3852e5 66d44575c29d39a03f0cda721984b6f4 40591f355e12d439ff150aab7613499d bd49adabc8676eef023b15b65bfc5ca0 6948109f23f350db82123535eb8a7433 bdabcb909271a6ecbcb58b936a88cd4e 8f2e6ff5800175f113253d8fa9ca8885 c2f552e657dc603f252e1a8e308f76f0 be79e2fb8f5d5fbbe2e30ecadd220723 c8c0aea8078cdfcb3868263ff8f09400 54da48781893a7e49ad5aff4af300cd8 04a6b6279ab3ff3afb64491c85194aab 760d58a606654f9f4400e8b38591356f bf6425aca26dc85244259ff2b19c41b9 f96f3ca9ec1dde434da7d2d392b905dd f3d1f9af93d1af5950bd493f5aa731b4 056df31bd267b6b90a079831aaf579be 0a39013137aac6d404f518cfd4684064 7e78bfe706ca4cf5e9c5453e9f7cfd2b 8b4c8d169a44e55c88d4a9a7f9474241 e221af44860018ab0856972e194cd934"), + ), + Entry("QUIC v2", + protocol.Version2, + splitHexString("d3709a50c4088394c8f03e5157080000449e00000002"), + splitHexString("060040f1010000ed0303ebf8fa56f129 39b9584a3896472ec40bb863cfd3e868 04fe3a47f06a2b69484c000004130113 02010000c000000010000e00000b6578 616d706c652e636f6dff01000100000a 00080006001d00170018001000070005 04616c706e0005000501000000000033 00260024001d00209370b2c9caa47fba baf4559fedba753de171fa71f50f1ce1 5d43e994ec74d748002b000302030400 0d0010000e0403050306030203080408 050806002d00020101001c0002400100 3900320408ffffffffffffffff050480 00ffff07048000ffff08011001048000 75300901100f088394c8f03e51570806 048000ffff"), + splitHexString("23b8e610589c83c92d0e97eb7a6e5003"), + byte(0xdd), + splitHexString("4391d848"), + splitHexString("dd709a50c4088394c8f03e5157080000 449e4391d84823b8e610589c83c92d0e 97eb7a6e5003f57764c5c7f0095ba54b 90818f1bfeecc1c97c54fc731edbd2a2 44e3b1e639a9bc75ed545b98649343b2 53615ec6b3e4df0fd2e7fe9d691a09e6 a144b436d8a2c088a404262340dfd995 ec3865694e3026ecd8c6d2561a5a3667 2a1005018168c0f081c10e2bf14d550c 977e28bb9a759c57d0f7ffb1cdfb40bd 774dec589657542047dffefa56fc8089 a4d1ef379c81ba3df71a05ddc7928340 775910feb3ce4cbcfd8d253edd05f161 458f9dc44bea017c3117cca7065a315d eda9464e672ec80c3f79ac993437b441 ef74227ecc4dc9d597f66ab0ab8d214b 55840c70349d7616cbe38e5e1d052d07 f1fedb3dd3c4d8ce295724945e67ed2e efcd9fb52472387f318e3d9d233be7df c79d6bf6080dcbbb41feb180d7858849 7c3e439d38c334748d2b56fd19ab364d 057a9bd5a699ae145d7fdbc8f5777518 1b0a97c3bdedc91a555d6c9b8634e106 d8c9ca45a9d5450a7679edc545da9102 5bc93a7cf9a023a066ffadb9717ffaf3 414c3b646b5738b3cc4116502d18d79d 8227436306d9b2b3afc6c785ce3c817f eb703a42b9c83b59f0dcef1245d0b3e4 0299821ec19549ce489714fe2611e72c d882f4f70dce7d3671296fc045af5c9f 630d7b49a3eb821bbca60f1984dce664 91713bfe06001a56f51bb3abe92f7960 547c4d0a70f4a962b3f05dc25a34bbe8 30a7ea4736d3b0161723500d82beda9b e3327af2aa413821ff678b2a876ec4b0 0bb605ffcc3917ffdc279f187daa2fce 8cde121980bba8ec8f44ca562b0f1319 14c901cfbd847408b778e6738c7bb5b1 b3f97d01b0a24dcca40e3bed29411b1b a8f60843c4a241021b23132b9500509b 9a3516d4a9dd41d3bacbcd426b451393 521828afedcf20fa46ac24f44a8e2973 30b16705d5d5f798eff9e9134a065979 87a1db4617caa2d93837730829d4d89e 16413be4d8a8a38a7e6226623b64a820 178ec3a66954e10710e043ae73dd3fb2 715a0525a46343fb7590e5eac7ee55fc 810e0d8b4b8f7be82cd5a214575a1b99 629d47a9b281b61348c8627cab38e2a6 4db6626e97bb8f77bdcb0fee476aedd7 ba8f5441acaab00f4432edab3791047d 9091b2a753f035648431f6d12f7d6a68 1e64c861f4ac911a0f7d6ec0491a78c9 f192f96b3a5e7560a3f056bc1ca85983 67ad6acb6f2e034c7f37beeb9ed470c4 304af0107f0eb919be36a86f68f37fa6 1dae7aff14decd67ec3157a11488a14f ed0142828348f5f608b0fe03e1f3c0af 3acca0ce36852ed42e220ae9abf8f890 6f00f1b86bff8504c8f16c784fd52d25 e013ff4fda903e9e1eb453c1464b1196 6db9b28e8f26a3fc419e6a60a48d4c72 14ee9c6c6a12b68a32cac8f61580c64f 29cb6922408783c6d12e725b014fe485 cd17e484c5952bf99bc94941d4b1919d 04317b8aa1bd3754ecbaa10ec227de85 40695bf2fb8ee56f6dc526ef366625b9 1aa4970b6ffa5c8284b9b5ab852b905f 9d83f5669c0535bc377bcc05ad5e48e2 81ec0e1917ca3c6a471f8da0894bc82a c2a8965405d6eef3b5e293a88fda203f 09bdc72757b107ab14880eaa3ef7045b 580f4821ce6dd325b5a90655d8c5b55f 76fb846279a9b518c5e9b9a21165c509 3ed49baaacadf1f21873266c767f6769"), + ), + ) + + DescribeTable("encrypts the server's Initial", + func(v protocol.VersionNumber, header, data, expectedSample, expectedHdr, expectedPacket []byte) { + sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveServer, v) sealed := sealer.Seal(nil, data, 1, header) sample := sealed[2 : 2+16] - Expect(sample).To(Equal(splitHexString("823a5d3a1207c86ee49132824f046524"))) + Expect(sample).To(Equal(expectedSample)) sealer.EncryptHeader(sample, &header[0], header[len(header)-2:]) - Expect(header).To(Equal(splitHexString("caff00001d0008f067a5502a4262b5004074aaf2"))) + Expect(header).To(Equal(expectedHdr)) packet := append(header, sealed...) - Expect(packet).To(Equal(splitHexString("caff00001d0008f067a5502a4262b500 4074aaf2f007823a5d3a1207c86ee491 32824f0465243d082d868b107a38092b c80528664cbf9456ebf27673fb5fa506 1ab573c9f001b81da028a00d52ab00b1 5bebaa70640e106cf2acd043e9c6b441 1c0a79637134d8993701fe779e58c2fe 753d14b0564021565ea92e57bc6faf56 dfc7a40870e6"))) - }) - }) - - // values taken from the Appendix of the draft - Context("using the test vector from the QUIC draft, for QUIC v1", func() { - const version = protocol.Version1 - var connID protocol.ConnectionID - - BeforeEach(func() { - connID = protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) - }) - - It("computes the client key and IV", func() { - clientSecret, _ := computeSecrets(connID, version) - Expect(clientSecret).To(Equal(splitHexString("c00cf151ca5be075ed0ebfb5c80323c4 2d6b7db67881289af4008f1f6c357aea"))) - key, iv := computeInitialKeyAndIV(clientSecret) - Expect(key).To(Equal(splitHexString("1f369613dd76d5467730efcbe3b1a22d"))) - Expect(iv).To(Equal(splitHexString("fa044b2f42a3fd3b46fb255c"))) - }) - - It("computes the server key and IV", func() { - _, serverSecret := computeSecrets(connID, version) - Expect(serverSecret).To(Equal(splitHexString("3c199828fd139efd216c155ad844cc81 fb82fa8d7446fa7d78be803acdda951b"))) - key, iv := computeInitialKeyAndIV(serverSecret) - Expect(key).To(Equal(splitHexString("cf3a5331653c364c88f0f379b6067e37"))) - Expect(iv).To(Equal(splitHexString("0ac1493ca1905853b0bba03e"))) - }) - - It("encrypts the client's Initial", func() { - sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveClient, version) - header := splitHexString("c300000001088394c8f03e5157080000449e00000002") - data := splitHexString("060040f1010000ed0303ebf8fa56f129 39b9584a3896472ec40bb863cfd3e868 04fe3a47f06a2b69484c000004130113 02010000c000000010000e00000b6578 616d706c652e636f6dff01000100000a 00080006001d00170018001000070005 04616c706e0005000501000000000033 00260024001d00209370b2c9caa47fba baf4559fedba753de171fa71f50f1ce1 5d43e994ec74d748002b000302030400 0d0010000e0403050306030203080408 050806002d00020101001c0002400100 3900320408ffffffffffffffff050480 00ffff07048000ffff08011001048000 75300901100f088394c8f03e51570806 048000ffff") - data = append(data, make([]byte, 1162-len(data))...) // add PADDING - sealed := sealer.Seal(nil, data, 2, header) - sample := sealed[0:16] - Expect(sample).To(Equal(splitHexString("d1b1c98dd7689fb8ec11d242b123dc9b"))) - sealer.EncryptHeader(sample, &header[0], header[len(header)-4:]) - Expect(header[0]).To(Equal(byte(0xc0))) - Expect(header[len(header)-4:]).To(Equal(splitHexString("7b9aec34"))) - packet := append(header, sealed...) - Expect(packet).To(Equal(splitHexString("c000000001088394c8f03e5157080000 449e7b9aec34d1b1c98dd7689fb8ec11 d242b123dc9bd8bab936b47d92ec356c 0bab7df5976d27cd449f63300099f399 1c260ec4c60d17b31f8429157bb35a12 82a643a8d2262cad67500cadb8e7378c 8eb7539ec4d4905fed1bee1fc8aafba1 7c750e2c7ace01e6005f80fcb7df6212 30c83711b39343fa028cea7f7fb5ff89 eac2308249a02252155e2347b63d58c5 457afd84d05dfffdb20392844ae81215 4682e9cf012f9021a6f0be17ddd0c208 4dce25ff9b06cde535d0f920a2db1bf3 62c23e596d11a4f5a6cf3948838a3aec 4e15daf8500a6ef69ec4e3feb6b1d98e 610ac8b7ec3faf6ad760b7bad1db4ba3 485e8a94dc250ae3fdb41ed15fb6a8e5 eba0fc3dd60bc8e30c5c4287e53805db 059ae0648db2f64264ed5e39be2e20d8 2df566da8dd5998ccabdae053060ae6c 7b4378e846d29f37ed7b4ea9ec5d82e7 961b7f25a9323851f681d582363aa5f8 9937f5a67258bf63ad6f1a0b1d96dbd4 faddfcefc5266ba6611722395c906556 be52afe3f565636ad1b17d508b73d874 3eeb524be22b3dcbc2c7468d54119c74 68449a13d8e3b95811a198f3491de3e7 fe942b330407abf82a4ed7c1b311663a c69890f4157015853d91e923037c227a 33cdd5ec281ca3f79c44546b9d90ca00 f064c99e3dd97911d39fe9c5d0b23a22 9a234cb36186c4819e8b9c5927726632 291d6a418211cc2962e20fe47feb3edf 330f2c603a9d48c0fcb5699dbfe58964 25c5bac4aee82e57a85aaf4e2513e4f0 5796b07ba2ee47d80506f8d2c25e50fd 14de71e6c418559302f939b0e1abd576 f279c4b2e0feb85c1f28ff18f58891ff ef132eef2fa09346aee33c28eb130ff2 8f5b766953334113211996d20011a198 e3fc433f9f2541010ae17c1bf202580f 6047472fb36857fe843b19f5984009dd c324044e847a4f4a0ab34f719595de37 252d6235365e9b84392b061085349d73 203a4a13e96f5432ec0fd4a1ee65accd d5e3904df54c1da510b0ff20dcc0c77f cb2c0e0eb605cb0504db87632cf3d8b4 dae6e705769d1de354270123cb11450e fc60ac47683d7b8d0f811365565fd98c 4c8eb936bcab8d069fc33bd801b03ade a2e1fbc5aa463d08ca19896d2bf59a07 1b851e6c239052172f296bfb5e724047 90a2181014f3b94a4e97d117b4381303 68cc39dbb2d198065ae3986547926cd2 162f40a29f0c3c8745c0f50fba3852e5 66d44575c29d39a03f0cda721984b6f4 40591f355e12d439ff150aab7613499d bd49adabc8676eef023b15b65bfc5ca0 6948109f23f350db82123535eb8a7433 bdabcb909271a6ecbcb58b936a88cd4e 8f2e6ff5800175f113253d8fa9ca8885 c2f552e657dc603f252e1a8e308f76f0 be79e2fb8f5d5fbbe2e30ecadd220723 c8c0aea8078cdfcb3868263ff8f09400 54da48781893a7e49ad5aff4af300cd8 04a6b6279ab3ff3afb64491c85194aab 760d58a606654f9f4400e8b38591356f bf6425aca26dc85244259ff2b19c41b9 f96f3ca9ec1dde434da7d2d392b905dd f3d1f9af93d1af5950bd493f5aa731b4 056df31bd267b6b90a079831aaf579be 0a39013137aac6d404f518cfd4684064 7e78bfe706ca4cf5e9c5453e9f7cfd2b 8b4c8d169a44e55c88d4a9a7f9474241 e221af44860018ab0856972e194cd934"))) - }) - - It("encrypt the server's Initial", func() { - sealer, _ := NewInitialAEAD(connID, protocol.PerspectiveServer, version) - header := splitHexString("c1000000010008f067a5502a4262b50040750001") - data := splitHexString("02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739 88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94 0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00 020304") - sealed := sealer.Seal(nil, data, 1, header) - sample := sealed[2 : 2+16] - Expect(sample).To(Equal(splitHexString("2cd0991cd25b0aac406a5816b6394100"))) - sealer.EncryptHeader(sample, &header[0], header[len(header)-2:]) - Expect(header).To(Equal(splitHexString("cf000000010008f067a5502a4262b5004075c0d9"))) - packet := append(header, sealed...) - Expect(packet).To(Equal(splitHexString("cf000000010008f067a5502a4262b500 4075c0d95a482cd0991cd25b0aac406a 5816b6394100f37a1c69797554780bb3 8cc5a99f5ede4cf73c3ec2493a1839b3 dbcba3f6ea46c5b7684df3548e7ddeb9 c3bf9c73cc3f3bded74b562bfb19fb84 022f8ef4cdd93795d77d06edbb7aaf2f 58891850abbdca3d20398c276456cbc4 2158407dd074ee"))) - }) - }) - - for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1} { + Expect(packet).To(Equal(expectedPacket)) + }, + Entry("draft 29", + protocol.VersionDraft29, + splitHexString("c1ff00001d0008f067a5502a4262b50040740001"), + splitHexString("0d0000000018410a020000560303eefc e7f7b37ba1d1632e96677825ddf73988 cfc79825df566dc5430b9a045a120013 0100002e00330024001d00209d3c940d 89690b84d08a60993c144eca684d1081 287c834d5311bcf32bb9da1a002b0002 0304"), + splitHexString("823a5d3a1207c86ee49132824f046524"), + splitHexString("caff00001d0008f067a5502a4262b5004074aaf2"), + splitHexString("caff00001d0008f067a5502a4262b500 4074aaf2f007823a5d3a1207c86ee491 32824f0465243d082d868b107a38092b c80528664cbf9456ebf27673fb5fa506 1ab573c9f001b81da028a00d52ab00b1 5bebaa70640e106cf2acd043e9c6b441 1c0a79637134d8993701fe779e58c2fe 753d14b0564021565ea92e57bc6faf56 dfc7a40870e6"), + ), + Entry("QUIC v1", + protocol.Version1, + splitHexString("c1000000010008f067a5502a4262b50040750001"), + splitHexString("02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739 88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94 0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00 020304"), + splitHexString("2cd0991cd25b0aac406a5816b6394100"), + splitHexString("cf000000010008f067a5502a4262b5004075c0d9"), + splitHexString("cf000000010008f067a5502a4262b500 4075c0d95a482cd0991cd25b0aac406a 5816b6394100f37a1c69797554780bb3 8cc5a99f5ede4cf73c3ec2493a1839b3 dbcba3f6ea46c5b7684df3548e7ddeb9 c3bf9c73cc3f3bded74b562bfb19fb84 022f8ef4cdd93795d77d06edbb7aaf2f 58891850abbdca3d20398c276456cbc4 2158407dd074ee"), + ), + Entry("QUIC v2", + protocol.Version2, + splitHexString("d1709a50c40008f067a5502a4262b50040750001"), + splitHexString("02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739 88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94 0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00 020304"), + splitHexString("ebb7972fdce59d50e7e49ff2a7e8de76"), + splitHexString("d0709a50c40008f067a5502a4262b5004075103e"), + splitHexString("d0709a50c40008f067a5502a4262b500 4075103e63b4ebb7972fdce59d50e7e4 9ff2a7e8de76b0cd8c10100a1f13d549 dd6fe801588fb14d279bef8d7c53ef62 66a9a7a1a5f2fa026c236a5bf8df5aa0 f9d74773aeccfffe910b0f76814b5e33 f7b7f8ec278d23fd8c7a9e66856b8bbe 72558135bca27c54d63fcc902253461c fc089d4e6b9b19"), + ), + ) + + for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1, protocol.Version2} { v := ver Context(fmt.Sprintf("using version %s", v), func() { diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index 094e650468a..1532e7b5a5b 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -55,8 +55,9 @@ type updatableAEAD struct { rttStats *utils.RTTStats - tracer logging.ConnectionTracer - logger utils.Logger + tracer logging.ConnectionTracer + logger utils.Logger + version protocol.VersionNumber // use a single slice to avoid allocations nonceBuf []byte @@ -67,7 +68,7 @@ var ( _ ShortHeaderSealer = &updatableAEAD{} ) -func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger) *updatableAEAD { +func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD { return &updatableAEAD{ firstPacketNumber: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber, @@ -77,6 +78,7 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, rttStats: rttStats, tracer: tracer, logger: logger, + version: version, } } @@ -100,8 +102,8 @@ func (a *updatableAEAD) rollKeys() { a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret) a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret) - a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret) - a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret) + a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret, a.version) + a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret, a.version) } func (a *updatableAEAD) startKeyDropTimer(now time.Time) { @@ -117,27 +119,27 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte // For the client, this function is called before SetWriteKey. // For the server, this function is called after SetWriteKey. func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { - a.rcvAEAD = createAEAD(suite, trafficSecret) - a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false) + a.rcvAEAD = createAEAD(suite, trafficSecret, a.version) + a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version) if a.suite == nil { a.setAEADParameters(a.rcvAEAD, suite) } a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) - a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret) + a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version) } // For the client, this function is called after SetReadKey. // For the server, this function is called before SetWriteKey. func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { - a.sendAEAD = createAEAD(suite, trafficSecret) - a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false) + a.sendAEAD = createAEAD(suite, trafficSecret, a.version) + a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version) if a.suite == nil { a.setAEADParameters(a.sendAEAD, suite) } a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) - a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret) + a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version) } func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSuiteTLS13) { diff --git a/internal/handshake/updatable_aead_test.go b/internal/handshake/updatable_aead_test.go index 3e6b8ad22fb..88b89f334cc 100644 --- a/internal/handshake/updatable_aead_test.go +++ b/internal/handshake/updatable_aead_test.go @@ -14,496 +14,515 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" ) var _ = Describe("Updatable AEAD", func() { - It("ChaCha test vector from the draft", func() { - secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b") - aead := newUpdatableAEAD(&utils.RTTStats{}, nil, nil) - chacha := cipherSuites[2] - Expect(chacha.ID).To(Equal(tls.TLS_CHACHA20_POLY1305_SHA256)) - aead.SetWriteKey(chacha, secret) - header := splitHexString("4200bff4") - const pnOffset = 1 - payloadOffset := len(header) - plaintext := splitHexString("01") - payload := aead.Seal(nil, plaintext, 654360564, header) - Expect(payload).To(Equal(splitHexString("655e5cd55c41f69080575d7999c25a5bfb"))) - packet := append(header, payload...) - aead.EncryptHeader(packet[pnOffset+4:pnOffset+4+16], &packet[0], packet[pnOffset:payloadOffset]) - Expect(packet).To(Equal(splitHexString("4cfe4189655e5cd55c41f69080575d7999c25a5bfb"))) - }) - - for i := range cipherSuites { - cs := cipherSuites[i] - - Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() { - var ( - client, server *updatableAEAD - serverTracer *mocklogging.MockConnectionTracer - rttStats *utils.RTTStats - ) - - BeforeEach(func() { - serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) - trafficSecret1 := make([]byte, 16) - trafficSecret2 := make([]byte, 16) - rand.Read(trafficSecret1) - rand.Read(trafficSecret2) - - rttStats = utils.NewRTTStats() - client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger) - server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger) - client.SetReadKey(cs, trafficSecret2) - client.SetWriteKey(cs, trafficSecret1) - server.SetReadKey(cs, trafficSecret1) - server.SetWriteKey(cs, trafficSecret2) - }) - - Context("header protection", func() { - It("encrypts and decrypts the header", func() { - var lastFiveBitsDifferent int - for i := 0; i < 100; i++ { - sample := make([]byte, 16) - rand.Read(sample) - header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} - client.EncryptHeader(sample, &header[0], header[9:13]) - if header[0]&0x1f != 0xb5&0x1f { - lastFiveBitsDifferent++ - } - Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0))) - Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) - Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) - server.DecryptHeader(sample, &header[0], header[9:13]) - Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) - } - Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75)) - }) - }) - - Context("message encryption", func() { - var msg, ad []byte - - BeforeEach(func() { - msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - ad = []byte("Donec in velit neque.") - }) - - It("encrypts and decrypts a message", func() { - encrypted := server.Seal(nil, msg, 0x1337, ad) - opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(opened).To(Equal(msg)) - }) - - It("saves the first packet number", func() { - client.Seal(nil, msg, 0x1337, ad) - Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) - client.Seal(nil, msg, 0x1338, ad) - Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) - }) - - It("fails to open a message if the associated data is not the same", func() { - encrypted := client.Seal(nil, msg, 0x1337, ad) - _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("fails to open a message if the packet number is not the same", func() { - encrypted := server.Seal(nil, msg, 0x1337, ad) - _, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) - - It("decodes the packet number", func() { - encrypted := server.Seal(nil, msg, 0x1337, ad) - _, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338)) - }) - - It("ignores packets it can't decrypt for packet number derivation", func() { - encrypted := server.Seal(nil, msg, 0x1337, ad) - _, err := client.Open(nil, encrypted[:len(encrypted)-1], time.Now(), 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).To(HaveOccurred()) - Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38)) - }) - - It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() { - client.invalidPacketLimit = 10 - for i := 0; i < 9; i++ { - _, err := client.Open(nil, []byte("foobar"), time.Now(), protocol.PacketNumber(i), protocol.KeyPhaseZero, []byte("ad")) - Expect(err).To(MatchError(ErrDecryptionFailed)) - } - _, err := client.Open(nil, []byte("foobar"), time.Now(), 10, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).To(HaveOccurred()) - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.AEADLimitReached)) - }) - - Context("key updates", func() { - Context("receiving key updates", func() { - It("updates keys", func() { - now := time.Now() - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - encrypted0 := server.Seal(nil, msg, 0x1337, ad) - server.rollKeys() - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - encrypted1 := server.Seal(nil, msg, 0x1337, ad) - Expect(encrypted0).ToNot(Equal(encrypted1)) - // expect opening to fail. The client didn't roll keys yet - _, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - client.rollKeys() - decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - }) - - It("updates the keys when receiving a packet with the next key phase", func() { - now := time.Now() - // receive the first packet at key phase zero - encrypted0 := client.Seal(nil, msg, 0x42, ad) - decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - // send one packet at key phase zero - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - _ = server.Seal(nil, msg, 0x1, ad) - // now received a message at key phase one - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x43, ad) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) - decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - - It("opens a reordered packet with the old keys after an update", func() { - now := time.Now() - encrypted01 := client.Seal(nil, msg, 0x42, ad) - encrypted02 := client.Seal(nil, msg, 0x43, ad) - // receive the first packet with key phase 0 - _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - // send one packet at key phase zero - _ = server.Seal(nil, msg, 0x1, ad) - // now receive a packet with key phase 1 - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x44, ad) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) - _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // now receive a reordered packet with key phase 0 - decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(decrypted).To(Equal(msg)) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - - It("drops keys 3 PTOs after a key update", func() { - now := time.Now() - rttStats.UpdateRTT(10*time.Millisecond, 0, now) - pto := rttStats.PTO(true) - encrypted01 := client.Seal(nil, msg, 0x42, ad) - encrypted02 := client.Seal(nil, msg, 0x43, ad) - // receive the first packet with key phase 0 - _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - // send one packet at key phase zero - _ = server.Seal(nil, msg, 0x1, ad) - // now receive a packet with key phase 1 - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x44, ad) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) - _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // now receive a reordered packet with key phase 0 - _, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrKeysDropped)) - }) - - It("allows the first key update immediately", func() { - // receive a packet at key phase one, before having sent or received any packets at key phase 0 - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x1337, ad) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) - _, err := server.Open(nil, encrypted1, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - }) - - It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() { - client.rollKeys() - encrypted := client.Seal(nil, msg, 0x1337, ad) - encrypted = encrypted[:len(encrypted)-1] - _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) - Expect(err).To(MatchError(ErrDecryptionFailed)) - }) + DescribeTable("ChaCha test vector", + func(v protocol.VersionNumber, expectedPayload, expectedPacket []byte) { + secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b") + aead := newUpdatableAEAD(&utils.RTTStats{}, nil, nil, v) + chacha := cipherSuites[2] + Expect(chacha.ID).To(Equal(tls.TLS_CHACHA20_POLY1305_SHA256)) + aead.SetWriteKey(chacha, secret) + const pnOffset = 1 + header := splitHexString("4200bff4") + payloadOffset := len(header) + plaintext := splitHexString("01") + payload := aead.Seal(nil, plaintext, 654360564, header) + Expect(payload).To(Equal(expectedPayload)) + packet := append(header, payload...) + aead.EncryptHeader(packet[pnOffset+4:pnOffset+4+16], &packet[0], packet[pnOffset:payloadOffset]) + Expect(packet).To(Equal(expectedPacket)) + }, + Entry("QUIC v1", + protocol.Version1, + splitHexString("655e5cd55c41f69080575d7999c25a5bfb"), + splitHexString("4cfe4189655e5cd55c41f69080575d7999c25a5bfb"), + ), + Entry("QUIC v2", + protocol.Version2, + splitHexString("0ae7b6b932bc27d786f4bc2bb20f2162ba"), + splitHexString("5558b1c60ae7b6b932bc27d786f4bc2bb20f2162ba"), + ), + ) + + for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1, protocol.Version2} { + v := ver + + Context(fmt.Sprintf("using version %s", v), func() { + for i := range cipherSuites { + cs := cipherSuites[i] + + Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() { + var ( + client, server *updatableAEAD + serverTracer *mocklogging.MockConnectionTracer + rttStats *utils.RTTStats + ) + + BeforeEach(func() { + serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) + trafficSecret1 := make([]byte, 16) + trafficSecret2 := make([]byte, 16) + rand.Read(trafficSecret1) + rand.Read(trafficSecret2) + + rttStats = utils.NewRTTStats() + client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, v) + server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger, v) + client.SetReadKey(cs, trafficSecret2) + client.SetWriteKey(cs, trafficSecret1) + server.SetReadKey(cs, trafficSecret1) + server.SetWriteKey(cs, trafficSecret2) + }) - It("errors when the peer updates keys too frequently", func() { - server.rollKeys() - client.rollKeys() - // receive the first packet at key phase one - encrypted0 := client.Seal(nil, msg, 0x42, ad) - _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - // now receive a packet at key phase two, before having sent any packets - client.rollKeys() - encrypted1 := client.Seal(nil, msg, 0x42, ad) - _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.KeyUpdateError, - ErrorMessage: "keys updated too quickly", - })) + Context("header protection", func() { + It("encrypts and decrypts the header", func() { + var lastFiveBitsDifferent int + for i := 0; i < 100; i++ { + sample := make([]byte, 16) + rand.Read(sample) + header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} + client.EncryptHeader(sample, &header[0], header[9:13]) + if header[0]&0x1f != 0xb5&0x1f { + lastFiveBitsDifferent++ + } + Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0))) + Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) + Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) + server.DecryptHeader(sample, &header[0], header[9:13]) + Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) + } + Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75)) }) }) - Context("initiating key updates", func() { - const keyUpdateInterval = 20 + Context("message encryption", func() { + var msg, ad []byte BeforeEach(func() { - Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) - server.keyUpdateInterval = keyUpdateInterval - server.SetHandshakeConfirmed() + msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + ad = []byte("Donec in velit neque.") }) - It("initiates a key update after sealing the maximum number of packets, for the first update", func() { - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - // the first update is allowed without receiving an acknowledgement - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - }) - - It("initiates a key update after sealing the maximum number of packets, for subsequent updates", func() { - server.rollKeys() - client.rollKeys() - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - server.Seal(nil, msg, pn, ad) - } - // no update allowed before receiving an acknowledgement for the current key phase - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // receive an ACK for a packet sent in key phase 0 - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseOne, []byte("ad")) + It("encrypts and decrypts a message", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + Expect(opened).To(Equal(msg)) }) - It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() { - // First make sure that we update our keys. - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // Now that our keys are updated, send a packet using the new keys. - const nextPN = keyUpdateInterval + 1 - server.Seal(nil, msg, nextPN, ad) - // We haven't decrypted any packet in the new key phase yet. - // This means that the ACK must have been sent in the old key phase. - Expect(server.SetLargestAcked(nextPN)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.KeyUpdateError, - ErrorMessage: "received ACK for key phase 1, but peer didn't update keys", - })) + It("saves the first packet number", func() { + client.Seal(nil, msg, 0x1337, ad) + Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) + client.Seal(nil, msg, 0x1338, ad) + Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) }) - It("doesn't error before actually sending a packet in the new key phase", func() { - // First make sure that we update our keys. - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - // Now that our keys are updated, send a packet using the new keys. - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // We haven't decrypted any packet in the new key phase yet. - // This means that the ACK must have been sent in the old key phase. - Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred()) - }) - - It("initiates a key update after opening the maximum number of packets, for the first update", func() { - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - encrypted := client.Seal(nil, msg, pn, ad) - _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - } - // the first update is allowed without receiving an acknowledgement - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + It("fails to open a message if the associated data is not the same", func() { + encrypted := client.Seal(nil, msg, 0x1337, ad) + _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) + Expect(err).To(MatchError(ErrDecryptionFailed)) }) - It("initiates a key update after opening the maximum number of packets, for subsequent updates", func() { - server.rollKeys() - client.rollKeys() - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - encrypted := client.Seal(nil, msg, pn, ad) - _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - } - // no update allowed before receiving an acknowledgement for the current key phase - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - server.Seal(nil, msg, 1, ad) - Expect(server.SetLargestAcked(1)).To(Succeed()) - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + It("fails to open a message if the packet number is not the same", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + _, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) }) - It("drops keys 3 PTOs after a key update", func() { - now := time.Now() - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, now, 1, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - Expect(server.SetLargestAcked(0)).To(Succeed()) - // Now we've initiated the first key update. - // Decrypt a message sent from the client more than 3 PTO later to make sure the key is still there - threePTO := 3 * rttStats.PTO(false) - dataKeyPhaseZero := client.Seal(nil, msg, 1, ad) - _, err = server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad) + It("decodes the packet number", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + _, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) Expect(err).ToNot(HaveOccurred()) - // Now receive a packet with key phase 1. - // This should start the timer to drop the keys after 3 PTOs. - client.rollKeys() - dataKeyPhaseOne := client.Seal(nil, msg, 10, ad) - t := now.Add(threePTO).Add(time.Second) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) - _, err = server.Open(nil, dataKeyPhaseOne, t, 10, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - // Make sure the keys are still here. - _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO*9/10), 1, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) - _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO).Add(time.Nanosecond), 1, protocol.KeyPhaseZero, ad) - Expect(err).To(MatchError(ErrKeysDropped)) + Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338)) }) - It("doesn't drop the first key generation too early", func() { - now := time.Now() - data1 := client.Seal(nil, msg, 1, ad) - _, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - Expect(server.SetLargestAcked(pn)).To(Succeed()) - } - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // The server never received a packet at key phase 1. - // Make sure the key phase 0 is still there at a much later point. - data2 := client.Seal(nil, msg, 1, ad) - _, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) + It("ignores packets it can't decrypt for packet number derivation", func() { + encrypted := server.Seal(nil, msg, 0x1337, ad) + _, err := client.Open(nil, encrypted[:len(encrypted)-1], time.Now(), 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).To(HaveOccurred()) + Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38)) }) - It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() { - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) + It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() { + client.invalidPacketLimit = 10 + for i := 0; i < 9; i++ { + _, err := client.Open(nil, []byte("foobar"), time.Now(), protocol.PacketNumber(i), protocol.KeyPhaseZero, []byte("ad")) + Expect(err).To(MatchError(ErrDecryptionFailed)) } - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - const nextPN = keyUpdateInterval + 1 - // Send and receive an acknowledgement for a packet in key phase 1. - // We are now running a timer to drop the keys with 3 PTO. - server.Seal(nil, msg, nextPN, ad) - client.rollKeys() - dataKeyPhaseOne := client.Seal(nil, msg, 2, ad) - now := time.Now() - _, err = server.Open(nil, dataKeyPhaseOne, now, 2, protocol.KeyPhaseOne, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.SetLargestAcked(nextPN)) - // Now the client sends us a packet in key phase 2, forcing us to update keys before the 3 PTO period is over. - // This mean that we need to drop the keys for key phase 0 immediately. - client.rollKeys() - dataKeyPhaseTwo := client.Seal(nil, msg, 3, ad) - gomock.InOrder( - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), true), - ) - _, err = server.Open(nil, dataKeyPhaseTwo, now, 3, protocol.KeyPhaseZero, ad) - Expect(err).ToNot(HaveOccurred()) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + _, err := client.Open(nil, []byte("foobar"), time.Now(), 10, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).To(HaveOccurred()) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.AEADLimitReached)) }) - It("drops keys early when we initiate another key update within the 3 PTO period", func() { - server.SetHandshakeConfirmed() - // send so many packets that we initiate the first key update - for i := 0; i < keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - server.Seal(nil, msg, pn, ad) - } - b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) - _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - // send so many packets that we initiate the next key update - for i := keyUpdateInterval; i < 2*keyUpdateInterval; i++ { - pn := protocol.PacketNumber(i) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) - server.Seal(nil, msg, pn, ad) - } - client.rollKeys() - b = client.Seal(nil, []byte("foobar"), 2, []byte("ad")) - now := time.Now() - _, err = server.Open(nil, b, now, 2, protocol.KeyPhaseOne, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) - ExpectWithOffset(1, server.SetLargestAcked(keyUpdateInterval)).To(Succeed()) - gomock.InOrder( - serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), - serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false), - ) - Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) - // We haven't received an ACK for a packet sent in key phase 2 yet. - // Make sure we canceled the timer to drop the previous key phase. - b = client.Seal(nil, []byte("foobar"), 3, []byte("ad")) - _, err = server.Open(nil, b, now.Add(10*rttStats.PTO(true)), 3, protocol.KeyPhaseOne, []byte("ad")) - Expect(err).ToNot(HaveOccurred()) + Context("key updates", func() { + Context("receiving key updates", func() { + It("updates keys", func() { + now := time.Now() + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + encrypted0 := server.Seal(nil, msg, 0x1337, ad) + server.rollKeys() + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + encrypted1 := server.Seal(nil, msg, 0x1337, ad) + Expect(encrypted0).ToNot(Equal(encrypted1)) + // expect opening to fail. The client didn't roll keys yet + _, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) + client.rollKeys() + decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + }) + + It("updates the keys when receiving a packet with the next key phase", func() { + now := time.Now() + // receive the first packet at key phase zero + encrypted0 := client.Seal(nil, msg, 0x42, ad) + decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + // send one packet at key phase zero + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + _ = server.Seal(nil, msg, 0x1, ad) + // now received a message at key phase one + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x43, ad) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("opens a reordered packet with the old keys after an update", func() { + now := time.Now() + encrypted01 := client.Seal(nil, msg, 0x42, ad) + encrypted02 := client.Seal(nil, msg, 0x43, ad) + // receive the first packet with key phase 0 + _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // send one packet at key phase zero + _ = server.Seal(nil, msg, 0x1, ad) + // now receive a packet with key phase 1 + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x44, ad) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // now receive a reordered packet with key phase 0 + decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(decrypted).To(Equal(msg)) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("drops keys 3 PTOs after a key update", func() { + now := time.Now() + rttStats.UpdateRTT(10*time.Millisecond, 0, now) + pto := rttStats.PTO(true) + encrypted01 := client.Seal(nil, msg, 0x42, ad) + encrypted02 := client.Seal(nil, msg, 0x43, ad) + // receive the first packet with key phase 0 + _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // send one packet at key phase zero + _ = server.Seal(nil, msg, 0x1, ad) + // now receive a packet with key phase 1 + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x44, ad) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // now receive a reordered packet with key phase 0 + _, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrKeysDropped)) + }) + + It("allows the first key update immediately", func() { + // receive a packet at key phase one, before having sent or received any packets at key phase 0 + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x1337, ad) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + _, err := server.Open(nil, encrypted1, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + }) + + It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() { + client.rollKeys() + encrypted := client.Seal(nil, msg, 0x1337, ad) + encrypted = encrypted[:len(encrypted)-1] + _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) + Expect(err).To(MatchError(ErrDecryptionFailed)) + }) + + It("errors when the peer updates keys too frequently", func() { + server.rollKeys() + client.rollKeys() + // receive the first packet at key phase one + encrypted0 := client.Seal(nil, msg, 0x42, ad) + _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + // now receive a packet at key phase two, before having sent any packets + client.rollKeys() + encrypted1 := client.Seal(nil, msg, 0x42, ad) + _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.KeyUpdateError, + ErrorMessage: "keys updated too quickly", + })) + }) + }) + + Context("initiating key updates", func() { + const keyUpdateInterval = 20 + + BeforeEach(func() { + Expect(server.keyUpdateInterval).To(BeEquivalentTo(protocol.KeyUpdateInterval)) + server.keyUpdateInterval = keyUpdateInterval + server.SetHandshakeConfirmed() + }) + + It("initiates a key update after sealing the maximum number of packets, for the first update", func() { + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + // the first update is allowed without receiving an acknowledgement + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("initiates a key update after sealing the maximum number of packets, for subsequent updates", func() { + server.rollKeys() + client.rollKeys() + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + server.Seal(nil, msg, pn, ad) + } + // no update allowed before receiving an acknowledgement for the current key phase + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // receive an ACK for a packet sent in key phase 0 + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseOne, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + }) + + It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() { + // First make sure that we update our keys. + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // Now that our keys are updated, send a packet using the new keys. + const nextPN = keyUpdateInterval + 1 + server.Seal(nil, msg, nextPN, ad) + // We haven't decrypted any packet in the new key phase yet. + // This means that the ACK must have been sent in the old key phase. + Expect(server.SetLargestAcked(nextPN)).To(MatchError(&qerr.TransportError{ + ErrorCode: qerr.KeyUpdateError, + ErrorMessage: "received ACK for key phase 1, but peer didn't update keys", + })) + }) + + It("doesn't error before actually sending a packet in the new key phase", func() { + // First make sure that we update our keys. + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + // Now that our keys are updated, send a packet using the new keys. + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // We haven't decrypted any packet in the new key phase yet. + // This means that the ACK must have been sent in the old key phase. + Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred()) + }) + + It("initiates a key update after opening the maximum number of packets, for the first update", func() { + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + encrypted := client.Seal(nil, msg, pn, ad) + _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + } + // the first update is allowed without receiving an acknowledgement + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + }) + + It("initiates a key update after opening the maximum number of packets, for subsequent updates", func() { + server.rollKeys() + client.rollKeys() + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + encrypted := client.Seal(nil, msg, pn, ad) + _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + } + // no update allowed before receiving an acknowledgement for the current key phase + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + server.Seal(nil, msg, 1, ad) + Expect(server.SetLargestAcked(1)).To(Succeed()) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + }) + + It("drops keys 3 PTOs after a key update", func() { + now := time.Now() + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, now, 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + Expect(server.SetLargestAcked(0)).To(Succeed()) + // Now we've initiated the first key update. + // Decrypt a message sent from the client more than 3 PTO later to make sure the key is still there + threePTO := 3 * rttStats.PTO(false) + dataKeyPhaseZero := client.Seal(nil, msg, 1, ad) + _, err = server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + // Now receive a packet with key phase 1. + // This should start the timer to drop the keys after 3 PTOs. + client.rollKeys() + dataKeyPhaseOne := client.Seal(nil, msg, 10, ad) + t := now.Add(threePTO).Add(time.Second) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) + _, err = server.Open(nil, dataKeyPhaseOne, t, 10, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + // Make sure the keys are still here. + _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO*9/10), 1, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) + _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO).Add(time.Nanosecond), 1, protocol.KeyPhaseZero, ad) + Expect(err).To(MatchError(ErrKeysDropped)) + }) + + It("doesn't drop the first key generation too early", func() { + now := time.Now() + data1 := client.Seal(nil, msg, 1, ad) + _, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + Expect(server.SetLargestAcked(pn)).To(Succeed()) + } + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // The server never received a packet at key phase 1. + // Make sure the key phase 0 is still there at a much later point. + data2 := client.Seal(nil, msg, 1, ad) + _, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + }) + + It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() { + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + const nextPN = keyUpdateInterval + 1 + // Send and receive an acknowledgement for a packet in key phase 1. + // We are now running a timer to drop the keys with 3 PTO. + server.Seal(nil, msg, nextPN, ad) + client.rollKeys() + dataKeyPhaseOne := client.Seal(nil, msg, 2, ad) + now := time.Now() + _, err = server.Open(nil, dataKeyPhaseOne, now, 2, protocol.KeyPhaseOne, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.SetLargestAcked(nextPN)) + // Now the client sends us a packet in key phase 2, forcing us to update keys before the 3 PTO period is over. + // This mean that we need to drop the keys for key phase 0 immediately. + client.rollKeys() + dataKeyPhaseTwo := client.Seal(nil, msg, 3, ad) + gomock.InOrder( + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), true), + ) + _, err = server.Open(nil, dataKeyPhaseTwo, now, 3, protocol.KeyPhaseZero, ad) + Expect(err).ToNot(HaveOccurred()) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + }) + + It("drops keys early when we initiate another key update within the 3 PTO period", func() { + server.SetHandshakeConfirmed() + // send so many packets that we initiate the first key update + for i := 0; i < keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + server.Seal(nil, msg, pn, ad) + } + b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) + _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + // send so many packets that we initiate the next key update + for i := keyUpdateInterval; i < 2*keyUpdateInterval; i++ { + pn := protocol.PacketNumber(i) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) + server.Seal(nil, msg, pn, ad) + } + client.rollKeys() + b = client.Seal(nil, []byte("foobar"), 2, []byte("ad")) + now := time.Now() + _, err = server.Open(nil, b, now, 2, protocol.KeyPhaseOne, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + ExpectWithOffset(1, server.SetLargestAcked(keyUpdateInterval)).To(Succeed()) + gomock.InOrder( + serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), + serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false), + ) + Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) + // We haven't received an ACK for a packet sent in key phase 2 yet. + // Make sure we canceled the timer to drop the previous key phase. + b = client.Seal(nil, []byte("foobar"), 3, []byte("ad")) + _, err = server.Open(nil, b, now.Add(10*rttStats.PTO(true)), 3, protocol.KeyPhaseOne, []byte("ad")) + Expect(err).ToNot(HaveOccurred()) + }) + }) }) }) }) - }) + } }) } }) diff --git a/internal/protocol/version.go b/internal/protocol/version.go index b5276303a2e..dd54dbd3cb5 100644 --- a/internal/protocol/version.go +++ b/internal/protocol/version.go @@ -23,11 +23,12 @@ const ( VersionUnknown VersionNumber = math.MaxUint32 VersionDraft29 VersionNumber = 0xff00001d Version1 VersionNumber = 0x1 + Version2 VersionNumber = 0x709a50c4 ) // SupportedVersions lists the versions that the server supports // must be in sorted descending order -var SupportedVersions = []VersionNumber{Version1, VersionDraft29} +var SupportedVersions = []VersionNumber{Version1, Version2, VersionDraft29} // IsValidVersion says if the version is known to quic-go func IsValidVersion(v VersionNumber) bool { @@ -50,6 +51,8 @@ func (vn VersionNumber) String() string { return "draft-29" case Version1: return "v1" + case Version2: + return "v2" default: if vn.isGQUIC() { return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion()) diff --git a/internal/protocol/version_test.go b/internal/protocol/version_test.go index 17c067b8cc6..33c6598b445 100644 --- a/internal/protocol/version_test.go +++ b/internal/protocol/version_test.go @@ -16,6 +16,7 @@ var _ = Describe("Version", func() { Expect(IsValidVersion(VersionUnknown)).To(BeFalse()) Expect(IsValidVersion(VersionDraft29)).To(BeTrue()) Expect(IsValidVersion(Version1)).To(BeTrue()) + Expect(IsValidVersion(Version2)).To(BeTrue()) Expect(IsValidVersion(1234)).To(BeFalse()) }) @@ -28,6 +29,7 @@ var _ = Describe("Version", func() { Expect(VersionUnknown.String()).To(Equal("unknown")) Expect(VersionDraft29.String()).To(Equal("draft-29")) Expect(Version1.String()).To(Equal("v1")) + Expect(Version2.String()).To(Equal("v2")) // check with unsupported version numbers from the wiki Expect(VersionNumber(0x51303039).String()).To(Equal("gQUIC 9")) Expect(VersionNumber(0x51303133).String()).To(Equal("gQUIC 13")) @@ -42,13 +44,6 @@ var _ = Describe("Version", func() { Expect(IsSupportedVersion(SupportedVersions, SupportedVersions[len(SupportedVersions)-1])).To(BeTrue()) }) - It("has supported versions in sorted order", func() { - Expect(SupportedVersions[0]).To(Equal(Version1)) - for i := 1; i < len(SupportedVersions)-1; i++ { - Expect(SupportedVersions[i]).To(BeNumerically(">", SupportedVersions[i+1])) - } - }) - Context("highest supported version", func() { It("finds the supported version", func() { supportedVersions := []VersionNumber{1, 2, 3} diff --git a/internal/testutils/testutils.go b/internal/testutils/testutils.go index b4ed9147121..87e2b975c83 100644 --- a/internal/testutils/testutils.go +++ b/internal/testutils/testutils.go @@ -14,7 +14,7 @@ import ( // writePacket returns a new raw packet with the specified header and payload func writePacket(hdr *wire.ExtendedHeader, data []byte) []byte { buf := &bytes.Buffer{} - hdr.Write(buf, protocol.VersionTLS) + hdr.Write(buf, hdr.Version) return append(buf.Bytes(), data...) } diff --git a/internal/wire/ack_frame_test.go b/internal/wire/ack_frame_test.go index 15c0f00ba18..aa00f92ac0e 100644 --- a/internal/wire/ack_frame_test.go +++ b/internal/wire/ack_frame_test.go @@ -20,7 +20,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(10)...) // first ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(90))) @@ -35,7 +35,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(0)...) // first ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(55))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(55))) @@ -50,7 +50,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(20)...) // first ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(20))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(0))) @@ -65,7 +65,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(21)...) // first ack block b := bytes.NewReader(data) - _, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) + _, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) Expect(err).To(MatchError("invalid first ACK range")) }) @@ -78,7 +78,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(98)...) // gap data = append(data, encodeVarInt(50)...) // ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(1000))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(750))) @@ -101,7 +101,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(1)...) // gap data = append(data, encodeVarInt(1)...) // ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(94))) @@ -121,10 +121,10 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, DelayTime: delayTime, } - Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) for i := uint8(0); i < 8; i++ { b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent+i, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent+i, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.DelayTime).To(Equal(delayTime * (1 << i))) } @@ -137,7 +137,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0)...) // num blocks data = append(data, encodeVarInt(0)...) // first ack block b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.DelayTime).To(BeNumerically(">", 0)) // The maximum encodable duration is ~292 years. @@ -152,10 +152,10 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(100)...) // first ack block data = append(data, encodeVarInt(98)...) // gap data = append(data, encodeVarInt(50)...) // ack block - _, err := parseAckFrame(bytes.NewReader(data), protocol.AckDelayExponent, versionIETFFrames) + _, err := parseAckFrame(bytes.NewReader(data), protocol.AckDelayExponent, protocol.Version1) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parseAckFrame(bytes.NewReader(data[0:i]), protocol.AckDelayExponent, versionIETFFrames) + _, err := parseAckFrame(bytes.NewReader(data[0:i]), protocol.AckDelayExponent, protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) @@ -171,7 +171,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0x12345)...) // ECT(1) data = append(data, encodeVarInt(0x12345678)...) // ECN-CE b := bytes.NewReader(data) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.LargestAcked()).To(Equal(protocol.PacketNumber(100))) Expect(frame.LowestAcked()).To(Equal(protocol.PacketNumber(90))) @@ -190,10 +190,10 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { data = append(data, encodeVarInt(0x42)...) // ECT(0) data = append(data, encodeVarInt(0x12345)...) // ECT(1) data = append(data, encodeVarInt(0x12345678)...) // ECN-CE - _, err := parseAckFrame(bytes.NewReader(data), protocol.AckDelayExponent, versionIETFFrames) + _, err := parseAckFrame(bytes.NewReader(data), protocol.AckDelayExponent, protocol.Version1) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parseAckFrame(bytes.NewReader(data[0:i]), protocol.AckDelayExponent, versionIETFFrames) + _, err := parseAckFrame(bytes.NewReader(data[0:i]), protocol.AckDelayExponent, protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) @@ -206,7 +206,7 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { f := &AckFrame{ AckRanges: []AckRange{{Smallest: 100, Largest: 1337}}, } - Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) expected := []byte{0x2} expected = append(expected, encodeVarInt(1337)...) // largest acked expected = append(expected, 0) // delay @@ -223,8 +223,8 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { ECT1: 37, ECNCE: 12345, } - Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) - Expect(f.Length(versionIETFFrames)).To(BeEquivalentTo(buf.Len())) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) expected := []byte{0x3} expected = append(expected, encodeVarInt(2000)...) // largest acked expected = append(expected, 0) // delay @@ -242,10 +242,10 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { AckRanges: []AckRange{{Smallest: 0x2eadbeef, Largest: 0x2eadbeef}}, DelayTime: 18 * time.Millisecond, } - Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) - Expect(f.Length(versionIETFFrames)).To(BeEquivalentTo(buf.Len())) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeFalse()) @@ -258,10 +258,10 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { f := &AckFrame{ AckRanges: []AckRange{{Smallest: 0x1337, Largest: 0x2eadbeef}}, } - Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) - Expect(f.Length(versionIETFFrames)).To(BeEquivalentTo(buf.Len())) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeFalse()) @@ -277,11 +277,11 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { }, } Expect(f.validateAckRanges()).To(BeTrue()) - err := f.Write(buf, versionIETFFrames) + err := f.Write(buf, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(f.Length(versionIETFFrames)).To(BeEquivalentTo(buf.Len())) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeTrue()) @@ -299,10 +299,10 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { }, } Expect(f.validateAckRanges()).To(BeTrue()) - Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) - Expect(f.Length(versionIETFFrames)).To(BeEquivalentTo(buf.Len())) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) Expect(frame.HasMissingRanges()).To(BeTrue()) @@ -318,13 +318,13 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() { } f := &AckFrame{AckRanges: ackRanges} Expect(f.validateAckRanges()).To(BeTrue()) - Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) - Expect(f.Length(versionIETFFrames)).To(BeEquivalentTo(buf.Len())) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len())) // make sure the ACK frame is *a little bit* smaller than the MaxAckFrameSize Expect(buf.Len()).To(BeNumerically(">", protocol.MaxAckFrameSize-5)) Expect(buf.Len()).To(BeNumerically("<=", protocol.MaxAckFrameSize)) b := bytes.NewReader(buf.Bytes()) - frame, err := parseAckFrame(b, protocol.AckDelayExponent, versionIETFFrames) + frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.HasMissingRanges()).To(BeTrue()) Expect(b.Len()).To(BeZero()) diff --git a/internal/wire/connection_close_frame_test.go b/internal/wire/connection_close_frame_test.go index c116454a35e..9c5e6661560 100644 --- a/internal/wire/connection_close_frame_test.go +++ b/internal/wire/connection_close_frame_test.go @@ -20,7 +20,7 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length data = append(data, []byte(reason)...) b := bytes.NewReader(data) - frame, err := parseConnectionCloseFrame(b, versionIETFFrames) + frame, err := parseConnectionCloseFrame(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.IsApplicationError).To(BeFalse()) Expect(frame.ErrorCode).To(BeEquivalentTo(0x19)) @@ -36,7 +36,7 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length data = append(data, reason...) b := bytes.NewReader(data) - frame, err := parseConnectionCloseFrame(b, versionIETFFrames) + frame, err := parseConnectionCloseFrame(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.IsApplicationError).To(BeTrue()) Expect(frame.ErrorCode).To(BeEquivalentTo(0xcafe)) @@ -50,7 +50,7 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { data = append(data, encodeVarInt(0x42)...) // frame type data = append(data, encodeVarInt(0xffff)...) // reason phrase length b := bytes.NewReader(data) - _, err := parseConnectionCloseFrame(b, versionIETFFrames) + _, err := parseConnectionCloseFrame(b, protocol.Version1) Expect(err).To(MatchError(io.EOF)) }) @@ -61,10 +61,10 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { data = append(data, encodeVarInt(0x1337)...) // frame type data = append(data, encodeVarInt(uint64(len(reason)))...) // reason phrase length data = append(data, []byte(reason)...) - _, err := parseConnectionCloseFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parseConnectionCloseFrame(bytes.NewReader(data), protocol.Version1) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parseConnectionCloseFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + _, err := parseConnectionCloseFrame(bytes.NewReader(data[0:i]), protocol.Version1) Expect(err).To(HaveOccurred()) } }) @@ -75,7 +75,7 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { data = append(data, encodeVarInt(0x42)...) // frame type data = append(data, encodeVarInt(0)...) b := bytes.NewReader(data) - frame, err := parseConnectionCloseFrame(b, versionIETFFrames) + frame, err := parseConnectionCloseFrame(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.ReasonPhrase).To(BeEmpty()) Expect(b.Len()).To(BeZero()) @@ -89,7 +89,7 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { ErrorCode: 0xbeef, FrameType: 0x12345, } - Expect(frame.Write(b, versionIETFFrames)).To(Succeed()) + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) expected := []byte{0x1c} expected = append(expected, encodeVarInt(0xbeef)...) expected = append(expected, encodeVarInt(0x12345)...) // frame type @@ -103,7 +103,7 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { ErrorCode: 0xdead, ReasonPhrase: "foobar", } - Expect(frame.Write(b, versionIETFFrames)).To(Succeed()) + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) expected := []byte{0x1c} expected = append(expected, encodeVarInt(0xdead)...) expected = append(expected, encodeVarInt(0)...) // frame type @@ -119,7 +119,7 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { ErrorCode: 0xdead, ReasonPhrase: "foobar", } - Expect(frame.Write(b, versionIETFFrames)).To(Succeed()) + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) expected := []byte{0x1d} expected = append(expected, encodeVarInt(0xdead)...) expected = append(expected, encodeVarInt(6)...) // reason phrase length @@ -134,8 +134,8 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { FrameType: 0xdeadbeef, ReasonPhrase: "foobar", } - Expect(f.Write(b, versionIETFFrames)).To(Succeed()) - Expect(f.Length(versionIETFFrames)).To(Equal(protocol.ByteCount(b.Len()))) + Expect(f.Write(b, protocol.Version1)).To(Succeed()) + Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(b.Len()))) }) It("has proper min length, for a frame containing an application error code", func() { @@ -145,9 +145,9 @@ var _ = Describe("CONNECTION_CLOSE Frame", func() { ErrorCode: 0xcafe, ReasonPhrase: "foobar", } - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) - Expect(f.Length(versionIETFFrames)).To(Equal(protocol.ByteCount(b.Len()))) + Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(b.Len()))) }) }) }) diff --git a/internal/wire/crypto_frame_test.go b/internal/wire/crypto_frame_test.go index d72ea527622..c37981012ad 100644 --- a/internal/wire/crypto_frame_test.go +++ b/internal/wire/crypto_frame_test.go @@ -18,7 +18,7 @@ var _ = Describe("CRYPTO frame", func() { data = append(data, encodeVarInt(6)...) // length data = append(data, []byte("foobar")...) r := bytes.NewReader(data) - frame, err := parseCryptoFrame(r, versionIETFFrames) + frame, err := parseCryptoFrame(r, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.Offset).To(Equal(protocol.ByteCount(0xdecafbad))) Expect(frame.Data).To(Equal([]byte("foobar"))) @@ -30,10 +30,10 @@ var _ = Describe("CRYPTO frame", func() { data = append(data, encodeVarInt(0xdecafbad)...) // offset data = append(data, encodeVarInt(6)...) // data length data = append(data, []byte("foobar")...) - _, err := parseCryptoFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parseCryptoFrame(bytes.NewReader(data), protocol.Version1) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parseCryptoFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + _, err := parseCryptoFrame(bytes.NewReader(data[0:i]), protocol.Version1) Expect(err).To(HaveOccurred()) } }) @@ -46,7 +46,7 @@ var _ = Describe("CRYPTO frame", func() { Data: []byte("foobar"), } b := &bytes.Buffer{} - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x6} expected = append(expected, encodeVarInt(0x123456)...) // offset @@ -73,13 +73,13 @@ var _ = Describe("CRYPTO frame", func() { if maxDataLen == 0 { // 0 means that no valid CRYTPO frame can be written // check that writing a minimal size CRYPTO frame (i.e. with 1 byte data) is actually larger than the desired size f.Data = []byte{0} - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(b.Len()).To(BeNumerically(">", i)) continue } f.Data = data[:int(maxDataLen)] - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) // There's *one* pathological case, where a data length of x can be encoded into 1 byte // but a data lengths of x+1 needs 2 bytes @@ -100,7 +100,7 @@ var _ = Describe("CRYPTO frame", func() { Offset: 0x1337, Data: []byte("foobar"), } - Expect(f.Length(versionIETFFrames)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(6) + 6)) + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(6) + 6)) }) }) @@ -110,8 +110,8 @@ var _ = Describe("CRYPTO frame", func() { Offset: 0x1337, Data: []byte("foobar"), } - hdrLen := f.Length(versionIETFFrames) - 6 - new, needsSplit := f.MaybeSplitOffFrame(hdrLen+3, versionIETFFrames) + hdrLen := f.Length(protocol.Version1) - 6 + new, needsSplit := f.MaybeSplitOffFrame(hdrLen+3, protocol.Version1) Expect(needsSplit).To(BeTrue()) Expect(new.Data).To(Equal([]byte("foo"))) Expect(new.Offset).To(Equal(protocol.ByteCount(0x1337))) @@ -124,7 +124,7 @@ var _ = Describe("CRYPTO frame", func() { Offset: 0x1337, Data: []byte("foobar"), } - f, needsSplit := f.MaybeSplitOffFrame(f.Length(versionIETFFrames), versionIETFFrames) + f, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1), protocol.Version1) Expect(needsSplit).To(BeFalse()) Expect(f).To(BeNil()) }) @@ -134,13 +134,13 @@ var _ = Describe("CRYPTO frame", func() { Offset: 0x1337, Data: []byte("foobar"), } - length := f.Length(versionIETFFrames) - 6 + length := f.Length(protocol.Version1) - 6 for i := protocol.ByteCount(0); i <= length; i++ { - f, needsSplit := f.MaybeSplitOffFrame(i, versionIETFFrames) + f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) Expect(needsSplit).To(BeTrue()) Expect(f).To(BeNil()) } - f, needsSplit := f.MaybeSplitOffFrame(length+1, versionIETFFrames) + f, needsSplit := f.MaybeSplitOffFrame(length+1, protocol.Version1) Expect(needsSplit).To(BeTrue()) Expect(f).ToNot(BeNil()) }) diff --git a/internal/wire/data_blocked_frame_test.go b/internal/wire/data_blocked_frame_test.go index 154d08f1d40..57ffd9a801b 100644 --- a/internal/wire/data_blocked_frame_test.go +++ b/internal/wire/data_blocked_frame_test.go @@ -17,7 +17,7 @@ var _ = Describe("DATA_BLOCKED frame", func() { data := []byte{0x14} data = append(data, encodeVarInt(0x12345678)...) b := bytes.NewReader(data) - frame, err := parseDataBlockedFrame(b, versionIETFFrames) + frame, err := parseDataBlockedFrame(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.MaximumData).To(Equal(protocol.ByteCount(0x12345678))) Expect(b.Len()).To(BeZero()) @@ -26,10 +26,10 @@ var _ = Describe("DATA_BLOCKED frame", func() { It("errors on EOFs", func() { data := []byte{0x14} data = append(data, encodeVarInt(0x12345678)...) - _, err := parseDataBlockedFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parseDataBlockedFrame(bytes.NewReader(data), protocol.Version1) Expect(err).ToNot(HaveOccurred()) for i := range data { - _, err := parseDataBlockedFrame(bytes.NewReader(data[:i]), versionIETFFrames) + _, err := parseDataBlockedFrame(bytes.NewReader(data[:i]), protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) @@ -48,7 +48,7 @@ var _ = Describe("DATA_BLOCKED frame", func() { It("has the correct min length", func() { frame := DataBlockedFrame{MaximumData: 0x12345} - Expect(frame.Length(versionIETFFrames)).To(Equal(1 + quicvarint.Len(0x12345))) + Expect(frame.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x12345))) }) }) }) diff --git a/internal/wire/datagram_frame_test.go b/internal/wire/datagram_frame_test.go index 500bd9122bd..4431eb8b0c0 100644 --- a/internal/wire/datagram_frame_test.go +++ b/internal/wire/datagram_frame_test.go @@ -18,7 +18,7 @@ var _ = Describe("STREAM frame", func() { data = append(data, encodeVarInt(0x6)...) // length data = append(data, []byte("foobar")...) r := bytes.NewReader(data) - frame, err := parseDatagramFrame(r, versionIETFFrames) + frame, err := parseDatagramFrame(r, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.Data).To(Equal([]byte("foobar"))) Expect(frame.DataLenPresent).To(BeTrue()) @@ -29,7 +29,7 @@ var _ = Describe("STREAM frame", func() { data := []byte{0x30} data = append(data, []byte("Lorem ipsum dolor sit amet")...) r := bytes.NewReader(data) - frame, err := parseDatagramFrame(r, versionIETFFrames) + frame, err := parseDatagramFrame(r, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.Data).To(Equal([]byte("Lorem ipsum dolor sit amet"))) Expect(frame.DataLenPresent).To(BeFalse()) @@ -41,7 +41,7 @@ var _ = Describe("STREAM frame", func() { data = append(data, encodeVarInt(0x6)...) // length data = append(data, []byte("fooba")...) r := bytes.NewReader(data) - _, err := parseDatagramFrame(r, versionIETFFrames) + _, err := parseDatagramFrame(r, protocol.Version1) Expect(err).To(MatchError(io.EOF)) }) @@ -49,10 +49,10 @@ var _ = Describe("STREAM frame", func() { data := []byte{0x30 ^ 0x1} data = append(data, encodeVarInt(6)...) // length data = append(data, []byte("foobar")...) - _, err := parseDatagramFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parseDatagramFrame(bytes.NewReader(data), protocol.Version1) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parseDatagramFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + _, err := parseDatagramFrame(bytes.NewReader(data[0:i]), protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) @@ -65,7 +65,7 @@ var _ = Describe("STREAM frame", func() { Data: []byte("foobar"), } buf := &bytes.Buffer{} - Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) expected := []byte{0x30 ^ 0x1} expected = append(expected, encodeVarInt(0x6)...) expected = append(expected, []byte("foobar")...) @@ -75,7 +75,7 @@ var _ = Describe("STREAM frame", func() { It("writes a frame without length", func() { f := &DatagramFrame{Data: []byte("Lorem ipsum")} buf := &bytes.Buffer{} - Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) expected := []byte{0x30} expected = append(expected, []byte("Lorem ipsum")...) Expect(buf.Bytes()).To(Equal(expected)) @@ -88,12 +88,12 @@ var _ = Describe("STREAM frame", func() { DataLenPresent: true, Data: []byte("foobar"), } - Expect(f.Length(versionIETFFrames)).To(Equal(1 + quicvarint.Len(6) + 6)) + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(6) + 6)) }) It("has the right length for a frame without length", func() { f := &DatagramFrame{Data: []byte("foobar")} - Expect(f.Length(versionIETFFrames)).To(Equal(protocol.ByteCount(1 + 6))) + Expect(f.Length(protocol.Version1)).To(Equal(protocol.ByteCount(1 + 6))) }) }) @@ -107,16 +107,16 @@ var _ = Describe("STREAM frame", func() { for i := 1; i < 3000; i++ { b.Reset() f.Data = nil - maxDataLen := f.MaxDataLen(protocol.ByteCount(i), versionIETFFrames) + maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size f.Data = []byte{0} - Expect(f.Write(b, versionIETFFrames)).To(Succeed()) + Expect(f.Write(b, protocol.Version1)).To(Succeed()) Expect(b.Len()).To(BeNumerically(">", i)) continue } f.Data = data[:int(maxDataLen)] - Expect(f.Write(b, versionIETFFrames)).To(Succeed()) + Expect(f.Write(b, protocol.Version1)).To(Succeed()) Expect(b.Len()).To(Equal(i)) } }) @@ -129,16 +129,16 @@ var _ = Describe("STREAM frame", func() { for i := 1; i < 3000; i++ { b.Reset() f.Data = nil - maxDataLen := f.MaxDataLen(protocol.ByteCount(i), versionIETFFrames) + maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size f.Data = []byte{0} - Expect(f.Write(b, versionIETFFrames)).To(Succeed()) + Expect(f.Write(b, protocol.Version1)).To(Succeed()) Expect(b.Len()).To(BeNumerically(">", i)) continue } f.Data = data[:int(maxDataLen)] - Expect(f.Write(b, versionIETFFrames)).To(Succeed()) + Expect(f.Write(b, protocol.Version1)).To(Succeed()) // There's *one* pathological case, where a data length of x can be encoded into 1 byte // but a data lengths of x+1 needs 2 bytes // In that case, it's impossible to create a STREAM frame of the desired size diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go index b8938acafb2..9d9edab25c4 100644 --- a/internal/wire/extended_header.go +++ b/internal/wire/extended_header.go @@ -127,18 +127,32 @@ func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) erro return h.writeShortHeader(b, ver) } -func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, _ protocol.VersionNumber) error { +func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.VersionNumber) error { var packetType uint8 - //nolint:exhaustive - switch h.Type { - case protocol.PacketTypeInitial: - packetType = 0x0 - case protocol.PacketType0RTT: - packetType = 0x1 - case protocol.PacketTypeHandshake: - packetType = 0x2 - case protocol.PacketTypeRetry: - packetType = 0x3 + if version == protocol.Version2 { + //nolint:exhaustive + switch h.Type { + case protocol.PacketTypeInitial: + packetType = 0b01 + case protocol.PacketType0RTT: + packetType = 0b10 + case protocol.PacketTypeHandshake: + packetType = 0b11 + case protocol.PacketTypeRetry: + packetType = 0b00 + } + } else { + //nolint:exhaustive + switch h.Type { + case protocol.PacketTypeInitial: + packetType = 0b00 + case protocol.PacketType0RTT: + packetType = 0b01 + case protocol.PacketTypeHandshake: + packetType = 0b10 + case protocol.PacketTypeRetry: + packetType = 0b11 + } } firstByte := 0xc0 | packetType<<4 if h.Type != protocol.PacketTypeRetry { diff --git a/internal/wire/extended_header_test.go b/internal/wire/extended_header_test.go index d9c6369ce75..51719e83cd7 100644 --- a/internal/wire/extended_header_test.go +++ b/internal/wire/extended_header_test.go @@ -95,6 +95,7 @@ var _ = Describe("Header", func() { PacketNumber: 0xdecafbad, PacketNumberLen: protocol.PacketNumberLen4, }).Write(buf, versionIETFHeader)).To(Succeed()) + Expect(buf.Bytes()[0]>>4&0b11 == 0) expectedSubstring := append(encodeVarInt(uint64(len(token))), token...) Expect(buf.Bytes()).To(ContainSubstring(string(expectedSubstring))) }) @@ -119,19 +120,74 @@ var _ = Describe("Header", func() { token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.") Expect((&ExtendedHeader{Header: Header{ IsLongHeader: true, - Version: 0x1020304, + Version: protocol.Version1, Type: protocol.PacketTypeRetry, Token: token, }}).Write(buf, versionIETFHeader)).To(Succeed()) - expected := []byte{ - 0xc0 | 0x3<<4, - 0x1, 0x2, 0x3, 0x4, // version number - 0x0, // dest connection ID length - 0x0, // src connection ID length - } + expected := []byte{0xc0 | 0b11<<4} + expected = appendVersion(expected, protocol.Version1) + expected = append(expected, 0x0) // dest connection ID length + expected = append(expected, 0x0) // src connection ID length + expected = append(expected, token...) + Expect(buf.Bytes()).To(Equal(expected)) + }) + }) + + Context("long header, version 2", func() { + It("writes an Initial", func() { + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Version: protocol.Version2, + Type: protocol.PacketTypeInitial, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, protocol.Version2)).To(Succeed()) + Expect(buf.Bytes()[0]>>4&0b11 == 0b01) + }) + + It("writes a Retry packet", func() { + token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.") + Expect((&ExtendedHeader{Header: Header{ + IsLongHeader: true, + Version: protocol.Version2, + Type: protocol.PacketTypeRetry, + Token: token, + }}).Write(buf, versionIETFHeader)).To(Succeed()) + expected := []byte{0xc0 | 0b11<<4} + expected = appendVersion(expected, protocol.Version2) + expected = append(expected, 0x0) // dest connection ID length + expected = append(expected, 0x0) // src connection ID length expected = append(expected, token...) Expect(buf.Bytes()).To(Equal(expected)) }) + + It("writes a Handshake Packet", func() { + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Version: protocol.Version2, + Type: protocol.PacketTypeHandshake, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, protocol.Version2)).To(Succeed()) + Expect(buf.Bytes()[0]>>4&0b11 == 0b11) + }) + + It("writes a 0-RTT Packet", func() { + Expect((&ExtendedHeader{ + Header: Header{ + IsLongHeader: true, + Version: protocol.Version2, + Type: protocol.PacketType0RTT, + }, + PacketNumber: 0xdecafbad, + PacketNumberLen: protocol.PacketNumberLen4, + }).Write(buf, protocol.Version2)).To(Succeed()) + Expect(buf.Bytes()[0]>>4&0b11 == 0b10) + }) }) Context("short header", func() { diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index 86e970a1bf0..8d52ebc78f6 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -18,7 +18,7 @@ var _ = Describe("Frame parsing", func() { BeforeEach(func() { buf = &bytes.Buffer{} - parser = NewFrameParser(true, versionIETFFrames) + parser = NewFrameParser(true, protocol.Version1) }) It("returns nil if there's nothing more to read", func() { @@ -29,7 +29,7 @@ var _ = Describe("Frame parsing", func() { It("skips PADDING frames", func() { buf.Write([]byte{0}) // PADDING frame - (&PingFrame{}).Write(buf, versionIETFFrames) + (&PingFrame{}).Write(buf, protocol.Version1) f, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(f).To(Equal(&PingFrame{})) @@ -45,7 +45,7 @@ var _ = Describe("Frame parsing", func() { It("unpacks ACK frames", func() { f := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 0x13}}} - err := f.Write(buf, versionIETFFrames) + err := f.Write(buf, protocol.Version1) Expect(err).ToNot(HaveOccurred()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -60,7 +60,7 @@ var _ = Describe("Frame parsing", func() { AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, DelayTime: time.Second, } - Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) // The ACK frame is always written using the protocol.AckDelayExponent. @@ -74,7 +74,7 @@ var _ = Describe("Frame parsing", func() { AckRanges: []AckRange{{Smallest: 1, Largest: 1}}, DelayTime: time.Second, } - Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.EncryptionHandshake) Expect(err).ToNot(HaveOccurred()) Expect(frame.(*AckFrame).DelayTime).To(Equal(time.Second)) @@ -86,7 +86,7 @@ var _ = Describe("Frame parsing", func() { FinalSize: 0xdecafbad1234, ErrorCode: 0x1337, } - err := f.Write(buf, versionIETFFrames) + err := f.Write(buf, protocol.Version1) Expect(err).ToNot(HaveOccurred()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -96,7 +96,7 @@ var _ = Describe("Frame parsing", func() { It("unpacks STOP_SENDING frames", func() { f := &StopSendingFrame{StreamID: 0x42} buf := &bytes.Buffer{} - err := f.Write(buf, versionIETFFrames) + err := f.Write(buf, protocol.Version1) Expect(err).ToNot(HaveOccurred()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -108,7 +108,7 @@ var _ = Describe("Frame parsing", func() { Offset: 0x1337, Data: []byte("lorem ipsum"), } - err := f.Write(buf, versionIETFFrames) + err := f.Write(buf, protocol.Version1) Expect(err).ToNot(HaveOccurred()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -118,7 +118,7 @@ var _ = Describe("Frame parsing", func() { It("unpacks NEW_TOKEN frames", func() { f := &NewTokenFrame{Token: []byte("foobar")} - err := f.Write(buf, versionIETFFrames) + err := f.Write(buf, protocol.Version1) Expect(err).ToNot(HaveOccurred()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -133,7 +133,7 @@ var _ = Describe("Frame parsing", func() { Fin: true, Data: []byte("foobar"), } - err := f.Write(buf, versionIETFFrames) + err := f.Write(buf, protocol.Version1) Expect(err).ToNot(HaveOccurred()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -146,7 +146,7 @@ var _ = Describe("Frame parsing", func() { MaximumData: 0xcafe, } buf := &bytes.Buffer{} - err := f.Write(buf, versionIETFFrames) + err := f.Write(buf, protocol.Version1) Expect(err).ToNot(HaveOccurred()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -159,7 +159,7 @@ var _ = Describe("Frame parsing", func() { MaximumStreamData: 0xdecafbad, } buf := &bytes.Buffer{} - err := f.Write(buf, versionIETFFrames) + err := f.Write(buf, protocol.Version1) Expect(err).ToNot(HaveOccurred()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -172,7 +172,7 @@ var _ = Describe("Frame parsing", func() { MaxStreamNum: 0x1337, } buf := &bytes.Buffer{} - err := f.Write(buf, versionIETFFrames) + err := f.Write(buf, protocol.Version1) Expect(err).ToNot(HaveOccurred()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -182,7 +182,7 @@ var _ = Describe("Frame parsing", func() { It("unpacks DATA_BLOCKED frames", func() { f := &DataBlockedFrame{MaximumData: 0x1234} buf := &bytes.Buffer{} - err := f.Write(buf, versionIETFFrames) + err := f.Write(buf, protocol.Version1) Expect(err).ToNot(HaveOccurred()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -194,7 +194,7 @@ var _ = Describe("Frame parsing", func() { StreamID: 0xdeadbeef, MaximumStreamData: 0xdead, } - err := f.Write(buf, versionIETFFrames) + err := f.Write(buf, protocol.Version1) Expect(err).ToNot(HaveOccurred()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -207,7 +207,7 @@ var _ = Describe("Frame parsing", func() { StreamLimit: 0x1234567, } buf := &bytes.Buffer{} - err := f.Write(buf, versionIETFFrames) + err := f.Write(buf, protocol.Version1) Expect(err).ToNot(HaveOccurred()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -221,7 +221,7 @@ var _ = Describe("Frame parsing", func() { StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, } buf := &bytes.Buffer{} - Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) @@ -230,7 +230,7 @@ var _ = Describe("Frame parsing", func() { It("unpacks RETIRE_CONNECTION_ID frames", func() { f := &RetireConnectionIDFrame{SequenceNumber: 0x1337} buf := &bytes.Buffer{} - Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) @@ -238,7 +238,7 @@ var _ = Describe("Frame parsing", func() { It("unpacks PATH_CHALLENGE frames", func() { f := &PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} - err := f.Write(buf, versionIETFFrames) + err := f.Write(buf, protocol.Version1) Expect(err).ToNot(HaveOccurred()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -249,7 +249,7 @@ var _ = Describe("Frame parsing", func() { It("unpacks PATH_RESPONSE frames", func() { f := &PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}} - err := f.Write(buf, versionIETFFrames) + err := f.Write(buf, protocol.Version1) Expect(err).ToNot(HaveOccurred()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -264,7 +264,7 @@ var _ = Describe("Frame parsing", func() { ReasonPhrase: "foobar", } buf := &bytes.Buffer{} - err := f.Write(buf, versionIETFFrames) + err := f.Write(buf, protocol.Version1) Expect(err).ToNot(HaveOccurred()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) @@ -274,7 +274,7 @@ var _ = Describe("Frame parsing", func() { It("unpacks HANDSHAKE_DONE frames", func() { f := &HandshakeDoneFrame{} buf := &bytes.Buffer{} - Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) @@ -283,17 +283,17 @@ var _ = Describe("Frame parsing", func() { It("unpacks DATAGRAM frames", func() { f := &DatagramFrame{Data: []byte("foobar")} buf := &bytes.Buffer{} - Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) frame, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(frame).To(Equal(f)) }) It("errors when DATAGRAM frames are not supported", func() { - parser = NewFrameParser(false, versionIETFFrames) + parser = NewFrameParser(false, protocol.Version1) f := &DatagramFrame{Data: []byte("foobar")} buf := &bytes.Buffer{} - Expect(f.Write(buf, versionIETFFrames)).To(Succeed()) + Expect(f.Write(buf, protocol.Version1)).To(Succeed()) _, err := parser.ParseNext(bytes.NewReader(buf.Bytes()), protocol.Encryption1RTT) Expect(err).To(MatchError(&qerr.TransportError{ ErrorCode: qerr.FrameEncodingError, @@ -317,7 +317,7 @@ var _ = Describe("Frame parsing", func() { MaximumStreamData: 0xdeadbeef, } b := &bytes.Buffer{} - f.Write(b, versionIETFFrames) + f.Write(b, protocol.Version1) _, err := parser.ParseNext(bytes.NewReader(b.Bytes()[:b.Len()-2]), protocol.Encryption1RTT) Expect(err).To(HaveOccurred()) Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.FrameEncodingError)) @@ -353,7 +353,7 @@ var _ = Describe("Frame parsing", func() { framesSerialized = nil for _, frame := range frames { buf := &bytes.Buffer{} - Expect(frame.Write(buf, versionIETFFrames)).To(Succeed()) + Expect(frame.Write(buf, protocol.Version1)).To(Succeed()) framesSerialized = append(framesSerialized, buf.Bytes()) } }) diff --git a/internal/wire/header.go b/internal/wire/header.go index 07ca9f05605..f6a31ee0ec4 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -53,10 +53,14 @@ func Is0RTTPacket(b []byte) bool { if b[0]&0x80 == 0 { return false } - if !protocol.IsSupportedVersion(protocol.SupportedVersions, protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5]))) { + version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5])) + if !protocol.IsSupportedVersion(protocol.SupportedVersions, version) { return false } - return b[0]&0x30>>4 == 0x1 + if version == protocol.Version2 { + return b[0]>>4&0b11 == 0b10 + } + return b[0]>>4&0b11 == 0b01 } var ErrUnsupportedVersion = errors.New("unsupported version") @@ -179,15 +183,28 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error { return ErrUnsupportedVersion } - switch (h.typeByte & 0x30) >> 4 { - case 0x0: - h.Type = protocol.PacketTypeInitial - case 0x1: - h.Type = protocol.PacketType0RTT - case 0x2: - h.Type = protocol.PacketTypeHandshake - case 0x3: - h.Type = protocol.PacketTypeRetry + if h.Version == protocol.Version2 { + switch h.typeByte >> 4 & 0b11 { + case 0b00: + h.Type = protocol.PacketTypeRetry + case 0b01: + h.Type = protocol.PacketTypeInitial + case 0b10: + h.Type = protocol.PacketType0RTT + case 0b11: + h.Type = protocol.PacketTypeHandshake + } + } else { + switch h.typeByte >> 4 & 0b11 { + case 0b00: + h.Type = protocol.PacketTypeInitial + case 0b01: + h.Type = protocol.PacketType0RTT + case 0b10: + h.Type = protocol.PacketTypeHandshake + case 0b11: + h.Type = protocol.PacketTypeRetry + } } if h.Type == protocol.PacketTypeRetry { diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index 782e292d9ac..77b196e3603 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -11,13 +11,6 @@ import ( ) var _ = Describe("Header Parsing", func() { - appendVersion := func(data []byte, v protocol.VersionNumber) []byte { - offset := len(data) - data = append(data, []byte{0, 0, 0, 0}...) - binary.BigEndian.PutUint32(data[offset:], uint32(v)) - return data - } - Context("Parsing the Connection ID", func() { It("parses the connection ID of a long header packet", func() { buf := &bytes.Buffer{} @@ -27,10 +20,10 @@ var _ = Describe("Header Parsing", func() { Type: protocol.PacketTypeHandshake, DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6}, - Version: versionIETFFrames, + Version: protocol.Version1, }, PacketNumberLen: 2, - }).Write(buf, versionIETFFrames)).To(Succeed()) + }).Write(buf, protocol.Version1)).To(Succeed()) connID, err := ParseConnectionID(buf.Bytes(), 8) Expect(err).ToNot(HaveOccurred()) Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) @@ -43,7 +36,7 @@ var _ = Describe("Header Parsing", func() { DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, }, PacketNumberLen: 2, - }).Write(buf, versionIETFFrames)).To(Succeed()) + }).Write(buf, protocol.Version1)).To(Succeed()) buf.Write([]byte("foobar")) connID, err := ParseConnectionID(buf.Bytes(), 4) Expect(err).ToNot(HaveOccurred()) @@ -57,7 +50,7 @@ var _ = Describe("Header Parsing", func() { DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, }, PacketNumberLen: 2, - }).Write(buf, versionIETFFrames)).To(Succeed()) + }).Write(buf, protocol.Version1)).To(Succeed()) data := buf.Bytes()[:buf.Len()-2] // cut the packet number _, err := ParseConnectionID(data, 8) Expect(err).ToNot(HaveOccurred()) @@ -77,10 +70,10 @@ var _ = Describe("Header Parsing", func() { Type: protocol.PacketTypeHandshake, DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}, SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 8, 9}, - Version: versionIETFFrames, + Version: protocol.Version1, }, PacketNumberLen: 2, - }).Write(buf, versionIETFFrames)).To(Succeed()) + }).Write(buf, protocol.Version1)).To(Succeed()) data := buf.Bytes()[:buf.Len()-2] // cut the packet number _, err := ParseConnectionID(data, 8) Expect(err).ToNot(HaveOccurred()) @@ -94,19 +87,27 @@ var _ = Describe("Header Parsing", func() { }) Context("identifying 0-RTT packets", func() { - var zeroRTTHeader []byte + It("recognizes 0-RTT packets, for QUIC v1", func() { + zeroRTTHeader := make([]byte, 5) + zeroRTTHeader[0] = 0x80 | 0b01<<4 + binary.BigEndian.PutUint32(zeroRTTHeader[1:], uint32(protocol.Version1)) - BeforeEach(func() { - zeroRTTHeader = make([]byte, 5) - zeroRTTHeader[0] = 0x80 | 0x1<<4 - binary.BigEndian.PutUint32(zeroRTTHeader[1:], uint32(versionIETFFrames)) + Expect(Is0RTTPacket(zeroRTTHeader)).To(BeTrue()) + Expect(Is0RTTPacket(zeroRTTHeader[:4])).To(BeFalse()) // too short + Expect(Is0RTTPacket([]byte{zeroRTTHeader[0], 1, 2, 3, 4})).To(BeFalse()) // unknown version + Expect(Is0RTTPacket([]byte{zeroRTTHeader[0] | 0x80, 1, 2, 3, 4})).To(BeFalse()) // short header + Expect(Is0RTTPacket(append(zeroRTTHeader, []byte("foobar")...))).To(BeTrue()) }) - It("recognizes 0-RTT packets", func() { + It("recognizes 0-RTT packets, for QUIC v2", func() { + zeroRTTHeader := make([]byte, 5) + zeroRTTHeader[0] = 0x80 | 0b10<<4 + binary.BigEndian.PutUint32(zeroRTTHeader[1:], uint32(protocol.Version2)) + + Expect(Is0RTTPacket(zeroRTTHeader)).To(BeTrue()) Expect(Is0RTTPacket(zeroRTTHeader[:4])).To(BeFalse()) // too short Expect(Is0RTTPacket([]byte{zeroRTTHeader[0], 1, 2, 3, 4})).To(BeFalse()) // unknown version Expect(Is0RTTPacket([]byte{zeroRTTHeader[0] | 0x80, 1, 2, 3, 4})).To(BeFalse()) // short header - Expect(Is0RTTPacket(zeroRTTHeader)).To(BeTrue()) Expect(Is0RTTPacket(append(zeroRTTHeader, []byte("foobar")...))).To(BeTrue()) }) }) @@ -134,7 +135,7 @@ var _ = Describe("Header Parsing", func() { destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} data := []byte{0xc0 ^ 0x3} - data = appendVersion(data, versionIETFFrames) + data = appendVersion(data, protocol.Version1) data = append(data, 0x9) // dest conn id length data = append(data, destConnID...) data = append(data, 0x4) // src conn id length @@ -156,10 +157,10 @@ var _ = Describe("Header Parsing", func() { Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial)) Expect(hdr.Token).To(Equal([]byte("foobar"))) Expect(hdr.Length).To(Equal(protocol.ByteCount(10))) - Expect(hdr.Version).To(Equal(versionIETFFrames)) + Expect(hdr.Version).To(Equal(protocol.Version1)) Expect(rest).To(BeEmpty()) b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, versionIETFFrames) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0xbeef))) Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) @@ -200,7 +201,7 @@ var _ = Describe("Header Parsing", func() { It("parses a Long Header without a destination connection ID", func() { data := []byte{0xc0 ^ 0x1<<4} - data = appendVersion(data, versionIETFFrames) + data = appendVersion(data, protocol.Version1) data = append(data, 0x0) // dest conn ID len data = append(data, 0x4) // src conn ID len data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // source connection ID @@ -215,7 +216,7 @@ var _ = Describe("Header Parsing", func() { It("parses a Long Header without a source connection ID", func() { data := []byte{0xc0 ^ 0x2<<4} - data = appendVersion(data, versionIETFFrames) + data = appendVersion(data, protocol.Version1) data = append(data, 0xa) // dest conn ID len data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // dest connection ID data = append(data, 0x0) // src conn ID len @@ -229,7 +230,7 @@ var _ = Describe("Header Parsing", func() { It("parses a Long Header with a 2 byte packet number", func() { data := []byte{0xc0 ^ 0x1} - data = appendVersion(data, versionIETFFrames) // version number + data = appendVersion(data, protocol.Version1) // version number data = append(data, []byte{0x0, 0x0}...) // connection ID lengths data = append(data, encodeVarInt(0)...) // token length data = append(data, encodeVarInt(0)...) // length @@ -238,16 +239,36 @@ var _ = Describe("Header Parsing", func() { hdr, _, _, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, versionIETFFrames) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x123))) Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) Expect(b.Len()).To(BeZero()) }) - It("parses a Retry packet", func() { - data := []byte{0xc0 | 0x3<<4 | (10 - 3) /* connection ID length */} - data = appendVersion(data, versionIETFFrames) + It("parses a Retry packet, for QUIC v1", func() { + data := []byte{0xc0 | 0b11<<4 | (10 - 3) /* connection ID length */} + data = appendVersion(data, protocol.Version1) + data = append(data, []byte{6}...) // dest conn ID len + data = append(data, []byte{6, 5, 4, 3, 2, 1}...) // dest conn ID + data = append(data, []byte{10}...) // src conn ID len + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID + data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token + data = append(data, []byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}...) + hdr, pdata, rest, err := ParsePacket(data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) + Expect(hdr.Version).To(Equal(protocol.Version1)) + Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{6, 5, 4, 3, 2, 1})) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + Expect(hdr.Token).To(Equal([]byte("foobar"))) + Expect(pdata).To(Equal(data)) + Expect(rest).To(BeEmpty()) + }) + + It("parses a Retry packet, for QUIC v2", func() { + data := []byte{0xc0 | 0b00<<4 | (10 - 3) /* connection ID length */} + data = appendVersion(data, protocol.Version2) data = append(data, []byte{6}...) // dest conn ID len data = append(data, []byte{6, 5, 4, 3, 2, 1}...) // dest conn ID data = append(data, []byte{10}...) // src conn ID len @@ -257,6 +278,7 @@ var _ = Describe("Header Parsing", func() { hdr, pdata, rest, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) + Expect(hdr.Version).To(Equal(protocol.Version2)) Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{6, 5, 4, 3, 2, 1})) Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) Expect(hdr.Token).To(Equal([]byte("foobar"))) @@ -266,7 +288,7 @@ var _ = Describe("Header Parsing", func() { It("errors if the Retry packet is too short for the integrity tag", func() { data := []byte{0xc0 | 0x3<<4 | (10 - 3) /* connection ID length */} - data = appendVersion(data, versionIETFFrames) + data = appendVersion(data, protocol.Version1) data = append(data, []byte{0, 0}...) // conn ID lens data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) @@ -277,7 +299,7 @@ var _ = Describe("Header Parsing", func() { It("errors if the token length is too large", func() { data := []byte{0xc0 ^ 0x1} - data = appendVersion(data, versionIETFFrames) + data = appendVersion(data, protocol.Version1) data = append(data, 0x0) // connection ID lengths data = append(data, encodeVarInt(4)...) // token length: 4 bytes (1 byte too long) data = append(data, encodeVarInt(0x42)...) // length, 1 byte @@ -289,14 +311,14 @@ var _ = Describe("Header Parsing", func() { It("errors if the 5th or 6th bit are set", func() { data := []byte{0xc0 | 0x2<<4 | 0x8 /* set the 5th bit */ | 0x1 /* 2 byte packet number */} - data = appendVersion(data, versionIETFFrames) + data = appendVersion(data, protocol.Version1) data = append(data, []byte{0x0, 0x0}...) // connection ID lengths data = append(data, encodeVarInt(2)...) // length data = append(data, []byte{0x12, 0x34}...) // packet number hdr, _, _, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) - extHdr, err := hdr.ParseExtended(bytes.NewReader(data), versionIETFFrames) + extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) Expect(err).To(MatchError(ErrInvalidReservedBits)) Expect(extHdr).ToNot(BeNil()) Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1234))) @@ -304,7 +326,7 @@ var _ = Describe("Header Parsing", func() { It("errors on EOF, when parsing the header", func() { data := []byte{0xc0 ^ 0x2<<4} - data = appendVersion(data, versionIETFFrames) + data = appendVersion(data, protocol.Version1) data = append(data, 0x8) // dest conn ID len data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // dest conn ID data = append(data, 0x8) // src conn ID len @@ -317,7 +339,7 @@ var _ = Describe("Header Parsing", func() { It("errors on EOF, when parsing the extended header", func() { data := []byte{0xc0 | 0x2<<4 | 0x3} - data = appendVersion(data, versionIETFFrames) + data = appendVersion(data, protocol.Version1) data = append(data, []byte{0x0, 0x0}...) // connection ID lengths data = append(data, encodeVarInt(0)...) // length hdrLen := len(data) @@ -327,14 +349,14 @@ var _ = Describe("Header Parsing", func() { hdr, _, _, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data) - _, err = hdr.ParseExtended(b, versionIETFFrames) + _, err = hdr.ParseExtended(b, protocol.Version1) Expect(err).To(Equal(io.EOF)) } }) It("errors on EOF, for a Retry packet", func() { data := []byte{0xc0 ^ 0x3<<4} - data = appendVersion(data, versionIETFFrames) + data = appendVersion(data, protocol.Version1) data = append(data, []byte{0x0, 0x0}...) // connection ID lengths data = append(data, 0xa) // Orig Destination Connection ID length data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID @@ -344,7 +366,7 @@ var _ = Describe("Header Parsing", func() { hdr, _, _, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data) - _, err = hdr.ParseExtended(b, versionIETFFrames) + _, err = hdr.ParseExtended(b, protocol.Version1) Expect(err).To(Equal(io.EOF)) } }) @@ -357,13 +379,13 @@ var _ = Describe("Header Parsing", func() { Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, Length: 2 + 6, - Version: versionIETFFrames, + Version: protocol.Version1, } Expect((&ExtendedHeader{ Header: hdr, PacketNumber: 0x1337, PacketNumberLen: 2, - }).Write(buf, versionIETFFrames)).To(Succeed()) + }).Write(buf, protocol.Version1)).To(Succeed()) hdrRaw := append([]byte{}, buf.Bytes()...) buf.Write([]byte("foobar")) // payload of the first packet buf.Write([]byte("raboof")) // second packet @@ -383,11 +405,11 @@ var _ = Describe("Header Parsing", func() { Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, Length: 3, - Version: versionIETFFrames, + Version: protocol.Version1, }, PacketNumber: 0x1337, PacketNumberLen: 2, - }).Write(buf, versionIETFFrames)).To(Succeed()) + }).Write(buf, protocol.Version1)).To(Succeed()) _, _, _, err := ParsePacket(buf.Bytes(), 4) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("packet length (2 bytes) is smaller than the expected length (3 bytes)")) @@ -401,11 +423,11 @@ var _ = Describe("Header Parsing", func() { Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, Length: 1000, - Version: versionIETFFrames, + Version: protocol.Version1, }, PacketNumber: 0x1337, PacketNumberLen: 2, - }).Write(buf, versionIETFFrames)).To(Succeed()) + }).Write(buf, protocol.Version1)).To(Succeed()) buf.Write(make([]byte, 500-2 /* for packet number length */)) _, _, _, err := ParsePacket(buf.Bytes(), 4) Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) @@ -425,7 +447,7 @@ var _ = Describe("Header Parsing", func() { Expect(hdr.IsLongHeader).To(BeFalse()) Expect(hdr.DestConnectionID).To(Equal(connID)) b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, versionIETFFrames) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) Expect(extHdr.DestConnectionID).To(Equal(connID)) @@ -451,7 +473,7 @@ var _ = Describe("Header Parsing", func() { hdr, _, _, err := ParsePacket(data, 5) Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsLongHeader).To(BeFalse()) - extHdr, err := hdr.ParseExtended(bytes.NewReader(data), versionIETFFrames) + extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) Expect(err).To(MatchError(ErrInvalidReservedBits)) Expect(extHdr).ToNot(BeNil()) Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) @@ -467,7 +489,7 @@ var _ = Describe("Header Parsing", func() { Expect(hdr.IsLongHeader).To(BeFalse()) Expect(hdr.DestConnectionID).To(Equal(connID)) b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, versionIETFFrames) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) Expect(extHdr.DestConnectionID).To(Equal(connID)) @@ -485,7 +507,7 @@ var _ = Describe("Header Parsing", func() { Expect(err).ToNot(HaveOccurred()) Expect(hdr.IsLongHeader).To(BeFalse()) b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, versionIETFFrames) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseOne)) Expect(b.Len()).To(BeZero()) @@ -500,7 +522,7 @@ var _ = Describe("Header Parsing", func() { hdr, _, _, err := ParsePacket(data, 4) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, versionIETFFrames) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.IsLongHeader).To(BeFalse()) Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) @@ -517,7 +539,7 @@ var _ = Describe("Header Parsing", func() { hdr, _, _, err := ParsePacket(data, 10) Expect(err).ToNot(HaveOccurred()) b := bytes.NewReader(data) - extHdr, err := hdr.ParseExtended(b, versionIETFFrames) + extHdr, err := hdr.ParseExtended(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(extHdr.IsLongHeader).To(BeFalse()) Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x99beef))) @@ -548,7 +570,7 @@ var _ = Describe("Header Parsing", func() { data = data[:i] hdr, _, _, err := ParsePacket(data, 6) Expect(err).ToNot(HaveOccurred()) - _, err = hdr.ParseExtended(bytes.NewReader(data), versionIETFFrames) + _, err = hdr.ParseExtended(bytes.NewReader(data), protocol.Version1) Expect(err).To(Equal(io.EOF)) } }) diff --git a/internal/wire/max_data_frame_test.go b/internal/wire/max_data_frame_test.go index 56abc54f181..a5ee0222217 100644 --- a/internal/wire/max_data_frame_test.go +++ b/internal/wire/max_data_frame_test.go @@ -16,7 +16,7 @@ var _ = Describe("MAX_DATA frame", func() { data := []byte{0x10} data = append(data, encodeVarInt(0xdecafbad123456)...) // byte offset b := bytes.NewReader(data) - frame, err := parseMaxDataFrame(b, versionIETFFrames) + frame, err := parseMaxDataFrame(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.MaximumData).To(Equal(protocol.ByteCount(0xdecafbad123456))) Expect(b.Len()).To(BeZero()) @@ -25,10 +25,10 @@ var _ = Describe("MAX_DATA frame", func() { It("errors on EOFs", func() { data := []byte{0x10} data = append(data, encodeVarInt(0xdecafbad1234567)...) // byte offset - _, err := parseMaxDataFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parseMaxDataFrame(bytes.NewReader(data), protocol.Version1) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parseMaxDataFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + _, err := parseMaxDataFrame(bytes.NewReader(data[0:i]), protocol.Version1) Expect(err).To(HaveOccurred()) } }) @@ -39,7 +39,7 @@ var _ = Describe("MAX_DATA frame", func() { f := &MaxDataFrame{ MaximumData: 0xdeadbeef, } - Expect(f.Length(versionIETFFrames)).To(Equal(1 + quicvarint.Len(0xdeadbeef))) + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0xdeadbeef))) }) It("writes a MAX_DATA frame", func() { @@ -47,7 +47,7 @@ var _ = Describe("MAX_DATA frame", func() { f := &MaxDataFrame{ MaximumData: 0xdeadbeefcafe, } - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x10} expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) diff --git a/internal/wire/max_stream_data_frame_test.go b/internal/wire/max_stream_data_frame_test.go index 772fde2a067..f12aac87625 100644 --- a/internal/wire/max_stream_data_frame_test.go +++ b/internal/wire/max_stream_data_frame_test.go @@ -17,7 +17,7 @@ var _ = Describe("MAX_STREAM_DATA frame", func() { data = append(data, encodeVarInt(0xdeadbeef)...) // Stream ID data = append(data, encodeVarInt(0x12345678)...) // Offset b := bytes.NewReader(data) - frame, err := parseMaxStreamDataFrame(b, versionIETFFrames) + frame, err := parseMaxStreamDataFrame(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) Expect(frame.MaximumStreamData).To(Equal(protocol.ByteCount(0x12345678))) @@ -28,10 +28,10 @@ var _ = Describe("MAX_STREAM_DATA frame", func() { data := []byte{0x11} data = append(data, encodeVarInt(0xdeadbeef)...) // Stream ID data = append(data, encodeVarInt(0x12345678)...) // Offset - _, err := parseMaxStreamDataFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parseMaxStreamDataFrame(bytes.NewReader(data), protocol.Version1) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parseMaxStreamDataFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + _, err := parseMaxStreamDataFrame(bytes.NewReader(data[0:i]), protocol.Version1) Expect(err).To(HaveOccurred()) } }) @@ -55,7 +55,7 @@ var _ = Describe("MAX_STREAM_DATA frame", func() { expected := []byte{0x11} expected = append(expected, encodeVarInt(0xdecafbad)...) expected = append(expected, encodeVarInt(0xdeadbeefcafe42)...) - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(b.Bytes()).To(Equal(expected)) }) diff --git a/internal/wire/new_connection_id_frame_test.go b/internal/wire/new_connection_id_frame_test.go index cacadc69793..75fe85c54fd 100644 --- a/internal/wire/new_connection_id_frame_test.go +++ b/internal/wire/new_connection_id_frame_test.go @@ -20,7 +20,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // connection ID data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token b := bytes.NewReader(data) - frame, err := parseNewConnectionIDFrame(b, versionIETFFrames) + frame, err := parseNewConnectionIDFrame(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.SequenceNumber).To(Equal(uint64(0xdeadbeef))) Expect(frame.RetirePriorTo).To(Equal(uint64(0xcafe))) @@ -36,7 +36,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { data = append(data, []byte{1, 2, 3}...) data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token b := bytes.NewReader(data) - _, err := parseNewConnectionIDFrame(b, versionIETFFrames) + _, err := parseNewConnectionIDFrame(b, protocol.Version1) Expect(err).To(MatchError("Retire Prior To value (1001) larger than Sequence Number (1000)")) }) @@ -48,7 +48,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}...) // connection ID data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token b := bytes.NewReader(data) - _, err := parseNewConnectionIDFrame(b, versionIETFFrames) + _, err := parseNewConnectionIDFrame(b, protocol.Version1) Expect(err).To(MatchError("invalid connection ID length: 21")) }) @@ -59,10 +59,10 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { data = append(data, 10) // connection ID length data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // connection ID data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token - _, err := parseNewConnectionIDFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parseNewConnectionIDFrame(bytes.NewReader(data), protocol.Version1) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parseNewConnectionIDFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + _, err := parseNewConnectionIDFrame(bytes.NewReader(data[0:i]), protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) @@ -78,7 +78,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { StatelessResetToken: token, } b := &bytes.Buffer{} - Expect(frame.Write(b, versionIETFFrames)).To(Succeed()) + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) expected := []byte{0x18} expected = append(expected, encodeVarInt(0x1337)...) expected = append(expected, encodeVarInt(0x42)...) @@ -97,8 +97,8 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { StatelessResetToken: token, } b := &bytes.Buffer{} - Expect(frame.Write(b, versionIETFFrames)).To(Succeed()) - Expect(frame.Length(versionIETFFrames)).To(BeEquivalentTo(b.Len())) + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + Expect(frame.Length(protocol.Version1)).To(BeEquivalentTo(b.Len())) }) }) }) diff --git a/internal/wire/path_challenge_frame_test.go b/internal/wire/path_challenge_frame_test.go index de1aa80a581..e7d8a97076f 100644 --- a/internal/wire/path_challenge_frame_test.go +++ b/internal/wire/path_challenge_frame_test.go @@ -22,10 +22,10 @@ var _ = Describe("PATH_CHALLENGE frame", func() { It("errors on EOFs", func() { data := []byte{0x1a, 1, 2, 3, 4, 5, 6, 7, 8} - _, err := parsePathChallengeFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parsePathChallengeFrame(bytes.NewReader(data), protocol.Version1) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parsePathChallengeFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + _, err := parsePathChallengeFrame(bytes.NewReader(data[0:i]), protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) diff --git a/internal/wire/path_response_frame_test.go b/internal/wire/path_response_frame_test.go index 5ba78a76cf5..11e9d67aefc 100644 --- a/internal/wire/path_response_frame_test.go +++ b/internal/wire/path_response_frame_test.go @@ -21,10 +21,10 @@ var _ = Describe("PATH_RESPONSE frame", func() { It("errors on EOFs", func() { data := []byte{0x1b, 1, 2, 3, 4, 5, 6, 7, 8} - _, err := parsePathResponseFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parsePathResponseFrame(bytes.NewReader(data), protocol.Version1) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parsePathResponseFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + _, err := parsePathResponseFrame(bytes.NewReader(data[0:i]), protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) diff --git a/internal/wire/reset_stream_frame_test.go b/internal/wire/reset_stream_frame_test.go index a1f0259d1c5..e241a8e6e44 100644 --- a/internal/wire/reset_stream_frame_test.go +++ b/internal/wire/reset_stream_frame_test.go @@ -19,7 +19,7 @@ var _ = Describe("RESET_STREAM frame", func() { data = append(data, encodeVarInt(0x1337)...) // error code data = append(data, encodeVarInt(0x987654321)...) // byte offset b := bytes.NewReader(data) - frame, err := parseResetStreamFrame(b, versionIETFFrames) + frame, err := parseResetStreamFrame(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) Expect(frame.FinalSize).To(Equal(protocol.ByteCount(0x987654321))) @@ -31,10 +31,10 @@ var _ = Describe("RESET_STREAM frame", func() { data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID data = append(data, encodeVarInt(0x1337)...) // error code data = append(data, encodeVarInt(0x987654321)...) // byte offset - _, err := parseResetStreamFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parseResetStreamFrame(bytes.NewReader(data), protocol.Version1) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parseResetStreamFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + _, err := parseResetStreamFrame(bytes.NewReader(data[0:i]), protocol.Version1) Expect(err).To(HaveOccurred()) } }) @@ -48,7 +48,7 @@ var _ = Describe("RESET_STREAM frame", func() { ErrorCode: 0xcafe, } b := &bytes.Buffer{} - err := frame.Write(b, versionIETFFrames) + err := frame.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x4} expected = append(expected, encodeVarInt(0x1337)...) @@ -64,7 +64,7 @@ var _ = Describe("RESET_STREAM frame", func() { ErrorCode: 0xde, } expectedLen := 1 + quicvarint.Len(0x1337) + quicvarint.Len(0x1234567) + 2 - Expect(rst.Length(versionIETFFrames)).To(Equal(expectedLen)) + Expect(rst.Length(protocol.Version1)).To(Equal(expectedLen)) }) }) }) diff --git a/internal/wire/retire_connection_id_frame_test.go b/internal/wire/retire_connection_id_frame_test.go index 54d41c59d46..0338b6ccf92 100644 --- a/internal/wire/retire_connection_id_frame_test.go +++ b/internal/wire/retire_connection_id_frame_test.go @@ -4,6 +4,8 @@ import ( "bytes" "io" + "github.com/lucas-clemente/quic-go/internal/protocol" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -14,7 +16,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { data := []byte{0x19} data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number b := bytes.NewReader(data) - frame, err := parseRetireConnectionIDFrame(b, versionIETFFrames) + frame, err := parseRetireConnectionIDFrame(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.SequenceNumber).To(Equal(uint64(0xdeadbeef))) }) @@ -22,10 +24,10 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { It("errors on EOFs", func() { data := []byte{0x18} data = append(data, encodeVarInt(0xdeadbeef)...) // sequence number - _, err := parseRetireConnectionIDFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parseRetireConnectionIDFrame(bytes.NewReader(data), protocol.Version1) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parseRetireConnectionIDFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + _, err := parseRetireConnectionIDFrame(bytes.NewReader(data[0:i]), protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) @@ -35,7 +37,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { It("writes a sample frame", func() { frame := &RetireConnectionIDFrame{SequenceNumber: 0x1337} b := &bytes.Buffer{} - Expect(frame.Write(b, versionIETFFrames)).To(Succeed()) + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) expected := []byte{0x19} expected = append(expected, encodeVarInt(0x1337)...) Expect(b.Bytes()).To(Equal(expected)) @@ -44,8 +46,8 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { It("has the correct length", func() { frame := &RetireConnectionIDFrame{SequenceNumber: 0xdecafbad} b := &bytes.Buffer{} - Expect(frame.Write(b, versionIETFFrames)).To(Succeed()) - Expect(frame.Length(versionIETFFrames)).To(BeEquivalentTo(b.Len())) + Expect(frame.Write(b, protocol.Version1)).To(Succeed()) + Expect(frame.Length(protocol.Version1)).To(BeEquivalentTo(b.Len())) }) }) }) diff --git a/internal/wire/stop_sending_frame_test.go b/internal/wire/stop_sending_frame_test.go index 9a3dcda0e23..7b6793c26e3 100644 --- a/internal/wire/stop_sending_frame_test.go +++ b/internal/wire/stop_sending_frame_test.go @@ -18,7 +18,7 @@ var _ = Describe("STOP_SENDING frame", func() { data = append(data, encodeVarInt(0xdecafbad)...) // stream ID data = append(data, encodeVarInt(0x1337)...) // error code b := bytes.NewReader(data) - frame, err := parseStopSendingFrame(b, versionIETFFrames) + frame, err := parseStopSendingFrame(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdecafbad))) Expect(frame.ErrorCode).To(Equal(qerr.StreamErrorCode(0x1337))) @@ -29,10 +29,10 @@ var _ = Describe("STOP_SENDING frame", func() { data := []byte{0x5} data = append(data, encodeVarInt(0xdecafbad)...) // stream ID data = append(data, encodeVarInt(0x123456)...) // error code - _, err := parseStopSendingFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parseStopSendingFrame(bytes.NewReader(data), protocol.Version1) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parseStopSendingFrame(bytes.NewReader(data[:i]), versionIETFFrames) + _, err := parseStopSendingFrame(bytes.NewReader(data[:i]), protocol.Version1) Expect(err).To(HaveOccurred()) } }) @@ -45,7 +45,7 @@ var _ = Describe("STOP_SENDING frame", func() { ErrorCode: 0xdecafbad, } buf := &bytes.Buffer{} - Expect(frame.Write(buf, versionIETFFrames)).To(Succeed()) + Expect(frame.Write(buf, protocol.Version1)).To(Succeed()) expected := []byte{0x5} expected = append(expected, encodeVarInt(0xdeadbeefcafe)...) expected = append(expected, encodeVarInt(0xdecafbad)...) @@ -57,7 +57,7 @@ var _ = Describe("STOP_SENDING frame", func() { StreamID: 0xdeadbeef, ErrorCode: 0x1234567, } - Expect(frame.Length(versionIETFFrames)).To(Equal(1 + quicvarint.Len(0xdeadbeef) + quicvarint.Len(0x1234567))) + Expect(frame.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0xdeadbeef) + quicvarint.Len(0x1234567))) }) }) }) diff --git a/internal/wire/stream_data_blocked_frame_test.go b/internal/wire/stream_data_blocked_frame_test.go index db8a715af4b..b5ec7cbda8d 100644 --- a/internal/wire/stream_data_blocked_frame_test.go +++ b/internal/wire/stream_data_blocked_frame_test.go @@ -17,7 +17,7 @@ var _ = Describe("STREAM_DATA_BLOCKED frame", func() { data = append(data, encodeVarInt(0xdeadbeef)...) // stream ID data = append(data, encodeVarInt(0xdecafbad)...) // offset b := bytes.NewReader(data) - frame, err := parseStreamDataBlockedFrame(b, versionIETFFrames) + frame, err := parseStreamDataBlockedFrame(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0xdeadbeef))) Expect(frame.MaximumStreamData).To(Equal(protocol.ByteCount(0xdecafbad))) @@ -28,10 +28,10 @@ var _ = Describe("STREAM_DATA_BLOCKED frame", func() { data := []byte{0x15} data = append(data, encodeVarInt(0xdeadbeef)...) data = append(data, encodeVarInt(0xc0010ff)...) - _, err := parseStreamDataBlockedFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parseStreamDataBlockedFrame(bytes.NewReader(data), protocol.Version1) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parseStreamDataBlockedFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + _, err := parseStreamDataBlockedFrame(bytes.NewReader(data[0:i]), protocol.Version1) Expect(err).To(HaveOccurred()) } }) @@ -52,7 +52,7 @@ var _ = Describe("STREAM_DATA_BLOCKED frame", func() { StreamID: 0xdecafbad, MaximumStreamData: 0x1337, } - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x15} expected = append(expected, encodeVarInt(uint64(f.StreamID))...) diff --git a/internal/wire/stream_frame_test.go b/internal/wire/stream_frame_test.go index 08a421b60ae..92f863a53d7 100644 --- a/internal/wire/stream_frame_test.go +++ b/internal/wire/stream_frame_test.go @@ -19,7 +19,7 @@ var _ = Describe("STREAM frame", func() { data = append(data, encodeVarInt(0xdecafbad)...) // offset data = append(data, []byte("foobar")...) r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, versionIETFFrames) + frame, err := parseStreamFrame(r, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) Expect(frame.Data).To(Equal([]byte("foobar"))) @@ -34,7 +34,7 @@ var _ = Describe("STREAM frame", func() { data = append(data, encodeVarInt(4)...) // data length data = append(data, []byte("foobar")...) r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, versionIETFFrames) + frame, err := parseStreamFrame(r, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) Expect(frame.Data).To(Equal([]byte("foob"))) @@ -48,7 +48,7 @@ var _ = Describe("STREAM frame", func() { data = append(data, encodeVarInt(9)...) // stream ID data = append(data, []byte("foobar")...) r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, versionIETFFrames) + frame, err := parseStreamFrame(r, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(9))) Expect(frame.Data).To(Equal([]byte("foobar"))) @@ -62,7 +62,7 @@ var _ = Describe("STREAM frame", func() { data = append(data, encodeVarInt(0x1337)...) // stream ID data = append(data, encodeVarInt(0x12345)...) // offset r := bytes.NewReader(data) - f, err := parseStreamFrame(r, versionIETFFrames) + f, err := parseStreamFrame(r, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(f.StreamID).To(Equal(protocol.StreamID(0x1337))) Expect(f.Offset).To(Equal(protocol.ByteCount(0x12345))) @@ -76,7 +76,7 @@ var _ = Describe("STREAM frame", func() { data = append(data, encodeVarInt(uint64(protocol.MaxByteCount-5))...) // offset data = append(data, []byte("foobar")...) r := bytes.NewReader(data) - _, err := parseStreamFrame(r, versionIETFFrames) + _, err := parseStreamFrame(r, protocol.Version1) Expect(err).To(MatchError("stream data overflows maximum offset")) }) @@ -86,7 +86,7 @@ var _ = Describe("STREAM frame", func() { data = append(data, encodeVarInt(uint64(protocol.MaxPacketBufferSize)+1)...) // data length data = append(data, make([]byte, protocol.MaxPacketBufferSize+1)...) r := bytes.NewReader(data) - _, err := parseStreamFrame(r, versionIETFFrames) + _, err := parseStreamFrame(r, protocol.Version1) Expect(err).To(Equal(io.EOF)) }) @@ -96,10 +96,10 @@ var _ = Describe("STREAM frame", func() { data = append(data, encodeVarInt(0xdecafbad)...) // offset data = append(data, encodeVarInt(6)...) // data length data = append(data, []byte("foobar")...) - _, err := parseStreamFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parseStreamFrame(bytes.NewReader(data), protocol.Version1) Expect(err).NotTo(HaveOccurred()) for i := range data { - _, err := parseStreamFrame(bytes.NewReader(data[0:i]), versionIETFFrames) + _, err := parseStreamFrame(bytes.NewReader(data[0:i]), protocol.Version1) Expect(err).To(HaveOccurred()) } }) @@ -111,7 +111,7 @@ var _ = Describe("STREAM frame", func() { data = append(data, encodeVarInt(0x12345)...) // stream ID data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize)...) r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, versionIETFFrames) + frame, err := parseStreamFrame(r, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) Expect(frame.Data).To(Equal(bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize))) @@ -127,7 +127,7 @@ var _ = Describe("STREAM frame", func() { data = append(data, encodeVarInt(0x12345)...) // stream ID data = append(data, bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1)...) r := bytes.NewReader(data) - frame, err := parseStreamFrame(r, versionIETFFrames) + frame, err := parseStreamFrame(r, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(frame.StreamID).To(Equal(protocol.StreamID(0x12345))) Expect(frame.Data).To(Equal(bytes.Repeat([]byte{'f'}, protocol.MinStreamFrameBufferSize-1))) @@ -146,7 +146,7 @@ var _ = Describe("STREAM frame", func() { Data: []byte("foobar"), } b := &bytes.Buffer{} - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x8} expected = append(expected, encodeVarInt(0x1337)...) // stream ID @@ -161,7 +161,7 @@ var _ = Describe("STREAM frame", func() { Data: []byte("foobar"), } b := &bytes.Buffer{} - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x8 ^ 0x4} expected = append(expected, encodeVarInt(0x1337)...) // stream ID @@ -177,7 +177,7 @@ var _ = Describe("STREAM frame", func() { Fin: true, } b := &bytes.Buffer{} - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x8 ^ 0x4 ^ 0x1} expected = append(expected, encodeVarInt(0x1337)...) // stream ID @@ -192,7 +192,7 @@ var _ = Describe("STREAM frame", func() { DataLenPresent: true, } b := &bytes.Buffer{} - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x8 ^ 0x2} expected = append(expected, encodeVarInt(0x1337)...) // stream ID @@ -209,7 +209,7 @@ var _ = Describe("STREAM frame", func() { Offset: 0x123456, } b := &bytes.Buffer{} - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) expected := []byte{0x8 ^ 0x4 ^ 0x2} expected = append(expected, encodeVarInt(0x1337)...) // stream ID @@ -225,7 +225,7 @@ var _ = Describe("STREAM frame", func() { Offset: 0x1337, } b := &bytes.Buffer{} - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).To(MatchError("StreamFrame: attempting to write empty frame without FIN")) }) }) @@ -236,7 +236,7 @@ var _ = Describe("STREAM frame", func() { StreamID: 0x1337, Data: []byte("foobar"), } - Expect(f.Length(versionIETFFrames)).To(Equal(1 + quicvarint.Len(0x1337) + 6)) + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + 6)) }) It("has the right length for a frame with offset", func() { @@ -245,7 +245,7 @@ var _ = Describe("STREAM frame", func() { Offset: 0x42, Data: []byte("foobar"), } - Expect(f.Length(versionIETFFrames)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(0x42) + 6)) + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(0x42) + 6)) }) It("has the right length for a frame with data length", func() { @@ -255,7 +255,7 @@ var _ = Describe("STREAM frame", func() { DataLenPresent: true, Data: []byte("foobar"), } - Expect(f.Length(versionIETFFrames)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(0x1234567) + quicvarint.Len(6) + 6)) + Expect(f.Length(protocol.Version1)).To(Equal(1 + quicvarint.Len(0x1337) + quicvarint.Len(0x1234567) + quicvarint.Len(6) + 6)) }) }) @@ -272,17 +272,17 @@ var _ = Describe("STREAM frame", func() { for i := 1; i < 3000; i++ { b.Reset() f.Data = nil - maxDataLen := f.MaxDataLen(protocol.ByteCount(i), versionIETFFrames) + maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size f.Data = []byte{0} - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(b.Len()).To(BeNumerically(">", i)) continue } f.Data = data[:int(maxDataLen)] - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(b.Len()).To(Equal(i)) } @@ -300,17 +300,17 @@ var _ = Describe("STREAM frame", func() { for i := 1; i < 3000; i++ { b.Reset() f.Data = nil - maxDataLen := f.MaxDataLen(protocol.ByteCount(i), versionIETFFrames) + maxDataLen := f.MaxDataLen(protocol.ByteCount(i), protocol.Version1) if maxDataLen == 0 { // 0 means that no valid STREAM frame can be written // check that writing a minimal size STREAM frame (i.e. with 1 byte data) is actually larger than the desired size f.Data = []byte{0} - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) Expect(b.Len()).To(BeNumerically(">", i)) continue } f.Data = data[:int(maxDataLen)] - err := f.Write(b, versionIETFFrames) + err := f.Write(b, protocol.Version1) Expect(err).ToNot(HaveOccurred()) // There's *one* pathological case, where a data length of x can be encoded into 1 byte // but a data lengths of x+1 needs 2 bytes @@ -333,11 +333,11 @@ var _ = Describe("STREAM frame", func() { Offset: 0xdeadbeef, Data: make([]byte, 100), } - frame, needsSplit := f.MaybeSplitOffFrame(f.Length(versionIETFFrames), versionIETFFrames) + frame, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1), protocol.Version1) Expect(needsSplit).To(BeFalse()) Expect(frame).To(BeNil()) Expect(f.DataLen()).To(BeEquivalentTo(100)) - frame, needsSplit = f.MaybeSplitOffFrame(f.Length(versionIETFFrames)-1, versionIETFFrames) + frame, needsSplit = f.MaybeSplitOffFrame(f.Length(protocol.Version1)-1, protocol.Version1) Expect(needsSplit).To(BeTrue()) Expect(frame.DataLen()).To(BeEquivalentTo(99)) f.PutBack() @@ -349,7 +349,7 @@ var _ = Describe("STREAM frame", func() { DataLenPresent: true, Data: make([]byte, 100), } - frame, needsSplit := f.MaybeSplitOffFrame(66, versionIETFFrames) + frame, needsSplit := f.MaybeSplitOffFrame(66, protocol.Version1) Expect(needsSplit).To(BeTrue()) Expect(frame).ToNot(BeNil()) Expect(f.DataLenPresent).To(BeTrue()) @@ -362,7 +362,7 @@ var _ = Describe("STREAM frame", func() { Offset: 0x100, Data: []byte("foobar"), } - frame, needsSplit := f.MaybeSplitOffFrame(f.Length(versionIETFFrames)-3, versionIETFFrames) + frame, needsSplit := f.MaybeSplitOffFrame(f.Length(protocol.Version1)-3, protocol.Version1) Expect(needsSplit).To(BeTrue()) Expect(frame).ToNot(BeNil()) Expect(frame.Offset).To(Equal(protocol.ByteCount(0x100))) @@ -378,7 +378,7 @@ var _ = Describe("STREAM frame", func() { Offset: 0xdeadbeef, Data: make([]byte, 100), } - frame, needsSplit := f.MaybeSplitOffFrame(50, versionIETFFrames) + frame, needsSplit := f.MaybeSplitOffFrame(50, protocol.Version1) Expect(needsSplit).To(BeTrue()) Expect(frame).ToNot(BeNil()) Expect(frame.Offset).To(BeNumerically("<", f.Offset)) @@ -393,18 +393,18 @@ var _ = Describe("STREAM frame", func() { Offset: 0x1234, Data: []byte{0}, } - minFrameSize := f.Length(versionIETFFrames) + minFrameSize := f.Length(protocol.Version1) for i := protocol.ByteCount(0); i < minFrameSize; i++ { - f, needsSplit := f.MaybeSplitOffFrame(i, versionIETFFrames) + f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) Expect(needsSplit).To(BeTrue()) Expect(f).To(BeNil()) } for i := minFrameSize; i < size; i++ { f.fromPool = false f.Data = make([]byte, size) - f, needsSplit := f.MaybeSplitOffFrame(i, versionIETFFrames) + f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) Expect(needsSplit).To(BeTrue()) - Expect(f.Length(versionIETFFrames)).To(Equal(i)) + Expect(f.Length(protocol.Version1)).To(Equal(i)) } }) @@ -416,9 +416,9 @@ var _ = Describe("STREAM frame", func() { DataLenPresent: true, Data: []byte{0}, } - minFrameSize := f.Length(versionIETFFrames) + minFrameSize := f.Length(protocol.Version1) for i := protocol.ByteCount(0); i < minFrameSize; i++ { - f, needsSplit := f.MaybeSplitOffFrame(i, versionIETFFrames) + f, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) Expect(needsSplit).To(BeTrue()) Expect(f).To(BeNil()) } @@ -426,16 +426,16 @@ var _ = Describe("STREAM frame", func() { for i := minFrameSize; i < size; i++ { f.fromPool = false f.Data = make([]byte, size) - newFrame, needsSplit := f.MaybeSplitOffFrame(i, versionIETFFrames) + newFrame, needsSplit := f.MaybeSplitOffFrame(i, protocol.Version1) Expect(needsSplit).To(BeTrue()) // There's *one* pathological case, where a data length of x can be encoded into 1 byte // but a data lengths of x+1 needs 2 bytes // In that case, it's impossible to create a STREAM frame of the desired size - if newFrame.Length(versionIETFFrames) == i-1 { + if newFrame.Length(protocol.Version1) == i-1 { frameOneByteTooSmallCounter++ continue } - Expect(newFrame.Length(versionIETFFrames)).To(Equal(i)) + Expect(newFrame.Length(protocol.Version1)).To(Equal(i)) } Expect(frameOneByteTooSmallCounter).To(Equal(1)) }) diff --git a/internal/wire/streams_blocked_frame_test.go b/internal/wire/streams_blocked_frame_test.go index f5f56b39ded..3247e2c891b 100644 --- a/internal/wire/streams_blocked_frame_test.go +++ b/internal/wire/streams_blocked_frame_test.go @@ -39,10 +39,10 @@ var _ = Describe("STREAMS_BLOCKED frame", func() { It("errors on EOFs", func() { data := []byte{0x16} data = append(data, encodeVarInt(0x12345678)...) - _, err := parseStreamsBlockedFrame(bytes.NewReader(data), versionIETFFrames) + _, err := parseStreamsBlockedFrame(bytes.NewReader(data), protocol.Version1) Expect(err).ToNot(HaveOccurred()) for i := range data { - _, err := parseStreamsBlockedFrame(bytes.NewReader(data[:i]), versionIETFFrames) + _, err := parseStreamsBlockedFrame(bytes.NewReader(data[:i]), protocol.Version1) Expect(err).To(MatchError(io.EOF)) } }) diff --git a/internal/wire/wire_suite_test.go b/internal/wire/wire_suite_test.go index 3917f3c1535..2af8bb888a1 100644 --- a/internal/wire/wire_suite_test.go +++ b/internal/wire/wire_suite_test.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "encoding/binary" "testing" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -16,13 +17,15 @@ func TestWire(t *testing.T) { RunSpecs(t, "Wire Suite") } -const ( - // a QUIC version that uses the IETF frame types - versionIETFFrames = protocol.VersionTLS -) - func encodeVarInt(i uint64) []byte { b := &bytes.Buffer{} quicvarint.Write(b, i) return b.Bytes() } + +func appendVersion(data []byte, v protocol.VersionNumber) []byte { + offset := len(data) + data = append(data, []byte{0, 0, 0, 0}...) + binary.BigEndian.PutUint32(data[offset:], uint32(v)) + return data +}