diff --git a/internal/lifecycle/lifecycle.go b/internal/lifecycle/lifecycle.go index 69163330a..af458fa6b 100644 --- a/internal/lifecycle/lifecycle.go +++ b/internal/lifecycle/lifecycle.go @@ -77,6 +77,11 @@ func (l *Lifecycle) Start(ctx context.Context) error { l.mu.Unlock() for _, hook := range l.hooks { + // if ctx has cancelled, bail out of the loop. + if err := ctx.Err(); err != nil { + return err + } + if hook.OnStart != nil { l.mu.Lock() l.runningHook = hook @@ -131,6 +136,9 @@ func (l *Lifecycle) Stop(ctx context.Context) error { // Run backward from last successful OnStart. var errs []error for ; l.numStarted > 0; l.numStarted-- { + if err := ctx.Err(); err != nil { + return err + } hook := l.hooks[l.numStarted-1] if hook.OnStop == nil { continue diff --git a/internal/lifecycle/lifecycle_test.go b/internal/lifecycle/lifecycle_test.go index f325de5ce..4fe75e256 100644 --- a/internal/lifecycle/lifecycle_test.go +++ b/internal/lifecycle/lifecycle_test.go @@ -119,6 +119,30 @@ func TestLifecycleStart(t *testing.T) { assert.Equal(t, 2, starterCount, "expected the first and second starter to execute") assert.Equal(t, 1, stopperCount, "expected the first stopper to execute since the second starter failed") }) + + t.Run("DoNotRunStartHooksWithExpiredCtx", func(t *testing.T) { + t.Parallel() + + l := New(testLogger(t), fxclock.System) + l.Append(Hook{ + OnStart: func(context.Context) error { + assert.Fail(t, "this hook should not run") + return nil + }, + OnStop: func(context.Context) error { + assert.Fail(t, "this hook should not run") + return nil + }, + }) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := l.Start(ctx) + require.Error(t, err) + // Note: Stop does not return an error here because no hooks + // have been started, so we don't end up any of the corresponding + // stop hooks. + require.NoError(t, l.Stop(ctx)) + }) } func TestLifecycleStop(t *testing.T) { @@ -247,6 +271,26 @@ func TestLifecycleStop(t *testing.T) { assert.Equal(t, err, l.Start(context.Background())) l.Stop(context.Background()) }) + + t.Run("DoNotRunStopHooksWithExpiredCtx", func(t *testing.T) { + t.Parallel() + + l := New(testLogger(t), fxclock.System) + l.Append(Hook{ + OnStart: func(context.Context) error { + return nil + }, + OnStop: func(context.Context) error { + assert.Fail(t, "this hook should not run") + return nil + }, + }) + ctx, cancel := context.WithCancel(context.Background()) + err := l.Start(ctx) + require.NoError(t, err) + cancel() + require.Error(t, l.Stop(ctx)) + }) } func TestHookRecordsFormat(t *testing.T) {