/
task.go
128 lines (108 loc) · 2.01 KB
/
task.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package gosync
import (
"context"
"time"
)
type Task interface {
IsDone() bool
Wait()
WaitCtx(ctx context.Context) error
WaitTimeout(timeout time.Duration) error
}
var _ Task = (*task)(nil)
type task struct {
doneCh chan struct{}
}
func newTask() *task {
return &task{
doneCh: make(chan struct{}),
}
}
func (t *task) done() {
close(t.doneCh)
}
func (t *task) IsDone() bool {
select {
case <-t.doneCh:
return true
default:
return false
}
}
func (t *task) Wait() {
<-t.doneCh
}
func (t *task) WaitCtx(ctx context.Context) error {
select {
case <-t.doneCh:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (t *task) WaitTimeout(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return t.WaitCtx(ctx)
}
// Go runs fn in a goroutine and returns a Task that
// can be used to wait for the goroutine to finish.
func Go(fn func()) Task {
t := newTask()
go func() {
fn()
t.done()
}()
return t
}
func WaitAll(tasks ...Task) {
for _, t := range tasks {
t.Wait()
}
}
func WaitAllCtx(ctx context.Context, tasks ...Task) error {
done := make(chan struct{})
go func() {
WaitAll(tasks...)
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func WaitAllTimeout(timeout time.Duration, tasks ...Task) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return WaitAllCtx(ctx, tasks...)
}
func WaitAny(tasks ...Task) {
done := make(chan struct{})
for _, t := range tasks {
go func(t Task) {
t.Wait()
done <- struct{}{}
}(t)
}
<-done
}
func WaitAnyCtx(ctx context.Context, tasks ...Task) error {
done := make(chan struct{})
go func() {
WaitAny(tasks...)
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func WaitAnyTimeout(timeout time.Duration, tasks ...Task) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return WaitAnyCtx(ctx, tasks...)
}