From d871f336e630a41a3ed0e238d921d54225856b87 Mon Sep 17 00:00:00 2001 From: Diego Alvarez Date: Tue, 12 May 2020 15:22:50 -0700 Subject: [PATCH 1/3] Set server name only for the current broker Fixes https://github.com/Shopify/sarama/issues/1700 --- broker.go | 42 +++++++++++++++++++++++------------------- client_tls_test.go | 15 +++++++++++++++ 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/broker.go b/broker.go index cda5da8f7..916dd0d6e 100644 --- a/broker.go +++ b/broker.go @@ -162,29 +162,11 @@ func (b *Broker) Open(conf *Config) error { atomic.StoreInt32(&b.opened, 0) return } - if conf.Net.TLS.Enable { - Logger.Printf("Using tls") - cfg := conf.Net.TLS.Config - if cfg == nil { - cfg = &tls.Config{} - } - // If no ServerName is set, infer the ServerName - // from the hostname we're connecting to. - // Gets the hostname as tls.DialWithDialer does it. - if cfg.ServerName == "" { - colonPos := strings.LastIndex(b.addr, ":") - if colonPos == -1 { - colonPos = len(b.addr) - } - hostname := b.addr[:colonPos] - cfg.ServerName = hostname - } - b.conn = tls.Client(b.conn, cfg) + b.conn = tls.Client(b.conn, validServerNameTLS(b.addr, conf.Net.TLS.Config)) } b.conn = newBufConn(b.conn) - b.conf = conf // Create or reuse the global metrics shared between brokers @@ -1440,3 +1422,25 @@ func (b *Broker) registerCounter(name string) metrics.Counter { b.registeredMetrics = append(b.registeredMetrics, nameForBroker) return metrics.GetOrRegisterCounter(nameForBroker, b.conf.MetricRegistry) } + +func validServerNameTLS(addr string, conf *tls.Config) *tls.Config { + cfg := conf + if cfg == nil { + cfg = &tls.Config{} + } + // If no ServerName is set, infer the ServerName + // from the hostname we're connecting to. + // Gets the hostname as tls.DialWithDialer does it. + if cfg.ServerName == "" { + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + // Make a copy to avoid polluting argument or default. + c := cfg.Clone() + c.ServerName = hostname + cfg = c + } + return cfg +} diff --git a/client_tls_test.go b/client_tls_test.go index 750145610..2d9d8fdd5 100644 --- a/client_tls_test.go +++ b/client_tls_test.go @@ -210,3 +210,18 @@ func doListenerTLSTest(t *testing.T, expectSuccess bool, serverConfig, clientCon } } } + +func TestSetServerName(t *testing.T) { + if validServerNameTLS("kafka-server.domain.com", nil).ServerName != "kafka-server.domain.com" { + t.Fatal("Expected kafka-server.domain.com as tls.ServerName when tls config is nil") + } + + if validServerNameTLS("kafka-server.domain.com", &tls.Config{}).ServerName != "kafka-server.domain.com" { + t.Fatal("Expected kafka-server.domain.com as tls.ServerName when tls config ServerName is not provided") + } + + c := &tls.Config{ServerName: "kafka-server-other.domain.com"} + if validServerNameTLS("", c).ServerName != "kafka-server-other.domain.com" { + t.Fatal("Expected kafka-server-other.domain.com as tls.ServerName when tls config ServerName is provided") + } +} From 1628dc1dbfa4b0676d77af43b8c0b9f17d7a59e0 Mon Sep 17 00:00:00 2001 From: Diego Alvarez Date: Thu, 14 May 2020 16:10:05 -0700 Subject: [PATCH 2/3] use net.SplitHostPort and don't separate conf var --- broker.go | 18 ++++++------------ client_tls_test.go | 8 ++++++-- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/broker.go b/broker.go index 916dd0d6e..e8bb2ed10 100644 --- a/broker.go +++ b/broker.go @@ -1423,23 +1423,17 @@ func (b *Broker) registerCounter(name string) metrics.Counter { return metrics.GetOrRegisterCounter(nameForBroker, b.conf.MetricRegistry) } -func validServerNameTLS(addr string, conf *tls.Config) *tls.Config { - cfg := conf +func validServerNameTLS(addr string, cfg *tls.Config) *tls.Config { if cfg == nil { cfg = &tls.Config{} } - // If no ServerName is set, infer the ServerName - // from the hostname we're connecting to. - // Gets the hostname as tls.DialWithDialer does it. if cfg.ServerName == "" { - colonPos := strings.LastIndex(addr, ":") - if colonPos == -1 { - colonPos = len(addr) - } - hostname := addr[:colonPos] - // Make a copy to avoid polluting argument or default. c := cfg.Clone() - c.ServerName = hostname + sn, _, err := net.SplitHostPort(addr) + if err != nil { + Logger.Println(fmt.Errorf("failed to get ServerName from addr %w", err)) + } + c.ServerName = sn cfg = c } return cfg diff --git a/client_tls_test.go b/client_tls_test.go index 2d9d8fdd5..9731e44b7 100644 --- a/client_tls_test.go +++ b/client_tls_test.go @@ -212,11 +212,11 @@ func doListenerTLSTest(t *testing.T, expectSuccess bool, serverConfig, clientCon } func TestSetServerName(t *testing.T) { - if validServerNameTLS("kafka-server.domain.com", nil).ServerName != "kafka-server.domain.com" { + if validServerNameTLS("kafka-server.domain.com:9093", nil).ServerName != "kafka-server.domain.com" { t.Fatal("Expected kafka-server.domain.com as tls.ServerName when tls config is nil") } - if validServerNameTLS("kafka-server.domain.com", &tls.Config{}).ServerName != "kafka-server.domain.com" { + if validServerNameTLS("kafka-server.domain.com:9093", &tls.Config{}).ServerName != "kafka-server.domain.com" { t.Fatal("Expected kafka-server.domain.com as tls.ServerName when tls config ServerName is not provided") } @@ -224,4 +224,8 @@ func TestSetServerName(t *testing.T) { if validServerNameTLS("", c).ServerName != "kafka-server-other.domain.com" { t.Fatal("Expected kafka-server-other.domain.com as tls.ServerName when tls config ServerName is provided") } + + if validServerNameTLS("host-no-port", nil).ServerName != "" { + t.Fatal("Expected empty ServerName as the broker addr is missing the port") + } } From 5ff9581b104a24f3b99dfa18704ada932af74a01 Mon Sep 17 00:00:00 2001 From: Diego Alvarez Date: Thu, 14 May 2020 16:14:48 -0700 Subject: [PATCH 3/3] more go like code, returning earlier --- broker.go | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/broker.go b/broker.go index e8bb2ed10..232559aea 100644 --- a/broker.go +++ b/broker.go @@ -1427,14 +1427,15 @@ func validServerNameTLS(addr string, cfg *tls.Config) *tls.Config { if cfg == nil { cfg = &tls.Config{} } - if cfg.ServerName == "" { - c := cfg.Clone() - sn, _, err := net.SplitHostPort(addr) - if err != nil { - Logger.Println(fmt.Errorf("failed to get ServerName from addr %w", err)) - } - c.ServerName = sn - cfg = c + if cfg.ServerName != "" { + return cfg + } + + c := cfg.Clone() + sn, _, err := net.SplitHostPort(addr) + if err != nil { + Logger.Println(fmt.Errorf("failed to get ServerName from addr %w", err)) } - return cfg + c.ServerName = sn + return c }