Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tcp option to set keep alive value on the tcp layer #561

Merged
merged 2 commits into from Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion client.go
Expand Up @@ -393,7 +393,7 @@ func (c *client) attemptConnection() (net.Conn, byte, bool, error) {
tlsCfg = c.options.OnConnectAttempt(broker, c.options.TLSConfig)
}
// Start by opening the network connection (tcp, tls, ws) etc
conn, err = openConnection(broker, tlsCfg, c.options.ConnectTimeout, c.options.HTTPHeaders, c.options.WebsocketOptions)
conn, err = openConnection(broker, tlsCfg, c.options.ConnectTimeout, c.options.HTTPHeaders, c.options.WebsocketOptions, c.options.Dialer)
if err != nil {
ERROR.Println(CLI, err.Error())
WARN.Println(CLI, "failed to connect to broker, trying next")
Expand Down
11 changes: 5 additions & 6 deletions netconn.go
Expand Up @@ -37,7 +37,7 @@ import (

// openConnection opens a network connection using the protocol indicated in the URL.
// Does not carry out any MQTT specific handshakes.
func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, headers http.Header, websocketOptions *WebsocketOptions) (net.Conn, error) {
func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, headers http.Header, websocketOptions *WebsocketOptions, dialer *net.Dialer) (net.Conn, error) {
switch uri.Scheme {
case "ws":
conn, err := NewWebsocket(uri.String(), nil, timeout, headers, websocketOptions)
Expand All @@ -48,7 +48,7 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade
case "mqtt", "tcp":
allProxy := os.Getenv("all_proxy")
if len(allProxy) == 0 {
conn, err := net.DialTimeout("tcp", uri.Host, timeout)
conn, err := dialer.Dial("tcp", uri.Host)
if err != nil {
return nil, err
}
Expand All @@ -68,9 +68,9 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade
// this check is preserved for compatibility with older versions
// which used uri.Host only (it works for local paths, e.g. unix://socket.sock in current dir)
if len(uri.Host) > 0 {
conn, err = net.DialTimeout("unix", uri.Host, timeout)
conn, err = dialer.Dial("unix", uri.Host)
} else {
conn, err = net.DialTimeout("unix", uri.Path, timeout)
conn, err = dialer.Dial("unix", uri.Path)
}

if err != nil {
Expand All @@ -80,14 +80,13 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, heade
case "ssl", "tls", "mqtts", "mqtt+ssl", "tcps":
allProxy := os.Getenv("all_proxy")
if len(allProxy) == 0 {
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: timeout}, "tcp", uri.Host, tlsc)
conn, err := tls.DialWithDialer(dialer, "tcp", uri.Host, tlsc)
if err != nil {
return nil, err
}
return conn, nil
}
proxyDialer := proxy.FromEnvironment()

conn, err := proxyDialer.Dial("tcp", uri.Host)
if err != nil {
return nil, err
Expand Down
10 changes: 10 additions & 0 deletions options.go
Expand Up @@ -23,6 +23,7 @@ package mqtt

import (
"crypto/tls"
"net"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -96,6 +97,7 @@ type ClientOptions struct {
HTTPHeaders http.Header
WebsocketOptions *WebsocketOptions
MaxResumePubInFlight int // // 0 = no limit; otherwise this is the maximum simultaneous messages sent while resuming
Dialer *net.Dialer
}

// NewClientOptions will create a new ClientClientOptions type with some
Expand Down Expand Up @@ -137,6 +139,7 @@ func NewClientOptions() *ClientOptions {
ResumeSubs: false,
HTTPHeaders: make(map[string][]string),
WebsocketOptions: &WebsocketOptions{},
Dialer: &net.Dialer{Timeout: 30 * time.Second},
}
return o
}
Expand Down Expand Up @@ -355,6 +358,7 @@ func (o *ClientOptions) SetWriteTimeout(t time.Duration) *ClientOptions {
// Default 30 seconds. Currently only operational on TCP/TLS connections.
func (o *ClientOptions) SetConnectTimeout(t time.Duration) *ClientOptions {
o.ConnectTimeout = t
o.Dialer.Timeout = t
return o
}

Expand Down Expand Up @@ -419,3 +423,9 @@ func (o *ClientOptions) SetMaxResumePubInFlight(MaxResumePubInFlight int) *Clien
o.MaxResumePubInFlight = MaxResumePubInFlight
return o
}

// SetDialer sets the tcp dialer options used in a tcp connection
func (o *ClientOptions) SetDialer(dialer *net.Dialer) *ClientOptions {
o.Dialer = dialer
return o
}