Skip to content

Commit

Permalink
Added utls to websocket (#1256)
Browse files Browse the repository at this point in the history
* Added utls to websocket

* Slightly better code

One less allocation
  • Loading branch information
HirbodBehnam committed Oct 18, 2022
1 parent 149e224 commit 1f93cbb
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
27 changes: 27 additions & 0 deletions transport/internet/tls/tls.go
Expand Up @@ -66,6 +66,33 @@ func (c *UConn) HandshakeAddress() net.Address {
return net.ParseAddress(state.ServerName)
}

// WebsocketHandshake basically calls UConn.Handshake inside it but it will only send
// http/1.1 in its ALPN.
func (c *UConn) WebsocketHandshake() error {
// Build the handshake state. This will apply every variable of the TLS of the
// fingerprint in the UConn
if err := c.BuildHandshakeState(); err != nil {
return err
}
// Iterate over extensions and check for utls.ALPNExtension
hasALPNExtension := false
for _, extension := range c.Extensions {
if alpn, ok := extension.(*utls.ALPNExtension); ok {
hasALPNExtension = true
alpn.AlpnProtocols = []string{"http/1.1"}
break
}
}
if !hasALPNExtension { // Append extension if doesn't exists
c.Extensions = append(c.Extensions, &utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}})
}
// Rebuild the client hello and do the handshake
if err := c.BuildHandshakeState(); err != nil {
return err
}
return c.Handshake()
}

func (c *UConn) NegotiatedProtocol() (name string, mutual bool) {
state := c.ConnectionState()
return state.NegotiatedProtocol, state.NegotiatedProtocolIsMutual
Expand Down
27 changes: 26 additions & 1 deletion transport/internet/websocket/dialer.go
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/base64"
"fmt"
"io"
gonet "net"
"net/http"
"os"
"time"
Expand Down Expand Up @@ -83,7 +84,31 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in

if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
protocol = "wss"
dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1"))
tlsConfig := config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1"))
dialer.TLSClientConfig = tlsConfig
if fingerprint, exists := tls.Fingerprints[config.Fingerprint]; exists {
dialer.NetDialTLSContext = func(_ context.Context, _, addr string) (gonet.Conn, error) {
// Like the NetDial in the dialer
pconn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
if err != nil {
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
return nil, err
}
// TLS and apply the handshake
cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
if err := cn.WebsocketHandshake(); err != nil {
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
return nil, err
}
if !tlsConfig.InsecureSkipVerify {
if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
return nil, err
}
}
return cn, nil
}
}
}

host := dest.NetAddr()
Expand Down

0 comments on commit 1f93cbb

Please sign in to comment.