Skip to content

Commit

Permalink
Group context
Browse files Browse the repository at this point in the history
  • Loading branch information
alitto committed May 9, 2022
1 parent d4c09d4 commit f9197db
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 52 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Expand Up @@ -10,7 +10,7 @@ jobs:
name: Test
strategy:
matrix:
go-version: [1.15.x, 1.16.x, 1.17.x]
go-version: [1.15.x, 1.16.x, 1.17.x, 1.18.x]
os: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.os }}
steps:
Expand Down
2 changes: 1 addition & 1 deletion go.mod
@@ -1,3 +1,3 @@
module github.com/alitto/pond

go 1.17
go 1.18
90 changes: 90 additions & 0 deletions group.go
@@ -0,0 +1,90 @@
package pond

import (
"context"
"sync"
)

// TaskGroup represents a group of related tasks
type TaskGroup struct {
pool *WorkerPool
waitGroup sync.WaitGroup
}

// Submit adds a task to this group and sends it to the worker pool to be executed
func (g *TaskGroup) Submit(task func()) {
g.waitGroup.Add(1)

g.pool.Submit(func() {
defer g.waitGroup.Done()

task()
})
}

// Wait waits until all the tasks in this group have completed
func (g *TaskGroup) Wait() {

// Wait for all tasks to complete
g.waitGroup.Wait()
}

// TaskGroup represents a group of related tasks associated to a context
type TaskGroupWithContext struct {
TaskGroup
ctx context.Context
cancel context.CancelFunc
errOnce sync.Once
err error
}

// Submit adds a task to this group and sends it to the worker pool to be executed
func (g *TaskGroupWithContext) Submit(task func() error) {
g.waitGroup.Add(1)

g.pool.Submit(func() {
defer g.waitGroup.Done()

// If context has already been cancelled, skip task execution
if g.ctx != nil {
select {
case <-g.ctx.Done():
return
default:
}
}

// don't actually ignore errors
err := task()
if err != nil {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel()
}
})
}
})
}

// Wait blocks until either all the tasks submitted to this group have completed,
// one of them returned a non-nil error or the context associated to this group
// was canceled.
func (g *TaskGroupWithContext) Wait() error {

// Wait for all tasks to complete
tasksCompleted := make(chan struct{})
go func() {
g.waitGroup.Wait()
tasksCompleted <- struct{}{}
}()

select {
case <-tasksCompleted:
// If context was provided, cancel it to signal all running tasks to stop
g.cancel()
case <-g.ctx.Done():
}

return g.err
}
107 changes: 107 additions & 0 deletions group_blackbox_test.go
@@ -0,0 +1,107 @@
package pond_test

import (
"context"
"errors"
"sync/atomic"
"testing"
"time"

"github.com/alitto/pond"
)

func TestGroupSubmit(t *testing.T) {

pool := pond.New(5, 1000)
assertEqual(t, 0, pool.RunningWorkers())

// Submit groups of tasks
var doneCount, taskCount int32
var groups []*pond.TaskGroup
for i := 0; i < 5; i++ {
group := pool.Group()
for j := 0; j < i+5; j++ {
group.Submit(func() {
time.Sleep(1 * time.Millisecond)
atomic.AddInt32(&doneCount, 1)
})
taskCount++
}
groups = append(groups, group)
}

// Wait for all groups to complete
for _, group := range groups {
group.Wait()
}

assertEqual(t, int32(taskCount), atomic.LoadInt32(&doneCount))
}

func TestGroupContext(t *testing.T) {

pool := pond.New(3, 100)
assertEqual(t, 0, pool.RunningWorkers())

// Submit a group of tasks
var doneCount, startedCount int32
group, ctx := pool.GroupContext(context.Background())
for i := 0; i < 10; i++ {
group.Submit(func() error {
atomic.AddInt32(&startedCount, 1)

select {
case <-time.After(5 * time.Millisecond):
atomic.AddInt32(&doneCount, 1)
case <-ctx.Done():
}

return nil
})
}

err := group.Wait()
assertEqual(t, nil, err)
assertEqual(t, int32(10), atomic.LoadInt32(&startedCount))
assertEqual(t, int32(10), atomic.LoadInt32(&doneCount))
}

