Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ChannelMerge #241

Merged
merged 1 commit into from Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 26 additions & 0 deletions channel.go
Expand Up @@ -2,6 +2,7 @@ package lo

import (
"math/rand"
"sync"
"time"
)

Expand Down Expand Up @@ -237,3 +238,28 @@ func BatchWithTimeout[T any](ch <-chan T, size int, timeout time.Duration) (coll

return buffer, index, time.Since(now), true
}

// ChannelMerge collect messages from multiple input channels into one buffered channel.
// output messages has no order guarantee
func ChannelMerge[T any](channelBufferCap int, upstreams ...<-chan T) <-chan T {
out := make(chan T, channelBufferCap)
var wg sync.WaitGroup

// Start an output goroutine for each input channel in upstreams.
wg.Add(len(upstreams))
for _, c := range upstreams {
go func(c <-chan T) {
for n := range c {
out <- n
}
wg.Done()
}(c)
}

// Start a goroutine to close out once all the output goroutines are done.
go func() {
wg.Wait()
close(out)
}()
return out
}
40 changes: 40 additions & 0 deletions channel_test.go
Expand Up @@ -302,3 +302,43 @@ func TestBatchWithTimeout(t *testing.T) {
is.Equal(0, length5)
is.False(ok5)
}

func TestChannelMerge(t *testing.T) {
t.Parallel()
testWithTimeout(t, 100*time.Millisecond)
is := assert.New(t)

upstreams := createChannels[int](3, 10)
roupstreams := channelsToReadOnly(upstreams)
for i := range roupstreams {
go func(i int) {
upstreams[i] <- 1
upstreams[i] <- 1
close(upstreams[i])
}(i)
}
out := ChannelMerge(10, roupstreams...)
time.Sleep(10 * time.Millisecond)

// check input channels
is.Equal(0, len(roupstreams[0]))
is.Equal(0, len(roupstreams[1]))
is.Equal(0, len(roupstreams[2]))

// check channels allocation
is.Equal(6, len(out))
is.Equal(10, cap(out))

// check channels content
for i := 0; i < 6; i++ {
msg0, ok0 := <-out
is.Equal(true, ok0)
is.Equal(1, msg0)
}

// check it is closed
time.Sleep(10 * time.Millisecond)
msg0, ok0 := <-out
is.Equal(false, ok0)
is.Equal(0, msg0)
}