Skip to content

Commit

Permalink
Merge pull request #561 from magnitudespace/add-tcp-options
Browse files Browse the repository at this point in the history
Add `options.SetDialer` allowing a custom `net.Dialer` to be used for TCP connections (allows users to change tcp settings including keep-alive).
  • Loading branch information
MattBrittan committed Nov 17, 2021
2 parents efeb638 + 6934fb5 commit 5208ce8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
2 changes: 1 addition & 1 deletion client.go
Expand Up @@ -393,7 +393,7 @@ func (c *client) attemptConnection() (net.Conn, byte, bool, error) {
tlsCfg = c.options.OnConnectAttempt(broker, c.options.TLSConfig)
}
// Start by opening the network connection (tcp, tls, ws) etc
conn, err = openConnection(broker, tlsCfg, c.options.ConnectTimeout, c.options.HTTPHeaders, c.options.WebsocketOptions)
conn, err = openConnection(broker, tlsCfg, c.options.ConnectTimeout, c.options.HTTPHeaders, c.options.WebsocketOptions, c.options.Dialer)
if err != nil {
ERROR.Println(CLI, err.Error())
WARN.Println(CLI, "failed to connect to broker, trying next")
Expand Down
11 changes: 5 additions & 6 deletions netconn.go
Expand Up @@ -37,7 +37,7 @@ import (

// openConnection opens a network connection using the protocol indicated in the URL.
// Does not carry out any MQTT specific handshakes.
func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, headers http.Header, websocketOptions *WebsocketOptions) (net.Conn, error) {
func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, headers http.Header, websocketOptions *WebsocketOptions, dialer *net.Dialer) (net.Conn, error) {
switch uri.Scheme {
case "ws":
conn, err := NewWebsocket(uri.String(), nil, timeout, headers, websocketOptions)
Expand All @@ -48,7 +48,7 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade
case "mqtt", "tcp":
allProxy := os.Getenv("all_proxy")
if len(allProxy) == 0 {
conn, err := net.DialTimeout("tcp", uri.Host, timeout)
conn, err := dialer.Dial("tcp", uri.Host)
if err != nil {
return nil, err
}
Expand All @@ -68,9 +68,9 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade
// this check is preserved for compatibility with older versions
// which used uri.Host only (it works for local paths, e.g. unix://socket.sock in current dir)
if len(uri.Host) > 0 {
conn, err = net.DialTimeout("unix", uri.Host, timeout)
conn, err = dialer.Dial("unix", uri.Host)
} else {
conn, err = net.DialTimeout("unix", uri.Path, timeout)
conn, err = dialer.Dial("unix", uri.Path)
}

if err != nil {
Expand All @@ -80,14 +80,13 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade
case "ssl", "tls", "mqtts", "mqtt+ssl", "tcps":
allProxy := os.Getenv("all_proxy")
if len(allProxy) == 0 {
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: timeout}, "tcp", uri.Host, tlsc)
conn, err := tls.DialWithDialer(dialer, "tcp", uri.Host, tlsc)
if err != nil {
return nil, err
}
return conn, nil
}
proxyDialer := proxy.FromEnvironment()

conn, err := proxyDialer.Dial("tcp", uri.Host)
if err != nil {
return nil, err
Expand Down
10 changes: 10 additions & 0 deletions options.go
Expand Up @@ -23,6 +23,7 @@ package mqtt

import (
"crypto/tls"
"net"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -96,6 +97,7 @@ type ClientOptions struct {
HTTPHeaders http.Header
WebsocketOptions *WebsocketOptions
MaxResumePubInFlight int // // 0 = no limit; otherwise this is the maximum simultaneous messages sent while resuming
Dialer *net.Dialer
}

// NewClientOptions will create a new ClientClientOptions type with some
Expand Down Expand Up @@ -137,6 +139,7 @@ func NewClientOptions() *ClientOptions {
ResumeSubs: false,
HTTPHeaders: make(map[string][]string),
WebsocketOptions: &WebsocketOptions{},
Dialer: &net.Dialer{Timeout: 30 * time.Second},
}
return o
}
Expand Down Expand Up @@ -355,6 +358,7 @@ func (o *ClientOptions) SetWriteTimeout(t time.Duration) *ClientOptions {
// Default 30 seconds. Currently only operational on TCP/TLS connections.
func (o *ClientOptions) SetConnectTimeout(t time.Duration) *ClientOptions {
o.ConnectTimeout = t
o.Dialer.Timeout = t
return o
}

Expand Down Expand Up @@ -419,3 +423,9 @@ func (o *ClientOptions) SetMaxResumePubInFlight(MaxResumePubInFlight int) *Clien
o.MaxResumePubInFlight = MaxResumePubInFlight
return o
}

// SetDialer sets the tcp dialer options used in a tcp connection
func (o *ClientOptions) SetDialer(dialer *net.Dialer) *ClientOptions {
o.Dialer = dialer
return o
}

0 comments on commit 5208ce8

Please sign in to comment.