Skip to content

Commit

Permalink
feat: add ChannelCollector + BufferedChannelCollector
Browse files Browse the repository at this point in the history
  • Loading branch information
hhu committed Oct 12, 2022
1 parent a2c5202 commit fd76d7e
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 0 deletions.
38 changes: 38 additions & 0 deletions channel.go
Expand Up @@ -2,6 +2,7 @@ package lo

import (
"math/rand"
"sync"
"time"
)

Expand Down Expand Up @@ -237,3 +238,40 @@ func BatchWithTimeout[T any](ch <-chan T, size int, timeout time.Duration) (coll

return buffer, index, time.Since(now), true
}

// ChannelCollector collect messages from multiple input channels into one channel.
// output messages has no order guarantee
func ChannelCollector[T any](upstreams ...<-chan T) <-chan T {
out := make(chan T)
channelCollector(out, upstreams...)
return out
}

// BufferedChannelCollector collect messages from multiple input channels into one buffered channel.
// output messages has no order guarantee
func BufferedChannelCollector[T any](channelBufferCap int, upstreams ...<-chan T) <-chan T {
out := make(chan T, channelBufferCap)
channelCollector(out, upstreams...)
return out
}

func channelCollector[T any](out chan T, upstreams ...<-chan T) {
var wg sync.WaitGroup

// Start an output goroutine for each input channel in upstreams.
wg.Add(len(upstreams))
for _, c := range upstreams {
go func(c <-chan T) {
for n := range c {
out <- n
}
wg.Done()
}(c)
}

// Start a goroutine to close out once all the output goroutines are done.
go func() {
wg.Wait()
close(out)
}()
}
83 changes: 83 additions & 0 deletions channel_test.go
Expand Up @@ -302,3 +302,86 @@ func TestBatchWithTimeout(t *testing.T) {
is.Equal(0, length5)
is.False(ok5)
}

func TestChannelCollector(t *testing.T) {
t.Parallel()
testWithTimeout(t, 100*time.Millisecond)
is := assert.New(t)

upstreams := createChannels[int](3, 10)
roupstreams := channelsToReadOnly(upstreams)
for i := range roupstreams {
go func(i int) {
upstreams[i] <- 1
upstreams[i] <- 1
close(upstreams[i])
}(i)
}
out := ChannelCollector(roupstreams...)
time.Sleep(10 * time.Millisecond)

is.Equal(1, len(roupstreams[0]))
is.Equal(1, len(roupstreams[1]))
is.Equal(1, len(roupstreams[2]))

// check channels allocation
is.Equal(0, len(out))
is.Equal(0, cap(out))

// check channels content
for i := 0; i < 6; i++ {
msg0, ok0 := <-out
is.Equal(true, ok0)
is.Equal(1, msg0)
}

is.Equal(0, len(roupstreams[0]))
is.Equal(0, len(roupstreams[1]))
is.Equal(0, len(roupstreams[2]))

// check it is closed
time.Sleep(10 * time.Millisecond)
msg0, ok0 := <-out
is.Equal(false, ok0)
is.Equal(0, msg0)
}

func TestBufferedChannelCollector(t *testing.T) {
t.Parallel()
testWithTimeout(t, 100*time.Millisecond)
is := assert.New(t)

upstreams := createChannels[int](3, 10)
roupstreams := channelsToReadOnly(upstreams)
for i := range roupstreams {
go func(i int) {
upstreams[i] <- 1
upstreams[i] <- 1
close(upstreams[i])
}(i)
}
out := BufferedChannelCollector(10, roupstreams...)
time.Sleep(10 * time.Millisecond)

// check input channels
is.Equal(0, len(roupstreams[0]))
is.Equal(0, len(roupstreams[1]))
is.Equal(0, len(roupstreams[2]))

// check channels allocation
is.Equal(6, len(out))
is.Equal(10, cap(out))

// check channels content
for i := 0; i < 6; i++ {
msg0, ok0 := <-out
is.Equal(true, ok0)
is.Equal(1, msg0)
}

// check it is closed
time.Sleep(10 * time.Millisecond)
msg0, ok0 := <-out
is.Equal(false, ok0)
is.Equal(0, msg0)
}

0 comments on commit fd76d7e

Please sign in to comment.