From 9d3ee7bea6d45727a7659bc7da8b9bbcd9ac7133 Mon Sep 17 00:00:00 2001 From: Hirbod Behnam Date: Sat, 15 Oct 2022 11:49:07 +0330 Subject: [PATCH 1/2] Added utls to websocket --- transport/internet/tls/tls.go | 27 ++++++++++++++++++++++++++ transport/internet/websocket/dialer.go | 27 +++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/transport/internet/tls/tls.go b/transport/internet/tls/tls.go index ea86c0cea75..27cde1e4186 100644 --- a/transport/internet/tls/tls.go +++ b/transport/internet/tls/tls.go @@ -66,6 +66,33 @@ func (c *UConn) HandshakeAddress() net.Address { return net.ParseAddress(state.ServerName) } +// WebsocketHandshake basically calls UConn.Handshake inside it but it will only send +// http/1.1 in its ALPN. +func (c *UConn) WebsocketHandshake() error { + // Build the handshake state. This will apply every variable of the TLS of the + // fingerprint in the UConn + if err := c.BuildHandshakeState(); err != nil { + return err + } + // Iterate over extensions and check for utls.ALPNExtension + hasALPNExtension := false + for i, extension := range c.Extensions { + if _, ok := extension.(*utls.ALPNExtension); ok { + hasALPNExtension = true + c.Extensions[i] = &utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}} + break + } + } + if !hasALPNExtension { // Append extension if doesn't exists + c.Extensions = append(c.Extensions, &utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}}) + } + // Rebuild the client hello and do the handshake + if err := c.BuildHandshakeState(); err != nil { + return err + } + return c.Handshake() +} + func (c *UConn) NegotiatedProtocol() (name string, mutual bool) { state := c.ConnectionState() return state.NegotiatedProtocol, state.NegotiatedProtocolIsMutual diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index 284d8deed87..a8f712647ad 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "fmt" "io" + gonet "net" "net/http" "os" "time" @@ -83,7 +84,31 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { protocol = "wss" - dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1")) + tlsConfig := config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1")) + dialer.TLSClientConfig = tlsConfig + if fingerprint, exists := tls.Fingerprints[config.Fingerprint]; exists { + dialer.NetDialTLSContext = func(_ context.Context, _, addr string) (gonet.Conn, error) { + // Like the NetDial in the dialer + pconn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) + if err != nil { + newError("failed to dial to " + addr).Base(err).AtError().WriteToLog() + return nil, err + } + // TLS and apply the handshake + cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn) + if err := cn.WebsocketHandshake(); err != nil { + newError("failed to dial to " + addr).Base(err).AtError().WriteToLog() + return nil, err + } + if !tlsConfig.InsecureSkipVerify { + if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil { + newError("failed to dial to " + addr).Base(err).AtError().WriteToLog() + return nil, err + } + } + return cn, nil + } + } } host := dest.NetAddr() From d3bb7a11041f78e77cb2b3ca3787ba81ee4437d5 Mon Sep 17 00:00:00 2001 From: Hirbod Behnam Date: Sat, 15 Oct 2022 12:34:51 +0330 Subject: [PATCH 2/2] Slightly better code One less allocation --- transport/internet/tls/tls.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transport/internet/tls/tls.go b/transport/internet/tls/tls.go index 27cde1e4186..f1291e81c6e 100644 --- a/transport/internet/tls/tls.go +++ b/transport/internet/tls/tls.go @@ -76,10 +76,10 @@ func (c *UConn) WebsocketHandshake() error { } // Iterate over extensions and check for utls.ALPNExtension hasALPNExtension := false - for i, extension := range c.Extensions { - if _, ok := extension.(*utls.ALPNExtension); ok { + for _, extension := range c.Extensions { + if alpn, ok := extension.(*utls.ALPNExtension); ok { hasALPNExtension = true - c.Extensions[i] = &utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}} + alpn.AlpnProtocols = []string{"http/1.1"} break } }