From d871f336e630a41a3ed0e238d921d54225856b87 Mon Sep 17 00:00:00 2001 From: Diego Alvarez Date: Tue, 12 May 2020 15:22:50 -0700 Subject: [PATCH] Set server name only for the current broker Fixes https://github.com/Shopify/sarama/issues/1700 --- broker.go | 42 +++++++++++++++++++++++------------------- client_tls_test.go | 15 +++++++++++++++ 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/broker.go b/broker.go index cda5da8f7..916dd0d6e 100644 --- a/broker.go +++ b/broker.go @@ -162,29 +162,11 @@ func (b *Broker) Open(conf *Config) error { atomic.StoreInt32(&b.opened, 0) return } - if conf.Net.TLS.Enable { - Logger.Printf("Using tls") - cfg := conf.Net.TLS.Config - if cfg == nil { - cfg = &tls.Config{} - } - // If no ServerName is set, infer the ServerName - // from the hostname we're connecting to. - // Gets the hostname as tls.DialWithDialer does it. - if cfg.ServerName == "" { - colonPos := strings.LastIndex(b.addr, ":") - if colonPos == -1 { - colonPos = len(b.addr) - } - hostname := b.addr[:colonPos] - cfg.ServerName = hostname - } - b.conn = tls.Client(b.conn, cfg) + b.conn = tls.Client(b.conn, validServerNameTLS(b.addr, conf.Net.TLS.Config)) } b.conn = newBufConn(b.conn) - b.conf = conf // Create or reuse the global metrics shared between brokers @@ -1440,3 +1422,25 @@ func (b *Broker) registerCounter(name string) metrics.Counter { b.registeredMetrics = append(b.registeredMetrics, nameForBroker) return metrics.GetOrRegisterCounter(nameForBroker, b.conf.MetricRegistry) } + +func validServerNameTLS(addr string, conf *tls.Config) *tls.Config { + cfg := conf + if cfg == nil { + cfg = &tls.Config{} + } + // If no ServerName is set, infer the ServerName + // from the hostname we're connecting to. + // Gets the hostname as tls.DialWithDialer does it. + if cfg.ServerName == "" { + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + // Make a copy to avoid polluting argument or default. + c := cfg.Clone() + c.ServerName = hostname + cfg = c + } + return cfg +} diff --git a/client_tls_test.go b/client_tls_test.go index 750145610..2d9d8fdd5 100644 --- a/client_tls_test.go +++ b/client_tls_test.go @@ -210,3 +210,18 @@ func doListenerTLSTest(t *testing.T, expectSuccess bool, serverConfig, clientCon } } } + +func TestSetServerName(t *testing.T) { + if validServerNameTLS("kafka-server.domain.com", nil).ServerName != "kafka-server.domain.com" { + t.Fatal("Expected kafka-server.domain.com as tls.ServerName when tls config is nil") + } + + if validServerNameTLS("kafka-server.domain.com", &tls.Config{}).ServerName != "kafka-server.domain.com" { + t.Fatal("Expected kafka-server.domain.com as tls.ServerName when tls config ServerName is not provided") + } + + c := &tls.Config{ServerName: "kafka-server-other.domain.com"} + if validServerNameTLS("", c).ServerName != "kafka-server-other.domain.com" { + t.Fatal("Expected kafka-server-other.domain.com as tls.ServerName when tls config ServerName is provided") + } +}