Skip to content

Commit

Permalink
Merge pull request #1147 from nats-io/custom-dialer-skip-tls
Browse files Browse the repository at this point in the history
* Add SkipTLSHandshake small interface to CustomDialer

This adds a small interface `SkipTLSHandshake() bool` that when 
implemented a `CustomDialer` may opt into skipping the handshake 
if not needed by the dialer implementation.

This option to skip the TLS wrapper is meant to be used if a custom
dialer already handled the TLS handshake. I discussed my use case,
which is using NATS from a Wasm module running in the browser in
Slack with @derekcollison. There may be other use cases when the
custom dialer is using some kind of tunneling, for example.

Signed-off-by: Waldemar Quevedo <wally@nats.io>
Co-authored-by: Hans Raaf <hara@oderwat.de>
Co-authored-by: Piotr Piotrowski <piotr@synadia.com>
  • Loading branch information
3 people committed Nov 30, 2022
2 parents 398a1ec + 65b7870 commit 907b219
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 3 deletions.
35 changes: 35 additions & 0 deletions example_test.go
Expand Up @@ -17,6 +17,7 @@ import (
"context"
"fmt"
"log"
"net"
"time"

"github.com/nats-io/nats.go"
Expand Down Expand Up @@ -44,6 +45,40 @@ func ExampleConnect() {
nc.Close()
}

type skipTLSDialer struct {
dialer *net.Dialer
skipTLS bool
}

func (sd *skipTLSDialer) Dial(network, address string) (net.Conn, error) {
return sd.dialer.Dial(network, address)
}

func (sd *skipTLSDialer) SkipTLSHandshake() bool {
return sd.skipTLS
}

func ExampleCustomDialer() {
// Given the following CustomDialer implementation:
//
// type skipTLSDialer struct {
// dialer *net.Dialer
// skipTLS bool
// }
//
// func (sd *skipTLSDialer) Dial(network, address string) (net.Conn, error) {
// return sd.dialer.Dial(network, address)
// }
//
// func (sd *skipTLSDialer) SkipTLSHandshake() bool {
// return true
// }
//
sd := &skipTLSDialer{dialer: &net.Dialer{Timeout: 2 * time.Second}, skipTLS: true}
nc, _ := nats.Connect("demo.nats.io", nats.SetCustomDialer(sd))
defer nc.Close()
}

// This Example shows an asynchronous subscriber.
func ExampleConn_Subscribe() {
nc, _ := nats.Connect(nats.DefaultURL)
Expand Down
16 changes: 14 additions & 2 deletions nats.go
Expand Up @@ -247,8 +247,9 @@ type asyncCallbacksHandler struct {
// Option is a function on the options for a connection.
type Option func(*Options) error

// CustomDialer can be used to specify any dialer, not necessarily
// a *net.Dialer.
// CustomDialer can be used to specify any dialer, not necessarily a
// *net.Dialer. A CustomDialer may also implement `SkipTLSHandshake() bool`
// in order to skip the TLS handshake in case not required.
type CustomDialer interface {
Dial(network, address string) (net.Conn, error)
}
Expand Down Expand Up @@ -1892,8 +1893,19 @@ func (nc *Conn) createConn() (err error) {
return nil
}

type skipTLSDialer interface {
SkipTLSHandshake() bool
}

// makeTLSConn will wrap an existing Conn using TLS
func (nc *Conn) makeTLSConn() error {
if nc.Opts.CustomDialer != nil {
// we do nothing when asked to skip the TLS wrapper
sd, ok := nc.Opts.CustomDialer.(skipTLSDialer)
if ok && sd.SkipTLSHandshake() {
return nil
}
}
// Allow the user to configure their own tls.Config structure.
var tlsCopy *tls.Config
if nc.Opts.TLSConfig != nil {
Expand Down
2 changes: 1 addition & 1 deletion services/service.go
Expand Up @@ -31,7 +31,7 @@ import (

type (

// Service is an interface for sevice management.
// Service is an interface for service management.
// It exposes methods to stop/reset a service, as well as get information on a service.
Service interface {
ID() string
Expand Down
54 changes: 54 additions & 0 deletions ws_test.go
Expand Up @@ -868,6 +868,60 @@ func TestWSWithTLS(t *testing.T) {
}
}

type testSkipTLSDialer struct {
dialer *net.Dialer
skipTLS bool
}

func (sd *testSkipTLSDialer) Dial(network, address string) (net.Conn, error) {
return sd.dialer.Dial(network, address)
}

func (sd *testSkipTLSDialer) SkipTLSHandshake() bool {
return sd.skipTLS
}

func TestWSWithTLSCustomDialer(t *testing.T) {
sopts := testWSGetDefaultOptions(t, true)
s := RunServerWithOptions(sopts)
defer s.Shutdown()

sd := &testSkipTLSDialer{
dialer: &net.Dialer{
Timeout: 2 * time.Second,
},
skipTLS: true,
}

// Connect with CustomDialer that fails since TLSHandshake is disabled.
copts := make([]Option, 0)
copts = append(copts, Secure(&tls.Config{InsecureSkipVerify: true}))
copts = append(copts, SetCustomDialer(sd))
_, err := Connect(fmt.Sprintf("wss://localhost:%d", sopts.Websocket.Port), copts...)
if err == nil {
t.Fatalf("Expected error on connect: %v", err)
}
if err.Error() != `invalid websocket connection` {
t.Logf("Expected invalid websocket connection: %v", err)
}

// Retry with the dialer.
copts = make([]Option, 0)
sd = &testSkipTLSDialer{
dialer: &net.Dialer{
Timeout: 2 * time.Second,
},
skipTLS: false,
}
copts = append(copts, Secure(&tls.Config{InsecureSkipVerify: true}))
copts = append(copts, SetCustomDialer(sd))
nc, err := Connect(fmt.Sprintf("wss://localhost:%d", sopts.Websocket.Port), copts...)
if err != nil {
t.Fatalf("Unexpected error on connect: %v", err)
}
defer nc.Close()
}

func TestWSTlsNoConfig(t *testing.T) {
opts := GetDefaultOptions()
opts.Servers = []string{"wss://localhost:443"}
Expand Down

0 comments on commit 907b219

Please sign in to comment.