diff --git a/fastdialer/context.go b/fastdialer/context.go new file mode 100644 index 0000000..f7f39ba --- /dev/null +++ b/fastdialer/context.go @@ -0,0 +1,8 @@ +package fastdialer + +type ContextOption string + +const ( + // SniName to use in tls connection + SniName ContextOption = "sni-name" +) diff --git a/fastdialer/dialer.go b/fastdialer/dialer.go index 52a04f4..7365846 100644 --- a/fastdialer/dialer.go +++ b/fastdialer/dialer.go @@ -202,13 +202,25 @@ func (d *Dialer) dial(ctx context.Context, network, address string, shouldUseTLS hostPort := net.JoinHostPort(ip, port) if shouldUseTLS { tlsconfigCopy := tlsconfig.Clone() - if !iputil.IsIP(hostname) { + switch { + case d.options.SNIName != "": + tlsconfigCopy.ServerName = d.options.SNIName + case ctx.Value(SniName) != nil: + sniName := ctx.Value(SniName).(string) + tlsconfigCopy.ServerName = sniName + case !iputil.IsIP(hostname): tlsconfigCopy.ServerName = hostname } - conn, err = tls.DialWithDialer(d.dialer, network, hostPort, tlsconfig) + conn, err = tls.DialWithDialer(d.dialer, network, hostPort, tlsconfigCopy) } else if shouldUseZTLS { ztlsconfigCopy := ztlsconfig.Clone() - if !iputil.IsIP(hostname) { + switch { + case d.options.SNIName != "": + ztlsconfigCopy.ServerName = d.options.SNIName + case ctx.Value(SniName) != nil: + sniName := ctx.Value(SniName).(string) + ztlsconfigCopy.ServerName = sniName + case !iputil.IsIP(hostname): ztlsconfigCopy.ServerName = hostname } conn, err = ztls.DialWithDialer(d.dialer, network, hostPort, ztlsconfigCopy) diff --git a/fastdialer/options.go b/fastdialer/options.go index 3dc17e7..cd6c57f 100644 --- a/fastdialer/options.go +++ b/fastdialer/options.go @@ -46,6 +46,7 @@ type Options struct { DialerKeepAlive time.Duration Dialer *net.Dialer WithZTLS bool + SNIName string } // DefaultOptions of the cache