Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent metrics leak with cleanup #2340

Merged
merged 1 commit into from Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 15 additions & 9 deletions async_producer.go
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/eapache/go-resiliency/breaker"
"github.com/eapache/queue"
"github.com/rcrowley/go-metrics"
)

// AsyncProducer publishes Kafka messages using a non-blocking API. It routes messages
Expand Down Expand Up @@ -122,6 +123,8 @@ type asyncProducer struct {
brokerLock sync.Mutex

txnmgr *transactionManager

metricsRegistry metrics.Registry
}

// NewAsyncProducer creates a new AsyncProducer using the given broker addresses and configuration.
Expand Down Expand Up @@ -154,15 +157,16 @@ func newAsyncProducer(client Client) (AsyncProducer, error) {
}

p := &asyncProducer{
client: client,
conf: client.Config(),
errors: make(chan *ProducerError),
input: make(chan *ProducerMessage),
successes: make(chan *ProducerMessage),
retries: make(chan *ProducerMessage),
brokers: make(map[*Broker]*brokerProducer),
brokerRefs: make(map[*brokerProducer]int),
txnmgr: txnmgr,
client: client,
conf: client.Config(),
errors: make(chan *ProducerError),
input: make(chan *ProducerMessage),
successes: make(chan *ProducerMessage),
retries: make(chan *ProducerMessage),
brokers: make(map[*Broker]*brokerProducer),
brokerRefs: make(map[*brokerProducer]int),
txnmgr: txnmgr,
metricsRegistry: newCleanupRegistry(client.Config().MetricRegistry),
}

// launch our singleton dispatchers
Expand Down Expand Up @@ -1134,6 +1138,8 @@ func (p *asyncProducer) shutdown() {
close(p.retries)
close(p.errors)
close(p.successes)

p.metricsRegistry.UnregisterAll()
}

