Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SkipTLSHandshake optional function to CustomDialer #1147

Merged
merged 2 commits into from Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 35 additions & 0 deletions example_test.go
Expand Up @@ -17,6 +17,7 @@ import (
"context"
"fmt"
"log"
"net"
"time"

"github.com/nats-io/nats.go"
Expand Down Expand Up @@ -44,6 +45,40 @@ func ExampleConnect() {
nc.Close()
}

type skipTLSDialer struct {
dialer *net.Dialer
skipTLS bool
}

func (sd *skipTLSDialer) Dial(network, address string) (net.Conn, error) {
return sd.dialer.Dial(network, address)
}

func (sd *skipTLSDialer) SkipTLSHandshake() bool {
return sd.skipTLS
}

func ExampleCustomDialer() {
// Given the following CustomDialer implementation:
//
// type skipTLSDialer struct {
// dialer *net.Dialer
// skipTLS bool
// }
//
// func (sd *skipTLSDialer) Dial(network, address string) (net.Conn, error) {
// return sd.dialer.Dial(network, address)
// }
//
// func (sd *skipTLSDialer) SkipTLSHandshake() bool {
// return true
// }
//
sd := &skipTLSDialer{dialer: &net.Dialer{Timeout: 2 * time.Second}, skipTLS: true}
nc, _ := nats.Connect("demo.nats.io", nats.SetCustomDialer(sd))
defer nc.Close()
}

// This Example shows an asynchronous subscriber.
func ExampleConn_Subscribe() {
nc, _ := nats.Connect(nats.DefaultURL)
Expand Down
16 changes: 14 additions & 2 deletions nats.go
Expand Up @@ -247,8 +247,9 @@ type asyncCallbacksHandler struct {
// Option is a function on the options for a connection.
type Option func(*Options) error

// CustomDialer can be used to specify any dialer, not necessarily
// a *net.Dialer.
// CustomDialer can be used to specify any dialer, not necessarily a
// *net.Dialer. A CustomDialer may also implement `SkipTLSHandshake() bool`
// in order to skip the TLS handshake in case not required.
type CustomDialer interface {
Dial(network, address string) (net.Conn, error)
}
Expand Down Expand Up @@ -1892,8 +1893,19 @@ func (nc *Conn) createConn() (err error) {
return nil
}

type skipTLSDialer interface {
SkipTLSHandshake() bool
}

// makeTLSConn will wrap an existing Conn using TLS
func (nc *Conn) makeTLSConn() error {
if nc.Opts.CustomDialer != nil {
// we do nothing when asked to skip the TLS wrapper
sd, ok := nc.Opts.CustomDialer.(skipTLSDialer)
if ok && sd.SkipTLSHandshake() {
return nil
}
}
// Allow the user to configure their own tls.Config structure.
var tlsCopy *tls.Config
if nc.Opts.TLSConfig != nil {
Expand Down
2 changes: 1 addition & 1 deletion services/service.go
Expand Up @@ -31,7 +31,7 @@ import (

type (

// Service is an interface for sevice management.
// Service is an interface for service management.
// It exposes methods to stop/reset a service, as well as get information on a service.
Service interface {
ID() string
Expand Down
54 changes: 54 additions & 0 deletions ws_test.go
Expand Up @@ -868,6 +868,60 @@ func TestWSWithTLS(t *testing.T) {
}
}

type testSkipTLSDialer struct {
dialer *net.Dialer
skipTLS bool
}

func (sd *testSkipTLSDialer) Dial(network, address string) (net.Conn, error) {
return sd.dialer.Dial(network, address)
}

func (sd *testSkipTLSDialer) SkipTLSHandshake() bool {
return sd.skipTLS
}

func TestWSWithTLSCustomDialer(t *testing.T) {
sopts := testWSGetDefaultOptions(t, true)
s := RunServerWithOptions(sopts)
defer s.Shutdown()

sd := &testSkipTLSDialer{
dialer: &net.Dialer{
Timeout: 2 * time.Second,
},
skipTLS: true,
}

// Connect with CustomDialer that fails since TLSHandshake is disabled.
copts := make([]Option, 0)
copts = append(copts, Secure(&tls.Config{InsecureSkipVerify: true}))
copts = append(copts, SetCustomDialer(sd))
_, err := Connect(fmt.Sprintf("wss://localhost:%d", sopts.Websocket.Port), copts...)
if err == nil {
t.Fatalf("Expected error on connect: %v", err)
}
if err.Error() != `invalid websocket connection` {
t.Logf("Expected invalid websocket connection: %v", err)
}

// Retry with the dialer.
copts = make([]Option, 0)
sd = &testSkipTLSDialer{
dialer: &net.Dialer{
Timeout: 2 * time.Second,
},
skipTLS: false,
}
copts = append(copts, Secure(&tls.Config{InsecureSkipVerify: true}))
copts = append(copts, SetCustomDialer(sd))
nc, err := Connect(fmt.Sprintf("wss://localhost:%d", sopts.Websocket.Port), copts...)
if err != nil {
t.Fatalf("Unexpected error on connect: %v", err)
}
defer nc.Close()
}

func TestWSTlsNoConfig(t *testing.T) {
opts := GetDefaultOptions()
opts.Servers = []string{"wss://localhost:443"}
Expand Down