Skip to content

Commit

Permalink
Group context
Browse files Browse the repository at this point in the history
  • Loading branch information
alitto committed May 9, 2022
1 parent d4c09d4 commit b1a3c9c
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Expand Up @@ -10,7 +10,7 @@ jobs:
name: Test
strategy:
matrix:
go-version: [1.15.x, 1.16.x, 1.17.x]
go-version: [1.15.x, 1.16.x, 1.17.x, 1.18.x]
os: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.os }}
steps:
Expand Down
2 changes: 1 addition & 1 deletion go.mod
@@ -1,3 +1,3 @@
module github.com/alitto/pond

go 1.17
go 1.18
71 changes: 71 additions & 0 deletions group.go
@@ -0,0 +1,71 @@
package pond

import (
"context"
"fmt"
"sync"
)

type TaskGroup = Group[func()]

type TaskGroupWithContext = Group[func() error]

// Group represents a group of related tasks
type Group[T func() | func() error] struct {
pool *WorkerPool
waitGroup sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
errOnce sync.Once
err error
}

// Submit adds a task to this group and sends it to the worker pool to be executed
func (g *Group[T]) Submit(task T) {
g.waitGroup.Add(1)

g.pool.Submit(func() {
defer g.waitGroup.Done()

// If context has already been cancelled, skip task execution
if g.ctx != nil {
select {
case <-g.ctx.Done():
return
default:
}
}

switch any(task).(type) {
case func():
any(task).(func())()
case func() error:
// don't actually ignore errors
err := any(task).(func() error)()
if err != nil {
fmt.Printf("Error not nil %v\n", err)
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel()
}
})
} else {
fmt.Printf("Error is nil %v\n", err)
}
}
})
}

// Wait waits until all the tasks in this group have completed
func (g *Group[T]) Wait() error {

// Wait for all tasks to complete
g.waitGroup.Wait()

// If context was provided, cancel it to signal all running tasks to stop
if g.cancel != nil {
g.cancel()
}
return g.err
}
35 changes: 13 additions & 22 deletions pond.go
Expand Up @@ -487,6 +487,19 @@ func (p *WorkerPool) Group() *TaskGroup {
}
}

// GroupContext creates a new task group and an associated Context derived from ctx.
//
// The derived Context is canceled the first time a function submitted to the group
// returns a non-nil error or the first time Wait returns, whichever occurs first.
func (p *WorkerPool) GroupContext(ctx context.Context) (*TaskGroupWithContext, context.Context) {
ctx, cancel := context.WithCancel(ctx)
return &TaskGroupWithContext{
pool: p,
ctx: ctx,
cancel: cancel,
}, ctx
}

// worker launches a worker goroutine
func worker(context context.Context, firstTask func(), tasks <-chan func(), idleWorkerCount *int32, exitHandler func(), taskExecutor func(func())) {

Expand Down Expand Up @@ -528,25 +541,3 @@ func worker(context context.Context, firstTask func(), tasks <-chan func(), idle
}
}
}

// TaskGroup represents a group of related tasks
type TaskGroup struct {
pool *WorkerPool
waitGroup sync.WaitGroup
}

// Submit adds a task to this group and sends it to the worker pool to be executed
func (g *TaskGroup) Submit(task func()) {
g.waitGroup.Add(1)
g.pool.Submit(func() {
defer g.waitGroup.Done()
task()
})
}

// Wait waits until all the tasks in this group have completed
func (g *TaskGroup) Wait() {

// Wait for all tasks to complete
g.waitGroup.Wait()
}
36 changes: 36 additions & 0 deletions pond_blackbox_test.go
Expand Up @@ -2,6 +2,7 @@ package pond_test

import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -478,6 +479,41 @@ func TestGroupSubmit(t *testing.T) {
assertEqual(t, int32(taskCount), atomic.LoadInt32(&doneCount))
}

func TestGroupContext(t *testing.T) {

pool := pond.New(5, 100)
assertEqual(t, 0, pool.RunningWorkers())

expectedErr := errors.New("Something went wrong")

// Submit a group of tasks
var doneCount int32
group, ctx := pool.GroupContext(context.Background())
for i := 0; i < 20; i++ {
n := i
group.Submit(func() error {

if n == 2 {
time.Sleep(5 * time.Millisecond)
return expectedErr
}

select {
case <-time.After(time.Duration(n) * time.Millisecond):
atomic.AddInt32(&doneCount, 1)
case <-ctx.Done():
}

return nil
})
}

err := group.Wait()
assertEqual(t, expectedErr, err)

assertEqual(t, int32(5), atomic.LoadInt32(&doneCount))
}

func TestPoolWithCustomStrategy(t *testing.T) {

pool := pond.New(3, 3, pond.Strategy(pond.RatedResizer(2)))
Expand Down

0 comments on commit b1a3c9c

Please sign in to comment.