func (p *asyncProducer) bumpIdempotentProducerEpoch() {
Expand Down
62 changes: 22 additions & 40 deletions broker.go
Expand Up @@ -33,8 +33,7 @@ type Broker struct {
responses chan *responsePromise
done chan bool

registeredMetrics map[string]struct{}

metricRegistry metrics.Registry
incomingByteRate metrics.Meter
requestRate metrics.Meter
fetchRate metrics.Meter
Expand Down Expand Up @@ -174,6 +173,8 @@ func (b *Broker) Open(conf *Config) error {

b.lock.Lock()

b.metricRegistry = newCleanupRegistry(conf.MetricRegistry)

go withRecover(func() {
defer func() {
b.lock.Unlock()
Expand Down Expand Up @@ -208,15 +209,15 @@ func (b *Broker) Open(conf *Config) error {
b.conf = conf

// Create or reuse the global metrics shared between brokers
b.incomingByteRate = metrics.GetOrRegisterMeter("incoming-byte-rate", conf.MetricRegistry)
b.requestRate = metrics.GetOrRegisterMeter("request-rate", conf.MetricRegistry)
b.fetchRate = metrics.GetOrRegisterMeter("consumer-fetch-rate", conf.MetricRegistry)
b.requestSize = getOrRegisterHistogram("request-size", conf.MetricRegistry)
b.requestLatency = getOrRegisterHistogram("request-latency-in-ms", conf.MetricRegistry)
b.outgoingByteRate = metrics.GetOrRegisterMeter("outgoing-byte-rate", conf.MetricRegistry)
b.responseRate = metrics.GetOrRegisterMeter("response-rate", conf.MetricRegistry)
b.responseSize = getOrRegisterHistogram("response-size", conf.MetricRegistry)
b.requestsInFlight = metrics.GetOrRegisterCounter("requests-in-flight", conf.MetricRegistry)
b.incomingByteRate = metrics.GetOrRegisterMeter("incoming-byte-rate", b.metricRegistry)
b.requestRate = metrics.GetOrRegisterMeter("request-rate", b.metricRegistry)
b.fetchRate = metrics.GetOrRegisterMeter("consumer-fetch-rate", b.metricRegistry)
b.requestSize = getOrRegisterHistogram("request-size", b.metricRegistry)
b.requestLatency = getOrRegisterHistogram("request-latency-in-ms", b.metricRegistry)
b.outgoingByteRate = metrics.GetOrRegisterMeter("outgoing-byte-rate", b.metricRegistry)
b.responseRate = metrics.GetOrRegisterMeter("response-rate", b.metricRegistry)
b.responseSize = getOrRegisterHistogram("response-size", b.metricRegistry)
b.requestsInFlight = metrics.GetOrRegisterCounter("requests-in-flight", b.metricRegistry)
// Do not gather metrics for seeded broker (only used during bootstrap) because they share
// the same id (-1) and are already exposed through the global metrics above
if b.id >= 0 && !metrics.UseNilMetrics {
Expand Down Expand Up @@ -319,7 +320,7 @@ func (b *Broker) Close() error {
b.done = nil
b.responses = nil

b.unregisterMetrics()
b.metricRegistry.UnregisterAll()

if err == nil {
DebugLogger.Printf("Closed connection to broker %s\n", b.addr)
Expand Down Expand Up @@ -435,7 +436,7 @@ func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error
return
}

if err := versionedDecode(packets, res, request.version(), b.conf.MetricRegistry); err != nil {
if err := versionedDecode(packets, res, request.version(), b.metricRegistry); err != nil {
// Malformed response
cb(nil, err)
return
Expand Down Expand Up @@ -979,7 +980,7 @@ func (b *Broker) sendInternal(rb protocolBody, promise *responsePromise) error {
}

req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
buf, err := encode(req, b.conf.MetricRegistry)
buf, err := encode(req, b.metricRegistry)
if err != nil {
return err
}
Expand Down Expand Up @@ -1029,7 +1030,7 @@ func (b *Broker) sendAndReceive(req protocolBody, res protocolBody) error {
func (b *Broker) handleResponsePromise(req protocolBody, res protocolBody, promise *responsePromise) error {
select {
case buf := <-promise.packets:
return versionedDecode(buf, res, req.version(), b.conf.MetricRegistry)
return versionedDecode(buf, res, req.version(), b.metricRegistry)
case err := <-promise.errors:
return err
}
Expand Down Expand Up @@ -1121,7 +1122,7 @@ func (b *Broker) responseReceiver() {
}

decodedHeader := responseHeader{}
err = versionedDecode(header, &decodedHeader, response.headerVersion, b.conf.MetricRegistry)
err = versionedDecode(header, &decodedHeader, response.headerVersion, b.metricRegistry)
if err != nil {
b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency)
dead = err
Expand Down Expand Up @@ -1243,7 +1244,7 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int
rb := &SaslHandshakeRequest{Mechanism: string(saslType), Version: version}

req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
buf, err := encode(req, b.conf.MetricRegistry)
buf, err := encode(req, b.metricRegistry)
if err != nil {
return err
}
Expand Down Expand Up @@ -1280,7 +1281,7 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int
b.updateIncomingCommunicationMetrics(n+8, time.Since(requestTime))
res := &SaslHandshakeResponse{}

err = versionedDecode(payload, res, 0, b.conf.MetricRegistry)
err = versionedDecode(payload, res, 0, b.metricRegistry)
if err != nil {
Logger.Printf("Failed to parse SASL handshake : %s\n", err.Error())
return err
Expand Down Expand Up @@ -1622,38 +1623,19 @@ func (b *Broker) registerMetrics() {
b.brokerThrottleTime = b.registerHistogram("throttle-time-in-ms")
}

func (b *Broker) unregisterMetrics() {
for name := range b.registeredMetrics {
b.conf.MetricRegistry.Unregister(name)
}
b.registeredMetrics = nil
}

func (b *Broker) registerMeter(name string) metrics.Meter {
nameForBroker := getMetricNameForBroker(name, b)
if b.registeredMetrics == nil {
b.registeredMetrics = map[string]struct{}{}
}
b.registeredMetrics[nameForBroker] = struct{}{}
return metrics.GetOrRegisterMeter(nameForBroker, b.conf.MetricRegistry)
return metrics.GetOrRegisterMeter(nameForBroker, b.metricRegistry)
}

func (b *Broker) registerHistogram(name string) metrics.Histogram {
nameForBroker := getMetricNameForBroker(name, b)
if b.registeredMetrics == nil {
b.registeredMetrics = map[string]struct{}{}
}
b.registeredMetrics[nameForBroker] = struct{}{}
return getOrRegisterHistogram(nameForBroker, b.conf.MetricRegistry)
return getOrRegisterHistogram(nameForBroker, b.metricRegistry)
}

func (b *Broker) registerCounter(name string) metrics.Counter {
nameForBroker := getMetricNameForBroker(name, b)
if b.registeredMetrics == nil {
b.registeredMetrics = map[string]struct{}{}
}
b.registeredMetrics[nameForBroker] = struct{}{}
return metrics.GetOrRegisterCounter(nameForBroker, b.conf.MetricRegistry)
return metrics.GetOrRegisterCounter(nameForBroker, b.metricRegistry)
}

func validServerNameTLS(addr string, cfg *tls.Config) *tls.Config {
Expand Down
24 changes: 24 additions & 0 deletions client_test.go
Expand Up @@ -8,6 +8,8 @@ import (
"syscall"
"testing"
"time"

"github.com/rcrowley/go-metrics"
)

func safeClose(t testing.TB, c io.Closer) {
Expand Down Expand Up @@ -1096,3 +1098,25 @@ func TestInitProducerIDConnectionRefused(t *testing.T) {

safeClose(t, client)
}

func TestMetricsCleanup(t *testing.T) {
seedBroker := NewMockBroker(t, 1)
seedBroker.Returns(new(MetadataResponse))

config := NewTestConfig()
metrics.GetOrRegisterMeter("a", config.MetricRegistry)

client, err := NewClient([]string{seedBroker.Addr()}, config)
if err != nil {
t.Fatal(err)
}
safeClose(t, client)

// Wait async close
time.Sleep(10 * time.Millisecond)

all := config.MetricRegistry.GetAll()
if len(all) != 1 || all["a"] == nil {
t.Errorf("excepted 1 metric, found: %v", all)
}
}
17 changes: 9 additions & 8 deletions consumer.go
Expand Up @@ -104,6 +104,7 @@ type consumer struct {
children map[string]map[int32]*partitionConsumer
brokerConsumers map[*Broker]*brokerConsumer
client Client
metricRegistry metrics.Registry
lock sync.Mutex
}

Expand Down Expand Up @@ -136,12 +137,14 @@ func newConsumer(client Client) (Consumer, error) {
conf: client.Config(),
children: make(map[string]map[int32]*partitionConsumer),
brokerConsumers: make(map[*Broker]*brokerConsumer),
metricRegistry: newCleanupRegistry(client.Config().MetricRegistry),
}

return c, nil
}

func (c *consumer) Close() error {
c.metricRegistry.UnregisterAll()
return c.client.Close()
}

Expand Down Expand Up @@ -678,13 +681,9 @@ func (child *partitionConsumer) parseRecords(batch *RecordBatch) ([]*ConsumerMes
}

func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*ConsumerMessage, error) {
var (
metricRegistry = child.conf.MetricRegistry
consumerBatchSizeMetric metrics.Histogram
)

if metricRegistry != nil {
consumerBatchSizeMetric = getOrRegisterHistogram("consumer-batch-size", metricRegistry)
var consumerBatchSizeMetric metrics.Histogram
if child.consumer != nil && child.consumer.metricRegistry != nil {
consumerBatchSizeMetric = getOrRegisterHistogram("consumer-batch-size", child.consumer.metricRegistry)
}

// If request was throttled and empty we log and return without error
Expand All @@ -709,7 +708,9 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
return nil, err
}

consumerBatchSizeMetric.Update(int64(nRecs))
if consumerBatchSizeMetric != nil {
consumerBatchSizeMetric.Update(int64(nRecs))
}

if block.PreferredReadReplica != invalidPreferredReplicaID {
child.preferredReadReplica = block.PreferredReadReplica
Expand Down
21 changes: 13 additions & 8 deletions consumer_group.go
Expand Up @@ -91,6 +91,8 @@ type consumerGroup struct {
closeOnce sync.Once

userData []byte

metricRegistry metrics.Registry
}

// NewConsumerGroup creates a new consumer group the given broker addresses and configuration.
Expand Down Expand Up @@ -129,13 +131,14 @@ func newConsumerGroup(groupID string, client Client) (ConsumerGroup, error) {
}

cg := &consumerGroup{
client: client,
consumer: consumer,
config: config,
groupID: groupID,
errors: make(chan error, config.ChannelBufferSize),
closed: make(chan none),
userData: config.Consumer.Group.Member.UserData,
client: client,
consumer: consumer,
config: config,
groupID: groupID,
errors: make(chan error, config.ChannelBufferSize),
closed: make(chan none),
userData: config.Consumer.Group.Member.UserData,
metricRegistry: newCleanupRegistry(config.MetricRegistry),
}
if client.Config().Consumer.Group.InstanceId != "" && config.Version.IsAtLeast(V2_3_0_0) {
cg.groupInstanceId = &client.Config().Consumer.Group.InstanceId
Expand Down Expand Up @@ -167,6 +170,8 @@ func (c *consumerGroup) Close() (err error) {
if e := c.client.Close(); e != nil {
err = e
}

c.metricRegistry.UnregisterAll()
})
return
}
Expand Down Expand Up @@ -261,7 +266,7 @@ func (c *consumerGroup) newSession(ctx context.Context, topics []string, handler
}

var (
metricRegistry = c.config.MetricRegistry
metricRegistry = c.metricRegistry
consumerGroupJoinTotal metrics.Counter
consumerGroupJoinFailed metrics.Counter
consumerGroupSyncTotal metrics.Counter
Expand Down