Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement ALPS extension #142

Merged
merged 4 commits into from Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 11 additions & 1 deletion common.go
Expand Up @@ -89,6 +89,7 @@ const (
extensionSupportedPoints uint16 = 11
extensionSignatureAlgorithms uint16 = 13
extensionALPN uint16 = 16
extensionStatusRequestV2 uint16 = 17
extensionSCT uint16 = 18
extensionDelegatedCredentials uint16 = 34
extensionSessionTicket uint16 = 35
Expand All @@ -100,7 +101,7 @@ const (
extensionCertificateAuthorities uint16 = 47
extensionSignatureAlgorithmsCert uint16 = 50
extensionKeyShare uint16 = 51
extensionNextProtoNeg uint16 = 13172 // not IANA assigned // Pending discussion on whether or not remove this. crypto/tls removed it on Nov 21, 2019.
extensionNextProtoNeg uint16 = 13172 // not IANA assigned // Pending discussion on whether or not remove this. crypto/tls removed it on Nov 21, 2019.
extensionRenegotiationInfo uint16 = 0xff01
)

Expand Down Expand Up @@ -237,6 +238,10 @@ type ConnectionState struct {
// Deprecated: this value is always true.
NegotiatedProtocolIsMutual bool

// PeerApplicationSettings is the Application-Layer Protocol Settings (ALPS)
// provided by peer.
PeerApplicationSettings []byte // [uTLS]

// ServerName is the value of the Server Name Indication extension sent by
// the client. It's available both on the server and on the client side.
ServerName string
Expand Down Expand Up @@ -624,6 +629,10 @@ type Config struct {
// ConnectionState.NegotiatedProtocol will be empty.
NextProtos []string

// ApplicationSettings is a set of application settings (ALPS) to use
// with each application protocol (ALPN).
ApplicationSettings map[string][]byte // [uTLS]

// ServerName is used to verify the hostname on the returned
// certificates unless InsecureSkipVerify is given. It is also included
// in the client's handshake to support virtual hosting unless it is
Expand Down Expand Up @@ -799,6 +808,7 @@ func (c *Config) Clone() *Config {
VerifyConnection: c.VerifyConnection,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
ApplicationSettings: c.ApplicationSettings,
ServerName: c.ServerName,
ClientAuth: c.ClientAuth,
ClientCAs: c.ClientCAs,
Expand Down
22 changes: 16 additions & 6 deletions conn.go
Expand Up @@ -92,6 +92,10 @@ type Conn struct {
// clientProtocol is the negotiated ALPN protocol.
clientProtocol string

// [UTLS SECTION START]
utls utlsConnExtraFields // used for extensive things such as ALPS
// [UTLS SECTION END]

// input/output
in, out halfConn
rawInput bytes.Buffer // raw input, starting with a record header
Expand Down Expand Up @@ -1075,17 +1079,22 @@ func (c *Conn) readHandshake() (any, error) {
}
case typeFinished:
m = new(finishedMsg)
case typeEncryptedExtensions:
m = new(encryptedExtensionsMsg)
// [uTLS] Commented typeEncryptedExtensions to force
// utlsHandshakeMessageType to handle it
// case typeEncryptedExtensions:
// m = new(encryptedExtensionsMsg)
case typeEndOfEarlyData:
m = new(endOfEarlyDataMsg)
case typeKeyUpdate:
m = new(keyUpdateMsg)
// [UTLS SECTION BEGINS]
case typeCompressedCertificate:
m = new(compressedCertificateMsg)
// [UTLS SECTION ENDS]
default:
// [UTLS SECTION BEGINS]
var err error
m, err = c.utlsHandshakeMessageType(data[0]) // see u_conn.go
if err == nil {
break
}
// [UTLS SECTION ENDS]
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}

Expand Down Expand Up @@ -1514,6 +1523,7 @@ func (c *Conn) connectionStateLocked() ConnectionState {
} else {
state.ekm = c.ekm
}

return state
}

Expand Down
121 changes: 26 additions & 95 deletions handshake_client_tls13.go
Expand Up @@ -6,20 +6,15 @@ package tls

import (
"bytes"
"compress/zlib"
"context"
"crypto"
"crypto/hmac"
"crypto/rsa"
"errors"
"fmt"
"hash"
"io"
"sync/atomic"
"time"

"github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
)

type clientHandshakeStateTLS13 struct {
Expand Down Expand Up @@ -103,6 +98,11 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
if err := hs.readServerFinished(); err != nil {
return err
}
// [UTLS SECTION START]
if err := hs.serverFinishedReceived(); err != nil {
return err
}
// [UTLS SECTION END]
if err := hs.sendClientCertificate(); err != nil {
return err
}
Expand Down Expand Up @@ -477,6 +477,15 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error {
}
c.clientProtocol = encryptedExtensions.alpnProtocol

// [UTLS SECTION STARTS]
if hs.uconn != nil {
err = hs.utlsReadServerParameters(encryptedExtensions)
if err != nil {
c.sendAlert(alertUnsupportedExtension)
return err
}
}
// [UTLS SECTION ENDS]
return nil
}

Expand Down Expand Up @@ -516,19 +525,15 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
}

// [UTLS SECTION BEGINS]
receivedCompressedCert := false
// Check to see if we advertised any compression algorithms
if hs.uconn != nil && len(hs.uconn.certCompressionAlgs) > 0 {
// Check to see if the message is a compressed certificate message, otherwise move on.
compressedCertMsg, ok := msg.(*compressedCertificateMsg)
if ok {
receivedCompressedCert = true
hs.transcript.Write(compressedCertMsg.marshal())

msg, err = hs.decompressCert(*compressedCertMsg)
if err != nil {
return fmt.Errorf("tls: failed to decompress certificate message: %w", err)
}
var skipWritingCertToTranscript bool = false
if hs.uconn != nil {
processedMsg, err := hs.utlsReadServerCertificate(msg)
if err != nil {
return err
}
if processedMsg != nil {
skipWritingCertToTranscript = true
msg = processedMsg // msg is now a processed-by-extension certificateMsg
}
}
// [UTLS SECTION ENDS]
Expand All @@ -544,7 +549,7 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
}
// [UTLS SECTION BEGINS]
// Previously, this was simply 'hs.transcript.Write(certMsg.marshal())' (without the if).
if !receivedCompressedCert {
if !skipWritingCertToTranscript {
hs.transcript.Write(certMsg.marshal())
}
// [UTLS SECTION ENDS]
Expand All @@ -570,15 +575,15 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
// See RFC 8446, Section 4.4.3.
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: certificate used with invalid signature algorithm")
return errors.New("tls: certificate used with invalid signature algorithm -- not implemented")
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: certificate used with invalid signature algorithm")
return errors.New("tls: certificate used with invalid signature algorithm -- obsolete")
}
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
Expand Down Expand Up @@ -729,80 +734,6 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
return nil
}

// [UTLS SECTION BEGINS]
func (hs *clientHandshakeStateTLS13) decompressCert(m compressedCertificateMsg) (*certificateMsgTLS13, error) {
var (
decompressed io.Reader
compressed = bytes.NewReader(m.compressedCertificateMessage)
c = hs.c
)

// Check to see if the peer responded with an algorithm we advertised.
supportedAlg := false
for _, alg := range hs.uconn.certCompressionAlgs {
if m.algorithm == uint16(alg) {
supportedAlg = true
}
}
if !supportedAlg {
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("unadvertised algorithm (%d)", m.algorithm)
}

switch CertCompressionAlgo(m.algorithm) {
case CertCompressionBrotli:
decompressed = brotli.NewReader(compressed)

case CertCompressionZlib:
rc, err := zlib.NewReader(compressed)
if err != nil {
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("failed to open zlib reader: %w", err)
}
defer rc.Close()
decompressed = rc

case CertCompressionZstd:
rc, err := zstd.NewReader(compressed)
if err != nil {
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("failed to open zstd reader: %w", err)
}
defer rc.Close()
decompressed = rc

default:
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("unsupported algorithm (%d)", m.algorithm)
}

rawMsg := make([]byte, m.uncompressedLength+4) // +4 for message type and uint24 length field
rawMsg[0] = typeCertificate
rawMsg[1] = uint8(m.uncompressedLength >> 16)
rawMsg[2] = uint8(m.uncompressedLength >> 8)
rawMsg[3] = uint8(m.uncompressedLength)

n, err := decompressed.Read(rawMsg[4:])
if err != nil {
c.sendAlert(alertBadCertificate)
return nil, err
}
if n < len(rawMsg)-4 {
// If, after decompression, the specified length does not match the actual length, the party
// receiving the invalid message MUST abort the connection with the "bad_certificate" alert.
// https://datatracker.ietf.org/doc/html/rfc8879#section-4
c.sendAlert(alertBadCertificate)
return nil, fmt.Errorf("decompressed len (%d) does not match specified len (%d)", n, m.uncompressedLength)
}
certMsg := new(certificateMsgTLS13)
if !certMsg.unmarshal(rawMsg) {
return nil, c.sendAlert(alertUnexpectedMessage)
}
return certMsg, nil
}

// [UTLS SECTION ENDS]

func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
if !c.isClient {
c.sendAlert(alertUnexpectedMessage)
Expand Down
7 changes: 7 additions & 0 deletions handshake_messages.go
Expand Up @@ -868,6 +868,8 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
type encryptedExtensionsMsg struct {
raw []byte
alpnProtocol string

utls utlsEncryptedExtensionsMsgExtraFields // [uTLS]
}

func (m *encryptedExtensionsMsg) marshal() []byte {
Expand Down Expand Up @@ -927,6 +929,11 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
}
m.alpnProtocol = string(proto)
default:
// [UTLS SECTION START]
if !m.utlsUnmarshal(extension, extData) {
return false // return false when ERROR
}
// [UTLS SECTION END]
// Ignore unknown extensions.
continue
}
Expand Down
6 changes: 3 additions & 3 deletions handshake_messages_test.go
Expand Up @@ -36,7 +36,7 @@ var tests = []any{
&newSessionTicketMsgTLS13{},
&certificateRequestMsgTLS13{},
&certificateMsgTLS13{},
&compressedCertificateMsg{}, // [UTLS]
&utlsCompressedCertificateMsg{}, // [UTLS]
}

func TestMarshalUnmarshal(t *testing.T) {
Expand Down Expand Up @@ -406,8 +406,8 @@ func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
}

// [UTLS]
func (*compressedCertificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &compressedCertificateMsg{}
func (*utlsCompressedCertificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &utlsCompressedCertificateMsg{}
m.algorithm = uint16(rand.Intn(2 << 15))
m.uncompressedLength = uint32(rand.Intn(2 << 23))
m.compressedCertificateMessage = randomBytes(rand.Intn(500)+1, rand)
Expand Down
2 changes: 1 addition & 1 deletion key_agreement.go
Expand Up @@ -319,7 +319,7 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell
}

if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) {
return errors.New("tls: certificate used with invalid signature algorithm")
return fmt.Errorf("tls: certificate used with invalid signature algorithm -- ClientHello not advertising %04x", uint16(signatureAlgorithm))
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
if err != nil {
Expand Down
5 changes: 4 additions & 1 deletion tls_test.go
Expand Up @@ -12,7 +12,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/refraction-networking/utls/testenv"
"io"
"math"
"net"
Expand All @@ -22,6 +21,8 @@ import (
"strings"
"testing"
"time"

"github.com/refraction-networking/utls/testenv"
)

var rsaCertPEM = `-----BEGIN CERTIFICATE-----
Expand Down Expand Up @@ -827,6 +828,8 @@ func TestCloneNonFuncFields(t *testing.T) {
f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
case "mutex", "autoSessionTicketKeys", "sessionTicketKeys":
continue // these are unexported fields that are handled separately
case "ApplicationSettings":
f.Set(reflect.ValueOf(map[string][]byte{"a": {1}}))
default:
t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
}
Expand Down