Skip to content

Commit

Permalink
fix zstd decoder leak (#543)
Browse files Browse the repository at this point in the history
* fix zstd decoder leak

* fix tests

* fix panic

* fix tests (2)

* fix tests (3)

* fix tests (4)

* move ConnWaitGroup to testing package

* fix zstd codec

* Update compress/zstd/zstd.go

Co-authored-by: Nicholas Sun <olassun2@gmail.com>

* PR feedback

Co-authored-by: Nicholas Sun <olassun2@gmail.com>
  • Loading branch information
Achille and nlsun committed Oct 21, 2020
1 parent b7a001a commit e6b8599
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 137 deletions.
78 changes: 34 additions & 44 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,43 @@ import (
"io"
"math/rand"
"net"
"sync"
"testing"
"time"

"github.com/segmentio/kafka-go/compress"
ktesting "github.com/segmentio/kafka-go/testing"
)

func newLocalClientAndTopic() (*Client, string, func()) {
topic := makeTopic()
client, shutdown := newClient(TCP("localhost"))
client, shutdown := newLocalClientWithTopic(topic, 1)
return client, topic, shutdown
}

func newLocalClientWithTopic(topic string, partitions int) (*Client, func()) {
client, shutdown := newLocalClient()
if err := clientCreateTopic(client, topic, partitions); err != nil {
shutdown()
panic(err)
}
return client, func() {
client.DeleteTopics(context.Background(), &DeleteTopicsRequest{
Topics: []string{topic},
})
shutdown()
}
}

func clientCreateTopic(client *Client, topic string, partitions int) error {
_, err := client.CreateTopics(context.Background(), &CreateTopicsRequest{
Topics: []TopicConfig{{
Topic: topic,
NumPartitions: 1,
NumPartitions: partitions,
ReplicationFactor: 1,
}},
})
if err != nil {
shutdown()
panic(err)
return err
}

// Topic creation seems to be asynchronous. Metadata for the topic partition
Expand All @@ -48,21 +64,16 @@ func newLocalClientAndTopic() (*Client, string, func()) {
time.Sleep(100 * time.Millisecond)
}

return client, topic, func() {
client.DeleteTopics(context.Background(), &DeleteTopicsRequest{
Topics: []string{topic},
})
shutdown()
}
return nil
}

func newLocalClient() (*Client, func()) {
return newClient(TCP("localhost"))
}

func newClient(addr net.Addr) (*Client, func()) {
conns := &connWaitGroup{
dial: (&net.Dialer{}).DialContext,
conns := &ktesting.ConnWaitGroup{
DialFunc: (&net.Dialer{}).DialContext,
}

transport := &Transport{
Expand All @@ -79,31 +90,6 @@ func newClient(addr net.Addr) (*Client, func()) {
return client, func() { transport.CloseIdleConnections(); conns.Wait() }
}

type connWaitGroup struct {
dial func(context.Context, string, string) (net.Conn, error)
sync.WaitGroup
}

func (g *connWaitGroup) Dial(ctx context.Context, network, address string) (net.Conn, error) {
c, err := g.dial(ctx, network, address)
if err != nil {
return nil, err
}
g.Add(1)
return &groupConn{Conn: c, group: g}, nil
}

type groupConn struct {
net.Conn
group *connWaitGroup
once sync.Once
}

func (c *groupConn) Close() error {
defer c.once.Do(c.group.Done)
return c.Conn.Close()
}

func TestClient(t *testing.T) {
tests := []struct {
scenario string
Expand All @@ -121,20 +107,23 @@ func TestClient(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

c := &Client{Addr: TCP("localhost:9092")}
testFunc(t, ctx, c)
client, shutdown := newLocalClient()
defer shutdown()

testFunc(t, ctx, client)
})
}
}

func testConsumerGroupFetchOffsets(t *testing.T, ctx context.Context, c *Client) {
func testConsumerGroupFetchOffsets(t *testing.T, ctx context.Context, client *Client) {
const totalMessages = 144
const partitions = 12
const msgPerPartition = totalMessages / partitions

topic := makeTopic()
createTopic(t, topic, partitions)
defer deleteTopic(t, topic)
if err := clientCreateTopic(client, topic, partitions); err != nil {
t.Fatal(err)
}

groupId := makeGroupID()
brokers := []string{"localhost:9092"}
Expand All @@ -144,6 +133,7 @@ func testConsumerGroupFetchOffsets(t *testing.T, ctx context.Context, c *Client)
Topic: topic,
Balancer: &RoundRobin{},
BatchSize: 1,
Transport: client.Transport,
}
if err := writer.WriteMessages(ctx, makeTestSequence(totalMessages)...); err != nil {
t.Fatalf("bad write messages: %v", err)
Expand Down Expand Up @@ -172,7 +162,7 @@ func testConsumerGroupFetchOffsets(t *testing.T, ctx context.Context, c *Client)
}
}

offsets, err := c.ConsumerOffsets(ctx, TopicAndGroup{GroupId: groupId, Topic: topic})
offsets, err := client.ConsumerOffsets(ctx, TopicAndGroup{GroupId: groupId, Topic: topic})
if err != nil {
t.Fatal(err)
}
Expand Down
118 changes: 71 additions & 47 deletions compress/compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,24 @@ func testEncodeDecode(t *testing.T, m kafka.Message, codec pkg.Codec) {
t.Run("encode with "+codec.Name(), func(t *testing.T) {
r1, err = compress(codec, m.Value)
if err != nil {
t.Error(err)
t.Fatal(err)
}
})

t.Run("decode with "+codec.Name(), func(t *testing.T) {
if r1 == nil {
if r1, err = compress(codec, m.Value); err != nil {
t.Fatal(err)
}
}
r2, err = decompress(codec, r1)
if err != nil {
t.Error(err)
t.Fatal(err)
}
if string(r2) != "message" {
t.Error("bad message")
t.Log("got: ", string(r2))
t.Log("expected: ", string(m.Value))
t.Logf("expected: %q", string(m.Value))
t.Logf("got: %q", string(r2))
}
})
}
Expand All @@ -116,15 +121,16 @@ func TestCompressedMessages(t *testing.T) {
}

func testCompressedMessages(t *testing.T, codec pkg.Codec) {
t.Run("produce/consume with"+codec.Name(), func(t *testing.T) {
topic := createTopic(t, 1)
defer deleteTopic(t, topic)
t.Run(codec.Name(), func(t *testing.T) {
client, topic, shutdown := newLocalClientAndTopic()
defer shutdown()

w := &kafka.Writer{
Addr: kafka.TCP("127.0.0.1:9092"),
Topic: topic,
Compression: kafka.Compression(codec.Code()),
BatchTimeout: 10 * time.Millisecond,
Transport: client.Transport,
}
defer w.Close()

Expand Down Expand Up @@ -185,19 +191,23 @@ func testCompressedMessages(t *testing.T, codec pkg.Codec) {
}

func TestMixedCompressedMessages(t *testing.T) {
topic := createTopic(t, 1)
defer deleteTopic(t, topic)
client, topic, shutdown := newLocalClientAndTopic()
defer shutdown()

offset := 0
var values []string
produce := func(n int, codec pkg.Codec) {
w := &kafka.Writer{
Addr: kafka.TCP("127.0.0.1:9092"),
Topic: topic,
Compression: kafka.Compression(codec.Code()),
Addr: kafka.TCP("127.0.0.1:9092"),
Topic: topic,
Transport: client.Transport,
}
defer w.Close()

if codec != nil {
w.Compression = kafka.Compression(codec.Code())
}

msgs := make([]kafka.Message, n)
for i := range msgs {
value := fmt.Sprintf("Hello World %d!", offset)
Expand Down Expand Up @@ -407,58 +417,72 @@ func benchmarkCompression(b *testing.B, codec pkg.Codec, buf *bytes.Buffer, payl
return 1 - (float64(buf.Len()) / float64(len(payload)))
}

func init() {
rand.Seed(time.Now().UnixNano())
}

func makeTopic() string {
return fmt.Sprintf("kafka-go-%016x", rand.Int63())
}

func createTopic(t *testing.T, partitions int) string {
func newLocalClientAndTopic() (*kafka.Client, string, func()) {
topic := makeTopic()

conn, err := kafka.Dial("tcp", "localhost:9092")
client, shutdown := newLocalClient()

_, err := client.CreateTopics(context.Background(), &kafka.CreateTopicsRequest{
Topics: []kafka.TopicConfig{{
Topic: topic,
NumPartitions: 1,
ReplicationFactor: 1,
}},
})
if err != nil {
t.Fatal(err)
shutdown()
panic(err)
}
defer conn.Close()

err = conn.CreateTopics(kafka.TopicConfig{
Topic: topic,
NumPartitions: partitions,
ReplicationFactor: 1,
})
// Topic creation seems to be asynchronous. Metadata for the topic partition
// layout in the cluster is available in the controller before being synced
// with the other brokers, which causes "Error:[3] Unknown Topic Or Partition"
// when sending requests to the partition leaders.
for i := 0; i < 20; i++ {
r, err := client.Fetch(context.Background(), &kafka.FetchRequest{
Topic: topic,
Partition: 0,
Offset: 0,
})
if err == nil && r.Error == nil {
break
}
time.Sleep(100 * time.Millisecond)
}

switch err {
case nil:
// ok
case kafka.TopicAlreadyExists:
// ok
default:
t.Error("bad createTopics", err)
t.FailNow()
return client, topic, func() {
client.DeleteTopics(context.Background(), &kafka.DeleteTopicsRequest{
Topics: []string{topic},
})
shutdown()
}
}

return topic
func newLocalClient() (*kafka.Client, func()) {
return newClient(kafka.TCP("127.0.0.1:9092"))
}

func deleteTopic(t *testing.T, topic ...string) {
conn, err := kafka.Dial("tcp", "localhost:9092")
if err != nil {
t.Fatal(err)
func newClient(addr net.Addr) (*kafka.Client, func()) {
conns := &ktesting.ConnWaitGroup{
DialFunc: (&net.Dialer{}).DialContext,
}
defer conn.Close()

controller, err := conn.Controller()
if err != nil {
t.Fatal(err)
transport := &kafka.Transport{
Dial: conns.Dial,
}

conn, err = kafka.Dial("tcp", net.JoinHostPort(controller.Host, strconv.Itoa(controller.Port)))
if err != nil {
t.Fatal(err)
client := &kafka.Client{
Addr: addr,
Timeout: 5 * time.Second,
Transport: transport,
}

conn.SetDeadline(time.Now().Add(2 * time.Second))

if err := conn.DeleteTopics(topic...); err != nil {
t.Fatal(err)
}
return client, func() { transport.CloseIdleConnections(); conns.Wait() }
}

0 comments on commit e6b8599

Please sign in to comment.