diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index cb0a6ec..1e81163 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -10,7 +10,7 @@ jobs: name: Test strategy: matrix: - go-version: [1.15.x, 1.16.x, 1.17.x] + go-version: [1.15.x, 1.16.x, 1.17.x, 1.18.x] os: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.os }} steps: diff --git a/go.mod b/go.mod index c0a402f..0434d67 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/alitto/pond -go 1.17 +go 1.18 diff --git a/group.go b/group.go new file mode 100644 index 0000000..6df9d34 --- /dev/null +++ b/group.go @@ -0,0 +1,90 @@ +package pond + +import ( + "context" + "sync" +) + +// TaskGroup represents a group of related tasks +type TaskGroup struct { + pool *WorkerPool + waitGroup sync.WaitGroup +} + +// Submit adds a task to this group and sends it to the worker pool to be executed +func (g *TaskGroup) Submit(task func()) { + g.waitGroup.Add(1) + + g.pool.Submit(func() { + defer g.waitGroup.Done() + + task() + }) +} + +// Wait waits until all the tasks in this group have completed +func (g *TaskGroup) Wait() { + + // Wait for all tasks to complete + g.waitGroup.Wait() +} + +// TaskGroup represents a group of related tasks associated to a context +type TaskGroupWithContext struct { + TaskGroup + ctx context.Context + cancel context.CancelFunc + errOnce sync.Once + err error +} + +// Submit adds a task to this group and sends it to the worker pool to be executed +func (g *TaskGroupWithContext) Submit(task func() error) { + g.waitGroup.Add(1) + + g.pool.Submit(func() { + defer g.waitGroup.Done() + + // If context has already been cancelled, skip task execution + if g.ctx != nil { + select { + case <-g.ctx.Done(): + return + default: + } + } + + // don't actually ignore errors + err := task() + if err != nil { + g.errOnce.Do(func() { + g.err = err + if g.cancel != nil { + g.cancel() + } + }) + } + }) +} + +// Wait blocks until either all the tasks submitted to this group have completed, +// one of them returned a non-nil error or the context associated to this group +// was canceled. +func (g *TaskGroupWithContext) Wait() error { + + // Wait for all tasks to complete + tasksCompleted := make(chan struct{}) + go func() { + g.waitGroup.Wait() + tasksCompleted <- struct{}{} + }() + + select { + case <-tasksCompleted: + // If context was provided, cancel it to signal all running tasks to stop + g.cancel() + case <-g.ctx.Done(): + } + + return g.err +} diff --git a/group_blackbox_test.go b/group_blackbox_test.go new file mode 100644 index 0000000..306f832 --- /dev/null +++ b/group_blackbox_test.go @@ -0,0 +1,107 @@ +package pond_test + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/alitto/pond" +) + +func TestGroupSubmit(t *testing.T) { + + pool := pond.New(5, 1000) + assertEqual(t, 0, pool.RunningWorkers()) + + // Submit groups of tasks + var doneCount, taskCount int32 + var groups []*pond.TaskGroup + for i := 0; i < 5; i++ { + group := pool.Group() + for j := 0; j < i+5; j++ { + group.Submit(func() { + time.Sleep(1 * time.Millisecond) + atomic.AddInt32(&doneCount, 1) + }) + taskCount++ + } + groups = append(groups, group) + } + + // Wait for all groups to complete + for _, group := range groups { + group.Wait() + } + + assertEqual(t, int32(taskCount), atomic.LoadInt32(&doneCount)) +} + +func TestGroupContext(t *testing.T) { + + pool := pond.New(3, 100) + assertEqual(t, 0, pool.RunningWorkers()) + + // Submit a group of tasks + var doneCount, startedCount int32 + group, ctx := pool.GroupContext(context.Background()) + for i := 0; i < 10; i++ { + group.Submit(func() error { + atomic.AddInt32(&startedCount, 1) + + select { + case <-time.After(5 * time.Millisecond): + atomic.AddInt32(&doneCount, 1) + case <-ctx.Done(): + } + + return nil + }) + } + + err := group.Wait() + assertEqual(t, nil, err) + assertEqual(t, int32(10), atomic.LoadInt32(&startedCount)) + assertEqual(t, int32(10), atomic.LoadInt32(&doneCount)) +} + +func TestGroupContextWithError(t *testing.T) { + + pool := pond.New(1, 100) + assertEqual(t, 0, pool.RunningWorkers()) + + expectedErr := errors.New("Something went wrong") + + // Submit a group of tasks + var doneCount, startedCount int32 + group, ctx := pool.GroupContext(context.Background()) + for i := 0; i < 10; i++ { + n := i + group.Submit(func() error { + atomic.AddInt32(&startedCount, 1) + + // Task number 5 fails + if n == 4 { + time.Sleep(10 * time.Millisecond) + return expectedErr + } + + select { + case <-time.After(5 * time.Millisecond): + atomic.AddInt32(&doneCount, 1) + case <-ctx.Done(): + } + + return nil + }) + } + + err := group.Wait() + assertEqual(t, expectedErr, err) + + pool.StopAndWait() + + assertEqual(t, int32(5), atomic.LoadInt32(&startedCount)) + assertEqual(t, int32(4), atomic.LoadInt32(&doneCount)) +} diff --git a/pond.go b/pond.go index 865f3e1..d025b12 100644 --- a/pond.go +++ b/pond.go @@ -487,6 +487,21 @@ func (p *WorkerPool) Group() *TaskGroup { } } +// GroupContext creates a new task group and an associated Context derived from ctx. +// +// The derived Context is canceled the first time a function submitted to the group +// returns a non-nil error or the first time Wait returns, whichever occurs first. +func (p *WorkerPool) GroupContext(ctx context.Context) (*TaskGroupWithContext, context.Context) { + ctx, cancel := context.WithCancel(ctx) + return &TaskGroupWithContext{ + TaskGroup: TaskGroup{ + pool: p, + }, + ctx: ctx, + cancel: cancel, + }, ctx +} + // worker launches a worker goroutine func worker(context context.Context, firstTask func(), tasks <-chan func(), idleWorkerCount *int32, exitHandler func(), taskExecutor func(func())) { @@ -528,25 +543,3 @@ func worker(context context.Context, firstTask func(), tasks <-chan func(), idle } } } - -// TaskGroup represents a group of related tasks -type TaskGroup struct { - pool *WorkerPool - waitGroup sync.WaitGroup -} - -// Submit adds a task to this group and sends it to the worker pool to be executed -func (g *TaskGroup) Submit(task func()) { - g.waitGroup.Add(1) - g.pool.Submit(func() { - defer g.waitGroup.Done() - task() - }) -} - -// Wait waits until all the tasks in this group have completed -func (g *TaskGroup) Wait() { - - // Wait for all tasks to complete - g.waitGroup.Wait() -} diff --git a/pond_blackbox_test.go b/pond_blackbox_test.go index f81280e..ca10b11 100644 --- a/pond_blackbox_test.go +++ b/pond_blackbox_test.go @@ -450,34 +450,6 @@ func TestPoolWithCustomMinWorkers(t *testing.T) { assertEqual(t, 0, pool.RunningWorkers()) } -func TestGroupSubmit(t *testing.T) { - - pool := pond.New(5, 1000) - assertEqual(t, 0, pool.RunningWorkers()) - - // Submit groups of tasks - var doneCount, taskCount int32 - var groups []*pond.TaskGroup - for i := 0; i < 5; i++ { - group := pool.Group() - for j := 0; j < i+5; j++ { - group.Submit(func() { - time.Sleep(1 * time.Millisecond) - atomic.AddInt32(&doneCount, 1) - }) - taskCount++ - } - groups = append(groups, group) - } - - // Wait for all groups to complete - for _, group := range groups { - group.Wait() - } - - assertEqual(t, int32(taskCount), atomic.LoadInt32(&doneCount)) -} - func TestPoolWithCustomStrategy(t *testing.T) { pool := pond.New(3, 3, pond.Strategy(pond.RatedResizer(2)))