diff --git a/clock.go b/clock.go index 7783201..40555b3 100644 --- a/clock.go +++ b/clock.go @@ -1,6 +1,7 @@ package clock import ( + "context" "sort" "sync" "time" @@ -24,6 +25,8 @@ type Clock interface { Tick(d time.Duration) <-chan time.Time Ticker(d time.Duration) *Ticker Timer(d time.Duration) *Timer + WithDeadline(parent context.Context, d time.Time) (context.Context, context.CancelFunc) + WithTimeout(parent context.Context, t time.Duration) (context.Context, context.CancelFunc) } // New returns an instance of a real-time clock. @@ -60,7 +63,15 @@ func (c *clock) Timer(d time.Duration) *Timer { return &Timer{C: t.C, timer: t} } -// Mock represents a mock clock that only moves forward programmatically. +func (c *clock) WithDeadline(parent context.Context, d time.Time) (context.Context, context.CancelFunc) { + return context.WithDeadline(parent, d) +} + +func (c *clock) WithTimeout(parent context.Context, t time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(parent, t) +} + +// Mock represents a mock clock that only moves forward programmically. // It can be preferable to a real-time clock when testing time-based functionality. type Mock struct { mu sync.Mutex @@ -360,3 +371,8 @@ func (t *internalTicker) Tick(now time.Time) { // Sleep momentarily so that other goroutines can process. func gosched() { time.Sleep(1 * time.Millisecond) } + +var ( + // type checking + _ Clock = &Mock{} +) diff --git a/context.go b/context.go new file mode 100644 index 0000000..eb67594 --- /dev/null +++ b/context.go @@ -0,0 +1,86 @@ +package clock + +import ( + "context" + "fmt" + "sync" + "time" +) + +func (m *Mock) WithTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + return m.WithDeadline(parent, m.Now().Add(timeout)) +} + +func (m *Mock) WithDeadline(parent context.Context, deadline time.Time) (context.Context, context.CancelFunc) { + if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { + // The current deadline is already sooner than the new one. + return context.WithCancel(parent) + } + ctx := &timerCtx{clock: m, parent: parent, deadline: deadline, done: make(chan struct{})} + propagateCancel(parent, ctx) + dur := m.Until(deadline) + if dur <= 0 { + ctx.cancel(context.DeadlineExceeded) // deadline has already passed + return ctx, func() {} + } + ctx.Lock() + defer ctx.Unlock() + if ctx.err == nil { + ctx.timer = m.AfterFunc(dur, func() { + ctx.cancel(context.DeadlineExceeded) + }) + } + return ctx, func() { ctx.cancel(context.Canceled) } +} + +// propagateCancel arranges for child to be canceled when parent is. +func propagateCancel(parent context.Context, child *timerCtx) { + if parent.Done() == nil { + return // parent is never canceled + } + go func() { + select { + case <-parent.Done(): + child.cancel(parent.Err()) + case <-child.Done(): + } + }() +} + +type timerCtx struct { + sync.Mutex + + clock Clock + parent context.Context + deadline time.Time + done chan struct{} + + err error + timer *Timer +} + +func (c *timerCtx) cancel(err error) { + c.Lock() + defer c.Unlock() + if c.err != nil { + return // already canceled + } + c.err = err + close(c.done) + if c.timer != nil { + c.timer.Stop() + c.timer = nil + } +} + +func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { return c.deadline, true } + +func (c *timerCtx) Done() <-chan struct{} { return c.done } + +func (c *timerCtx) Err() error { return c.err } + +func (c *timerCtx) Value(key interface{}) interface{} { return c.parent.Value(key) } + +func (c *timerCtx) String() string { + return fmt.Sprintf("clock.WithDeadline(%s [%s])", c.deadline, c.deadline.Sub(c.clock.Now())) +} diff --git a/context_test.go b/context_test.go new file mode 100644 index 0000000..176d013 --- /dev/null +++ b/context_test.go @@ -0,0 +1,99 @@ +package clock + +import ( + "context" + "errors" + "testing" + "time" +) + +// Ensure that WithDeadline is cancelled when deadline exceeded. +func TestMock_WithDeadline(t *testing.T) { + m := NewMock() + ctx, _ := m.WithDeadline(context.Background(), m.Now().Add(time.Second)) + m.Add(time.Second) + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Error("invalid type of error returned when deadline exceeded") + } + default: + t.Error("context is not cancelled when deadline exceeded") + } +} + +// Ensure that WithDeadline does nothing when the deadline is later than the current deadline. +func TestMock_WithDeadlineLaterThanCurrent(t *testing.T) { + m := NewMock() + ctx, _ := m.WithDeadline(context.Background(), m.Now().Add(time.Second)) + ctx, _ = m.WithDeadline(ctx, m.Now().Add(10*time.Second)) + m.Add(time.Second) + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Error("invalid type of error returned when deadline exceeded") + } + default: + t.Error("context is not cancelled when deadline exceeded") + } +} + +// Ensure that WithDeadline cancel closes Done channel with context.Canceled error. +func TestMock_WithDeadlineCancel(t *testing.T) { + m := NewMock() + ctx, cancel := m.WithDeadline(context.Background(), m.Now().Add(time.Second)) + cancel() + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.Canceled) { + t.Error("invalid type of error returned after cancellation") + } + case <-time.After(time.Second): + t.Error("context is not cancelled after cancel was called") + } +} + +// Ensure that WithDeadline closes child contexts after it was closed. +func TestMock_WithDeadlineCancelledWithParent(t *testing.T) { + m := NewMock() + parent, cancel := context.WithCancel(context.Background()) + ctx, _ := m.WithDeadline(parent, m.Now().Add(time.Second)) + cancel() + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.Canceled) { + t.Error("invalid type of error returned after cancellation") + } + case <-time.After(time.Second): + t.Error("context is not cancelled when parent context is cancelled") + } +} + +// Ensure that WithDeadline cancelled immediately when deadline has already passed. +func TestMock_WithDeadlineImmediate(t *testing.T) { + m := NewMock() + ctx, _ := m.WithDeadline(context.Background(), m.Now().Add(-time.Second)) + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Error("invalid type of error returned when deadline has already passed") + } + default: + t.Error("context is not cancelled when deadline has already passed") + } +} + +// Ensure that WithTimeout is cancelled when deadline exceeded. +func TestMock_WithTimeout(t *testing.T) { + m := NewMock() + ctx, _ := m.WithTimeout(context.Background(), time.Second) + m.Add(time.Second) + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Error("invalid type of error returned when time is over") + } + default: + t.Error("context is not cancelled when time is over") + } +}