Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added utls to websocket #1256

Merged
merged 2 commits into from Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 27 additions & 0 deletions transport/internet/tls/tls.go
Expand Up @@ -66,6 +66,33 @@ func (c *UConn) HandshakeAddress() net.Address {
return net.ParseAddress(state.ServerName)
}

// WebsocketHandshake basically calls UConn.Handshake inside it but it will only send
// http/1.1 in its ALPN.
func (c *UConn) WebsocketHandshake() error {
// Build the handshake state. This will apply every variable of the TLS of the
// fingerprint in the UConn
if err := c.BuildHandshakeState(); err != nil {
return err
}
// Iterate over extensions and check for utls.ALPNExtension
hasALPNExtension := false
for _, extension := range c.Extensions {
if alpn, ok := extension.(*utls.ALPNExtension); ok {
hasALPNExtension = true
alpn.AlpnProtocols = []string{"http/1.1"}
break
}
}
if !hasALPNExtension { // Append extension if doesn't exists
c.Extensions = append(c.Extensions, &utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}})
}
// Rebuild the client hello and do the handshake
if err := c.BuildHandshakeState(); err != nil {
return err
}
return c.Handshake()
}

func (c *UConn) NegotiatedProtocol() (name string, mutual bool) {
state := c.ConnectionState()
return state.NegotiatedProtocol, state.NegotiatedProtocolIsMutual
Expand Down
27 changes: 26 additions & 1 deletion transport/internet/websocket/dialer.go
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/base64"
"fmt"
"io"
gonet "net"
"net/http"
"os"
"time"
Expand Down Expand Up @@ -83,7 +84,31 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in

if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
protocol = "wss"
dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1"))
tlsConfig := config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1"))
dialer.TLSClientConfig = tlsConfig
if fingerprint, exists := tls.Fingerprints[config.Fingerprint]; exists {
dialer.NetDialTLSContext = func(_ context.Context, _, addr string) (gonet.Conn, error) {
// Like the NetDial in the dialer
pconn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
if err != nil {
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
return nil, err
}
// TLS and apply the handshake
cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
if err := cn.WebsocketHandshake(); err != nil {
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
return nil, err
}
if !tlsConfig.InsecureSkipVerify {
if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
return nil, err
}
}
return cn, nil
}
}
}

host := dest.NetAddr()
Expand Down