diff --git a/app.go b/app.go index 8d0132c62..67244b913 100644 --- a/app.go +++ b/app.go @@ -300,6 +300,10 @@ type App struct { dones []chan os.Signal shutdownSig os.Signal + // Used to make sure Start/Stop is called only once. + runStart sync.Once + runStop sync.Once + osExit func(code int) // os.Exit override; used for testing only } @@ -658,21 +662,25 @@ var ( // Note that Start short-circuits immediately if the New constructor // encountered any errors in application initialization. func (app *App) Start(ctx context.Context) (err error) { - defer func() { - app.log.LogEvent(&fxevent.Started{Err: err}) - }() + app.runStart.Do(func() { + defer func() { + app.log.LogEvent(&fxevent.Started{Err: err}) + }() - if app.err != nil { - // Some provides failed, short-circuit immediately. - return app.err - } + if app.err != nil { + // Some provides failed, short-circuit immediately. + err = app.err + return + } - return withTimeout(ctx, &withTimeoutParams{ - hook: _onStartHook, - callback: app.start, - lifecycle: app.lifecycle, - log: app.log, + err = withTimeout(ctx, &withTimeoutParams{ + hook: _onStartHook, + callback: app.start, + lifecycle: app.lifecycle, + log: app.log, + }) }) + return } func (app *App) start(ctx context.Context) error { @@ -700,16 +708,20 @@ func (app *App) start(ctx context.Context) error { // called are executed. However, all those hooks are executed, even if some // fail. func (app *App) Stop(ctx context.Context) (err error) { - defer func() { - app.log.LogEvent(&fxevent.Stopped{Err: err}) - }() + app.runStop.Do(func() { + // Protect the Stop hooks from being called multiple times. + defer func() { + app.log.LogEvent(&fxevent.Stopped{Err: err}) + }() - return withTimeout(ctx, &withTimeoutParams{ - hook: _onStopHook, - callback: app.lifecycle.Stop, - lifecycle: app.lifecycle, - log: app.log, + err = withTimeout(ctx, &withTimeoutParams{ + hook: _onStopHook, + callback: app.lifecycle.Stop, + lifecycle: app.lifecycle, + log: app.log, + }) }) + return } // Done returns a channel of signals to block on after starting the diff --git a/app_test.go b/app_test.go index 001127203..971804176 100644 --- a/app_test.go +++ b/app_test.go @@ -31,6 +31,7 @@ import ( "reflect" "runtime" "strings" + "sync" "testing" "time" @@ -1281,6 +1282,47 @@ func TestAppStart(t *testing.T) { err := app.Start(context.Background()).Error() assert.Contains(t, err, "OnStart hook added by go.uber.org/fx_test.TestAppStart.func10.1 failed: goroutine exited without returning") }) + + t.Run("Start/Stop should be called exactly once only.", func(t *testing.T) { + t.Parallel() + startCalled := 0 + stopCalled := 0 + app := fxtest.New(t, + Provide(Annotate(func() int { return 0 }, + OnStart(func(context.Context) error { + startCalled += 1 + return nil + }), + OnStop(func(context.Context) error { + stopCalled += 1 + return nil + })), + ), + Invoke(func(i int) { + assert.Equal(t, 0, i) + }), + ) + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + app.Start(context.Background()) + }() + } + wg.Wait() + assert.Equal(t, 1, startCalled) + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + app.Stop(context.Background()) + }() + } + wg.Wait() + assert.Equal(t, 1, stopCalled) + }) + } func TestAppStop(t *testing.T) {