func TestGroupContextWithError(t *testing.T) {

pool := pond.New(1, 100)
assertEqual(t, 0, pool.RunningWorkers())

expectedErr := errors.New("Something went wrong")

// Submit a group of tasks
var doneCount, startedCount int32
group, ctx := pool.GroupContext(context.Background())
for i := 0; i < 10; i++ {
n := i
group.Submit(func() error {
atomic.AddInt32(&startedCount, 1)

// Task number 5 fails
if n == 4 {
time.Sleep(10 * time.Millisecond)
return expectedErr
}

select {
case <-time.After(5 * time.Millisecond):
atomic.AddInt32(&doneCount, 1)
case <-ctx.Done():
}

return nil
})
}

err := group.Wait()
assertEqual(t, expectedErr, err)

pool.StopAndWait()

assertEqual(t, int32(5), atomic.LoadInt32(&startedCount))
assertEqual(t, int32(4), atomic.LoadInt32(&doneCount))
}
37 changes: 15 additions & 22 deletions pond.go
Expand Up @@ -487,6 +487,21 @@ func (p *WorkerPool) Group() *TaskGroup {
}
}

// GroupContext creates a new task group and an associated Context derived from ctx.
//
// The derived Context is canceled the first time a function submitted to the group
// returns a non-nil error or the first time Wait returns, whichever occurs first.
func (p *WorkerPool) GroupContext(ctx context.Context) (*TaskGroupWithContext, context.Context) {
ctx, cancel := context.WithCancel(ctx)
return &TaskGroupWithContext{
TaskGroup: TaskGroup{
pool: p,
},
ctx: ctx,
cancel: cancel,
}, ctx
}

// worker launches a worker goroutine
func worker(context context.Context, firstTask func(), tasks <-chan func(), idleWorkerCount *int32, exitHandler func(), taskExecutor func(func())) {

Expand Down Expand Up @@ -528,25 +543,3 @@ func worker(context context.Context, firstTask func(), tasks <-chan func(), idle
}
}
}

// TaskGroup represents a group of related tasks
type TaskGroup struct {
pool *WorkerPool
waitGroup sync.WaitGroup
}

// Submit adds a task to this group and sends it to the worker pool to be executed
func (g *TaskGroup) Submit(task func()) {
g.waitGroup.Add(1)
g.pool.Submit(func() {
defer g.waitGroup.Done()
task()
})
}

// Wait waits until all the tasks in this group have completed
func (g *TaskGroup) Wait() {

// Wait for all tasks to complete
g.waitGroup.Wait()
}
28 changes: 0 additions & 28 deletions pond_blackbox_test.go
Expand Up @@ -450,34 +450,6 @@ func TestPoolWithCustomMinWorkers(t *testing.T) {
assertEqual(t, 0, pool.RunningWorkers())
}

func TestGroupSubmit(t *testing.T) {

pool := pond.New(5, 1000)
assertEqual(t, 0, pool.RunningWorkers())

// Submit groups of tasks
var doneCount, taskCount int32
var groups []*pond.TaskGroup
for i := 0; i < 5; i++ {
group := pool.Group()
for j := 0; j < i+5; j++ {
group.Submit(func() {
time.Sleep(1 * time.Millisecond)
atomic.AddInt32(&doneCount, 1)
})
taskCount++
}
groups = append(groups, group)
}

// Wait for all groups to complete
for _, group := range groups {
group.Wait()
}

assertEqual(t, int32(taskCount), atomic.LoadInt32(&doneCount))
}

func TestPoolWithCustomStrategy(t *testing.T) {

pool := pond.New(3, 3, pond.Strategy(pond.RatedResizer(2)))
Expand Down

0 comments on commit f9197db

Please sign in to comment.