diff --git a/async_producer.go b/async_producer.go index d0ce01b66e..f1ffc8f92d 100644 --- a/async_producer.go +++ b/async_producer.go @@ -348,6 +348,10 @@ func (p *asyncProducer) dispatcher() { p.inFlight.Add(1) } + for _, interceptor := range p.conf.Producer.Interceptors { + msg.safelyApplyInterceptor(interceptor) + } + version := 1 if p.conf.Version.IsAtLeast(V0_11_0_0) { version = 2 diff --git a/async_producer_test.go b/async_producer_test.go index 46b97790a9..338ca3656b 100644 --- a/async_producer_test.go +++ b/async_producer_test.go @@ -5,6 +5,7 @@ import ( "log" "os" "os/signal" + "strconv" "sync" "sync/atomic" "testing" @@ -1230,6 +1231,126 @@ func TestBrokerProducerShutdown(t *testing.T) { mockBroker.Close() } +type appendInterceptor struct { + i int +} + +func (b *appendInterceptor) onSend(msg *ProducerMessage) { + if b.i < 0 { + panic("hey, the interceptor have failed") + } + v, _ := msg.Value.Encode() + msg.Value = StringEncoder(string(v) + strconv.Itoa(b.i)) + b.i++ +} + +func (b *appendInterceptor) onConsume(msg *ConsumerMessage) { + if b.i < 0 { + panic("hey, the interceptor have failed") + } + msg.Value = []byte(string(msg.Value) + strconv.Itoa(b.i)) + b.i++ +} + +func testProducerInterceptor( + t *testing.T, + interceptors []ProducerInterceptor, + expectationFn func(*testing.T, int, *ProducerMessage), +) { + seedBroker := NewMockBroker(t, 1) + leader := NewMockBroker(t, 2) + metadataLeader := new(MetadataResponse) + metadataLeader.AddBroker(leader.Addr(), leader.BrokerID()) + metadataLeader.AddTopicPartition("my_topic", 0, leader.BrokerID(), nil, nil, nil, ErrNoError) + seedBroker.Returns(metadataLeader) + + config := NewConfig() + config.Producer.Flush.Messages = 10 + config.Producer.Return.Successes = true + config.Producer.Interceptors = interceptors + producer, err := NewAsyncProducer([]string{seedBroker.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 10; i++ { + producer.Input() <- &ProducerMessage{Topic: "my_topic", Key: nil, Value: StringEncoder(TestMessage)} + } + + prodSuccess := new(ProduceResponse) + prodSuccess.AddTopicPartition("my_topic", 0, ErrNoError) + leader.Returns(prodSuccess) + + for i := 0; i < 10; i++ { + select { + case msg := <-producer.Errors(): + t.Error(msg.Err) + case msg := <-producer.Successes(): + expectationFn(t, i, msg) + } + } + + closeProducer(t, producer) + leader.Close() + seedBroker.Close() +} + +func TestAsyncProducerInterceptors(t *testing.T) { + tests := []struct { + name string + interceptors []ProducerInterceptor + expectationFn func(*testing.T, int, *ProducerMessage) + }{ + { + name: "intercept messages", + interceptors: []ProducerInterceptor{&appendInterceptor{i: 0}}, + expectationFn: func(t *testing.T, i int, msg *ProducerMessage) { + v, _ := msg.Value.Encode() + expected := TestMessage + strconv.Itoa(i) + if string(v) != expected { + t.Errorf("Interceptor should have incremented the value, got %s, expected %s", v, expected) + } + }, + }, + { + name: "interceptor chain", + interceptors: []ProducerInterceptor{&appendInterceptor{i: 0}, &appendInterceptor{i: 1000}}, + expectationFn: func(t *testing.T, i int, msg *ProducerMessage) { + v, _ := msg.Value.Encode() + expected := TestMessage + strconv.Itoa(i) + strconv.Itoa(i+1000) + if string(v) != expected { + t.Errorf("Interceptor should have incremented the value, got %s, expected %s", v, expected) + } + }, + }, + { + name: "interceptor chain with one interceptor failing", + interceptors: []ProducerInterceptor{&appendInterceptor{i: -1}, &appendInterceptor{i: 1000}}, + expectationFn: func(t *testing.T, i int, msg *ProducerMessage) { + v, _ := msg.Value.Encode() + expected := TestMessage + strconv.Itoa(i+1000) + if string(v) != expected { + t.Errorf("Interceptor should have incremented the value, got %s, expected %s", v, expected) + } + }, + }, + { + name: "interceptor chain with all interceptors failing", + interceptors: []ProducerInterceptor{&appendInterceptor{i: -1}, &appendInterceptor{i: -1}}, + expectationFn: func(t *testing.T, i int, msg *ProducerMessage) { + v, _ := msg.Value.Encode() + expected := TestMessage + if string(v) != expected { + t.Errorf("Interceptor should have not changed the value, got %s, expected %s", v, expected) + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { testProducerInterceptor(t, tt.interceptors, tt.expectationFn) }) + } +} + // This example shows how to use the producer while simultaneously // reading the Errors channel to know about any failures. func ExampleAsyncProducer_select() { diff --git a/config.go b/config.go index 0ce308f80a..b346c58a12 100644 --- a/config.go +++ b/config.go @@ -229,6 +229,14 @@ type Config struct { // `Backoff` if set. BackoffFunc func(retries, maxRetries int) time.Duration } + + // Interceptors to be called when the producer dispatcher reads the + // message for the first time. Interceptors allows to intercept and + // possible mutate the message before they are published to Kafka + // cluster. *ProducerMessage modified by the first interceptor's + // onSend() is passed to the second interceptor onSend(), and so on in + // the interceptor chain. + Interceptors []ProducerInterceptor } // Consumer is the namespace for configuration related to consuming messages, @@ -391,6 +399,14 @@ type Config struct { // - use `ReadUncommitted` (default) to consume and return all messages in message channel // - use `ReadCommitted` to hide messages that are part of an aborted transaction IsolationLevel IsolationLevel + + // Interceptors to be called just before the record is sent to the + // messages channel. Interceptors allows to intercept and possible + // mutate the message before they are returned to the client. + // *ConsumerMessage modified by the first interceptor's onConsume() is + // passed to the second interceptor onConsume(), and so on in the + // interceptor chain. + Interceptors []ConsumerInterceptor } // A user-provided string sent with every request to the brokers for logging, diff --git a/consumer.go b/consumer.go index e16d08aa9f..b0cdfc3a70 100644 --- a/consumer.go +++ b/consumer.go @@ -451,6 +451,9 @@ feederLoop: } for i, msg := range msgs { + for _, interceptor := range child.conf.Consumer.Interceptors { + msg.safelyApplyInterceptor(interceptor) + } messageSelect: select { case <-child.dying: diff --git a/consumer_test.go b/consumer_test.go index d0617f2ab3..230582e5ff 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -5,6 +5,7 @@ import ( "os" "os/signal" "reflect" + "strconv" "sync" "sync/atomic" "testing" @@ -1342,3 +1343,112 @@ func Test_partitionConsumer_parseResponse(t *testing.T) { }) } } + +func testConsumerInterceptor( + t *testing.T, + interceptors []ConsumerInterceptor, + expectationFn func(*testing.T, int, *ConsumerMessage), +) { + // Given + broker0 := NewMockBroker(t, 0) + + mockFetchResponse := NewMockFetchResponse(t, 1) + for i := 0; i < 10; i++ { + mockFetchResponse.SetMessage("my_topic", 0, int64(i), testMsg) + } + + broker0.SetHandlerByMap(map[string]MockResponse{ + "MetadataRequest": NewMockMetadataResponse(t). + SetBroker(broker0.Addr(), broker0.BrokerID()). + SetLeader("my_topic", 0, broker0.BrokerID()), + "OffsetRequest": NewMockOffsetResponse(t). + SetOffset("my_topic", 0, OffsetOldest, 0). + SetOffset("my_topic", 0, OffsetNewest, 0), + "FetchRequest": mockFetchResponse, + }) + config := NewConfig() + config.Consumer.Interceptors = interceptors + // When + master, err := NewConsumer([]string{broker0.Addr()}, config) + if err != nil { + t.Fatal(err) + } + + consumer, err := master.ConsumePartition("my_topic", 0, 0) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 10; i++ { + select { + case msg := <-consumer.Messages(): + expectationFn(t, i, msg) + case err := <-consumer.Errors(): + t.Error(err) + } + } + + safeClose(t, consumer) + safeClose(t, master) + broker0.Close() +} + +func TestConsumerInterceptors(t *testing.T) { + tests := []struct { + name string + interceptors []ConsumerInterceptor + expectationFn func(*testing.T, int, *ConsumerMessage) + }{ + { + name: "intercept messages", + interceptors: []ConsumerInterceptor{&appendInterceptor{i: 0}}, + expectationFn: func(t *testing.T, i int, msg *ConsumerMessage) { + ev, _ := testMsg.Encode() + expected := string(ev) + strconv.Itoa(i) + v := string(msg.Value) + if string(v) != expected { + t.Errorf("Interceptor should have incremented the value, got %s, expected %s", v, expected) + } + }, + }, + { + name: "interceptor chain", + interceptors: []ConsumerInterceptor{&appendInterceptor{i: 0}, &appendInterceptor{i: 1000}}, + expectationFn: func(t *testing.T, i int, msg *ConsumerMessage) { + ev, _ := testMsg.Encode() + expected := string(ev) + strconv.Itoa(i) + strconv.Itoa(i+1000) + v := string(msg.Value) + if string(v) != expected { + t.Errorf("Interceptor should have incremented the value, got %s, expected %s", v, expected) + } + }, + }, + { + name: "interceptor chain with one interceptor failing", + interceptors: []ConsumerInterceptor{&appendInterceptor{i: -1}, &appendInterceptor{i: 1000}}, + expectationFn: func(t *testing.T, i int, msg *ConsumerMessage) { + ev, _ := testMsg.Encode() + expected := string(ev) + strconv.Itoa(i+1000) + v := string(msg.Value) + if string(v) != expected { + t.Errorf("Interceptor should have not changed the value, got %s, expected %s", v, expected) + } + }, + }, + { + name: "interceptor chain with all interceptors failing", + interceptors: []ConsumerInterceptor{&appendInterceptor{i: -1}, &appendInterceptor{i: -1}}, + expectationFn: func(t *testing.T, i int, msg *ConsumerMessage) { + ev, _ := testMsg.Encode() + expected := string(ev) + v := string(msg.Value) + if string(v) != expected { + t.Errorf("Interceptor should have incremented the value, got %s, expected %s", v, expected) + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { testConsumerInterceptor(t, tt.interceptors, tt.expectationFn) }) + } +} diff --git a/functional_producer_test.go b/functional_producer_test.go index 9e153b0d94..474645fd7c 100644 --- a/functional_producer_test.go +++ b/functional_producer_test.go @@ -3,6 +3,7 @@ package sarama import ( "fmt" "os" + "strconv" "strings" "sync" "testing" @@ -25,6 +26,13 @@ func TestFuncProducingGzip(t *testing.T) { testProducingMessages(t, config) } +func TestFuncProducingZstd(t *testing.T) { + config := NewConfig() + config.Version = V2_1_0_0 + config.Producer.Compression = CompressionZSTD + testProducingMessages(t, config) +} + func TestFuncProducingSnappy(t *testing.T) { config := NewConfig() config.Producer.Compression = CompressionSnappy @@ -181,6 +189,78 @@ func TestFuncProducingIdempotentWithBrokerFailure(t *testing.T) { } } +func TestInterceptors(t *testing.T) { + config := NewConfig() + setupFunctionalTest(t) + defer teardownFunctionalTest(t) + + config.Producer.Return.Successes = true + config.Consumer.Return.Errors = true + config.Producer.Interceptors = []ProducerInterceptor{&appendInterceptor{i: 0}, &appendInterceptor{i: 100}} + config.Consumer.Interceptors = []ConsumerInterceptor{&appendInterceptor{i: 20}} + + client, err := NewClient(kafkaBrokers, config) + if err != nil { + t.Fatal(err) + } + + initialOffset, err := client.GetOffset("test.1", 0, OffsetNewest) + if err != nil { + t.Fatal(err) + } + + producer, err := NewAsyncProducerFromClient(client) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 10; i++ { + producer.Input() <- &ProducerMessage{Topic: "test.1", Key: nil, Value: StringEncoder(TestMessage)} + } + + for i := 0; i < 10; i++ { + select { + case msg := <-producer.Errors(): + t.Error(msg.Err) + case msg := <-producer.Successes(): + v, _ := msg.Value.Encode() + expected := TestMessage + strconv.Itoa(i) + strconv.Itoa(i+100) + if string(v) != expected { + t.Errorf("Interceptor should have incremented the value, got %s, expected %s", v, expected) + } + } + } + safeClose(t, producer) + + master, err := NewConsumerFromClient(client) + if err != nil { + t.Fatal(err) + } + consumer, err := master.ConsumePartition("test.1", 0, initialOffset) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 10; i++ { + select { + case <-time.After(10 * time.Second): + t.Fatal("Not received any more events in the last 10 seconds.") + case err := <-consumer.Errors(): + t.Error(err) + case msg := <-consumer.Messages(): + // producer interceptors: strconv.Itoa(i) + strconv.Itoa(i+100) + // consumer interceptor: strconv.Itoa(i+20) + expected := TestMessage + strconv.Itoa(i) + strconv.Itoa(i+100) + strconv.Itoa(i+20) + v := string(msg.Value) + if string(v) != expected { + t.Errorf("Interceptor should have incremented the value, got %s, expected %s", v, expected) + } + } + } + safeClose(t, consumer) + safeClose(t, client) +} + func testProducingMessages(t *testing.T, config *Config) { setupFunctionalTest(t) defer teardownFunctionalTest(t) diff --git a/interceptors.go b/interceptors.go new file mode 100644 index 0000000000..c6eaeda9d6 --- /dev/null +++ b/interceptors.go @@ -0,0 +1,43 @@ +package sarama + +// ProducerInterceptor allows you to intercept (and possibly mutate) the records +// received by the producer before they are published to the Kafka cluster. +// https://cwiki.apache.org/confluence/display/KAFKA/KIP-42%3A+Add+Producer+and+Consumer+Interceptors#KIP42:AddProducerandConsumerInterceptors-Motivation +type ProducerInterceptor interface { + + // onSend is called when the producer message is intercepted. Please avoid + // modifying the message until it's safe to do so, as this is _not_ a copy + // of the message. + onSend(*ProducerMessage) +} + +// ConsumerInterceptor allows you to intercept (and possibly mutate) the records +// received by the consumer before they are sent to the messages channel. +// https://cwiki.apache.org/confluence/display/KAFKA/KIP-42%3A+Add+Producer+and+Consumer+Interceptors#KIP42:AddProducerandConsumerInterceptors-Motivation +type ConsumerInterceptor interface { + + // onConsume is called when the consumed message is intercepted. Please + // avoid modifying the message until it's safe to do so, as this is _not_ a + // copy of the message. + onConsume(*ConsumerMessage) +} + +func (msg *ProducerMessage) safelyApplyInterceptor(interceptor ProducerInterceptor) { + defer func() { + if r := recover(); r != nil { + Logger.Printf("Error when calling producer interceptor: %s, %w\n", interceptor, r) + } + }() + + interceptor.onSend(msg) +} + +func (msg *ConsumerMessage) safelyApplyInterceptor(interceptor ConsumerInterceptor) { + defer func() { + if r := recover(); r != nil { + Logger.Printf("Error when calling consumer interceptor: %s, %w\n", interceptor, r) + } + }() + + interceptor.onConsume(msg) +}