From e3945ddc4681ea2e0a6f517e36675272dd34bcb2 Mon Sep 17 00:00:00 2001 From: jmills Date: Tue, 28 Jun 2022 04:54:33 +0000 Subject: [PATCH 01/16] Adds OnStart/OnStop lifecycle Annotations --- annotated.go | 380 +++++++++++++++++++++++++++++++++- annotated_test.go | 516 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 891 insertions(+), 5 deletions(-) diff --git a/annotated.go b/annotated.go index 85a8ec0e6..16a7a5ced 100644 --- a/annotated.go +++ b/annotated.go @@ -21,6 +21,7 @@ package fx import ( + "context" "errors" "fmt" "reflect" @@ -195,6 +196,310 @@ func ResultTags(tags ...string) Annotation { return resultTagsAnnotation{tags} } +type _lifecycleHookAnnotationType int + +const ( + _unknownHookType _lifecycleHookAnnotationType = iota + _onStartHookType + _onStopHookType +) + +type lifecycleHookAnnotation struct { + Type _lifecycleHookAnnotationType + Target interface{} +} + +func (la *lifecycleHookAnnotation) apply(ann *annotated) error { + name := "UnknownHookAnnotation" + switch la.Type { + case _onStartHookType: + name = _onStartHook + case _onStopHookType: + name = _onStopHook + } + + if la.Target == nil { + return fmt.Errorf( + "cannot use nil function for %v hook annotation", + name, + ) + } + + for _, h := range ann.Hooks { + if la.Type == h.Type { + return fmt.Errorf( + "cannot apply more than one %v hook annotation", + name, + ) + } + } + + ft := la.targetType() + if ft.Kind() != reflect.Func { + return fmt.Errorf( + "must provide function for hook, got %v (%T)", + la.Target, + la.Target, + ) + } + + if ft.NumIn() < 1 || ft.In(0) != _typeOfContext { + return fmt.Errorf( + "first argument of hook must be context.Context, got %v (%T)", + la.Target, + la.Target, + ) + } + + hasOut := ft.NumOut() == 1 + returnsErr := hasOut && ft.Out(0) == _typeOfError + + if !hasOut || !returnsErr { + return fmt.Errorf( + "hooks must return only an error type, got %v (%T)", + la.Target, + la.Target, + ) + } + + if ft.IsVariadic() { + return fmt.Errorf( + "hooks must not accept variatic parameters, got %v (%T)", + la.Target, + la.Target, + ) + } + + ann.Hooks = append(ann.Hooks, la) + return nil +} + +var ( + _typeOfLifecycle reflect.Type = reflect.TypeOf((*Lifecycle)(nil)).Elem() + _typeOfContext reflect.Type = reflect.TypeOf((*context.Context)(nil)).Elem() +) + +func (la *lifecycleHookAnnotation) targetType() (targetType reflect.Type) { + return reflect.TypeOf(la.Target) +} + +func (la *lifecycleHookAnnotation) resolveMap(results []reflect.Type) ( + resultMap map[reflect.Type]struct { + resolve func(reflect.Value, int) reflect.Value + }, +) { + // index the constructor results by type and position to allow + // for us to omit these from the in types that must be injected, + // and to allow us to interleave constructor results + // into our hook arguments. + resultMap = make(map[reflect.Type]struct { + resolve func(reflect.Value, int) reflect.Value + }, 0) + + for _, r := range results { + resultMap[r] = struct { + resolve func(reflect.Value, int) reflect.Value + }{ + resolve: func(v reflect.Value, pos int) (value reflect.Value) { + return v + }, + } + } + + fmt.Printf("Result map %+v\n", resultMap) + return +} + +func (la *lifecycleHookAnnotation) resolveLifecycleParamField( + param reflect.Value, + n int, +) ( + value reflect.Value, +) { + if param.Kind() == reflect.Struct { + nf := param.NumField() + if n <= nf { + value = param.FieldByName(fmt.Sprintf("Field%d", n-1)) + } + } + + return value +} + +func (la *lifecycleHookAnnotation) parameters(results ...reflect.Type) ( + in reflect.Type, + argmap func( + args []reflect.Value, + ) (Lifecycle, []reflect.Value), +) { + resultMap := la.resolveMap(results) + + // hook functions require a lifecycle, and it should be injected + params := []reflect.StructField{ + { + Name: "In", + Type: _typeOfIn, + Anonymous: true, + }, + { + Name: "Lifecycle", + Type: _typeOfLifecycle, + }, + } + + type valueResolver func(reflect.Value, int) reflect.Value + type argSource struct { + pos int + result bool + resolve valueResolver + } + + resolverIdx := make([]argSource, 1) + ft := la.targetType() + for i := 1; i < ft.NumIn(); i++ { + t := ft.In(i) + resultIdx, isProvidedByResults := resultMap[t] + + if isProvidedByResults { + resolverIdx = append(resolverIdx, argSource{ + pos: i, + result: true, + resolve: resultIdx.resolve, + }) + continue + } + + field := reflect.StructField{ + Name: fmt.Sprintf("Field%d", i), + Type: t, + } + params = append(params, field) + + resolver := func(v reflect.Value, pos int) (value reflect.Value) { + value = la.resolveLifecycleParamField(v, i) + return + } + + resolverIdx = append(resolverIdx, argSource{ + pos: i, + resolve: resolver, + }) + } + + in = reflect.StructOf(params) + + argmap = func( + args []reflect.Value, + ) (lc Lifecycle, remapped []reflect.Value) { + remapped = make([]reflect.Value, ft.NumIn()) + + if len(args) != 0 { + + p := args[0] + results := args[1] + + lc, _ = p.FieldByName("Lifecycle").Interface().(Lifecycle) + + for i := 1; i < ft.NumIn(); i++ { + resolver := resolverIdx[i] + + source := p + if resolver.result { + source = results + } + + remapped[i] = resolver.resolve(source, i) + } + } + + return + } + return +} + +func (la *lifecycleHookAnnotation) buildHook(fn func(context.Context) error) (hook Hook) { + switch la.Type { + case _onStartHookType: + hook.OnStart = fn + case _onStopHookType: + hook.OnStop = fn + } + + return +} + +func (la *lifecycleHookAnnotation) Build(results ...reflect.Type) (reflect.Value, error) { + in, paramMap := la.parameters(results...) + params := []reflect.Type{in} + for _, r := range results { + if r != _typeOfError { + params = append(params, r) + } + } + + origFn := reflect.ValueOf(la.Target) + newFnType := reflect.FuncOf(params, nil, false) + newFn := reflect.MakeFunc(newFnType, func(args []reflect.Value) []reflect.Value { + var lc Lifecycle + lc, args = paramMap(args) + hookFn := func(ctx context.Context) (err error) { + args[0] = reflect.ValueOf(ctx) + + results := origFn.Call(args) + if len(results) > 0 && results[0].Type() == _typeOfError { + err, _ = results[0].Interface().(error) + } + + return + } + + lc.Append(la.buildHook(hookFn)) + return []reflect.Value{} + }) + + return newFn, nil +} + +// OnStart is an Annotation that appends an OnStart Hook to the application +// Lifecycle when that function is called. This provides a way to create +// Lifecycle OnStart hooks without building a function that takes a dependency +// on the Lifecycle type. +// +// fx.Annotate( +// func(...) Server { ... }, +// fx.OnStart(func(ctx context.Context, server Server) error { +// return server.Listen(ctx) +// }), +// ) +// +// Only one OnStart annotation may be applied to a given function at a time. +func OnStart(onStart interface{}) Annotation { + return &lifecycleHookAnnotation{ + Type: _onStartHookType, + Target: onStart, + } +} + +// OnStop is an Annotation that appends an OnStop Hook to the application +// Lifecycle when that function is called. This provides a way to create +// Lifecycle OnStop hooks without building a function that takes a dependency +// on the Lifecycle type. +// +// fx.Annotate( +// func(...) Server { ... }, +// fx.OnStop(func(ctx context.Context, server Server) error { +// return server.Shutdown(ctx) +// }), +// ) +// +// Only one OnStop annotation may be applied to a given function at a time. +func OnStop(onStop interface{}) Annotation { + return &lifecycleHookAnnotation{ + Type: _onStopHookType, + Target: onStop, + } +} + type asAnnotation struct { targets []interface{} } @@ -265,6 +570,7 @@ type annotated struct { ResultTags []string As [][]reflect.Type FuncPtr uintptr + Hooks []*lifecycleHookAnnotation } func (ann annotated) String() string { @@ -295,16 +601,26 @@ func (ann *annotated) Build() (interface{}, error) { return nil, fmt.Errorf("invalid annotation function %T: %w", ann.Target, err) } - paramTypes, remapParams := ann.parameters() resultTypes, remapResults, err := ann.results() if err != nil { return nil, err } + paramTypes, remapParams, hookParams := ann.parameters(resultTypes...) + + var hooks []reflect.Value + for _, hook := range ann.Hooks { + if hookFn, err := hook.Build(resultTypes...); err == nil { + hooks = append(hooks, hookFn) + } + } newFnType := reflect.FuncOf(paramTypes, resultTypes, false) origFn := reflect.ValueOf(ann.Target) ann.FuncPtr = origFn.Pointer() + newFn := reflect.MakeFunc(newFnType, func(args []reflect.Value) []reflect.Value { + origArgs := make([]reflect.Value, len(args)) + copy(origArgs, args) args = remapParams(args) var results []reflect.Value if ft.IsVariadic() { @@ -313,6 +629,22 @@ func (ann *annotated) Build() (interface{}, error) { results = origFn.Call(args) } results = remapResults(results) + + // if the results are greater than zero and the final result + // is a non-nil error, do not execute hook installers + hasErrorResult := len(results) > 0 && results[len(results)-1].Type() == _typeOfError + if hasErrorResult { + err, ok := results[len(results)-1].Interface().(error) + if ok && err != nil { + return results + } + } + + for i, hook := range hooks { + hookArgs := hookParams(i, origArgs, results) + hook.Call(hookArgs) + } + return results }) @@ -349,9 +681,10 @@ func (ann *annotated) typeCheckOrigFn() error { // parameters returns the type for the parameters of the annotated function, // and a function that maps the arguments of the annotated function // back to the arguments of the target function. -func (ann *annotated) parameters() ( +func (ann *annotated) parameters(results ...reflect.Type) ( types []reflect.Type, remap func([]reflect.Value) []reflect.Value, + hookValueMap func(int, []reflect.Value, []reflect.Value) []reflect.Value, ) { ft := reflect.TypeOf(ann.Target) @@ -362,10 +695,10 @@ func (ann *annotated) parameters() ( // No parameter annotations. Return the original types // and an identity function. - if len(ann.ParamTags) == 0 && !ft.IsVariadic() { + if len(ann.ParamTags) == 0 && !ft.IsVariadic() && len(ann.Hooks) == 0 { return types, func(args []reflect.Value) []reflect.Value { return args - } + }, nil } // Turn parameters into an fx.In struct. @@ -395,8 +728,20 @@ func (ann *annotated) parameters() ( inFields = append(inFields, field) } + // append required types for hooks to types field, but do not + // include them as params in constructor call + for h, t := range ann.Hooks { + params, _ := t.parameters(results...) + field := reflect.StructField{ + Name: fmt.Sprintf("Hook%d", h), + Type: params, + } + inFields = append(inFields, field) + } + types = []reflect.Type{reflect.StructOf(inFields)} - return types, func(args []reflect.Value) []reflect.Value { + + remap = func(args []reflect.Value) []reflect.Value { params := args[0] args = args[:0] for i := 0; i < ft.NumIn(); i++ { @@ -404,6 +749,31 @@ func (ann *annotated) parameters() ( } return args } + + hookValueMap = func(hook int, args []reflect.Value, results []reflect.Value) (out []reflect.Value) { + params := args[0] + + if params.Kind() == reflect.Struct { + var zero reflect.Value + value := params.FieldByNameFunc(func(name string) bool { + return name == fmt.Sprintf("Hook%d", hook) + }) + + if value != zero { + out = append(out, value) + } + } + + for _, r := range results { + if r.Type() != _typeOfError { + out = append(out, r) + } + } + + return + } + + return } // results returns the types of the results of the annotated function, diff --git a/annotated_test.go b/annotated_test.go index f8ba3b9e1..63634b534 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -22,6 +22,7 @@ package fx_test import ( "bytes" + "context" "errors" "fmt" "io" @@ -983,4 +984,519 @@ func TestAnnotate(t *testing.T) { assert.Contains(t, err.Error(), "invalid annotation function func(fx_test.B) string") assert.Contains(t, err.Error(), "fx.In structs cannot be annotated") }) + + t.Run("Hooks", testHookAnnotations) +} + +func testHookAnnotations(t *testing.T) { + t.Parallel() + + validateApp := func(t *testing.T, opts ...fx.Option) error { + return fx.ValidateApp( + append(opts, fx.Logger(fxtest.NewTestPrinter(t)))..., + ) + } + + t.Run("depend on result interface of target", func(t *testing.T) { + //t.Skip() + type stub interface { + String() string + } + + var started bool + + app := fxtest.New(t, + fx.Provide( + fx.Annotate( + func() (stub, error) { + b := []byte("expected") + return bytes.NewBuffer(b), nil + }, + fx.OnStart(func(_ context.Context, s stub) error { + started = true + if !assert.Equal(t, "expected", s.String()) { + return fmt.Errorf( + "expected %q got %q", + "expected", + s.String(), + ) + } + return nil + }), + ), + ), + fx.Invoke(func(s stub) { + require.Equal(t, "expected", s.String()) + }), + ) + + ctx := context.Background() + assert.False(t, started) + require.NoError(t, app.Start(ctx)) + assert.True(t, started) + require.NoError(t, app.Stop(ctx)) + }) + + t.Run("start and stop without dependencies", func(t *testing.T) { + t.Parallel() + + type stub interface{} + + var ( + invoked bool + started bool + stopped bool + ) + + app := fxtest.New(t, + fx.Provide( + fx.Annotate( + func() (stub, error) { return nil, nil }, + fx.OnStart(func(context.Context) error { + started = true + return nil + }), + fx.OnStop(func(context.Context) error { + stopped = true + return nil + }), + ), + ), + fx.Invoke(func(s stub) { + invoked = s == nil + }), + ) + + ctx := context.Background() + assert.False(t, started) + require.NoError(t, app.Start(ctx)) + assert.True(t, invoked) + assert.True(t, started) + assert.False(t, stopped) + require.NoError(t, app.Stop(ctx)) + assert.True(t, stopped) + + }) + + t.Run("depedency chain", func(t *testing.T) { + t.Parallel() + + type ( + a interface{} + b interface{} + c interface{} + ) + + var aHook, bHook bool + + app := fxtest.New(t, + fx.Provide( + fx.Annotate( + func() (a, error) { return nil, nil }, + fx.OnStart(func(context.Context) error { + aHook = true + return nil + }), + ), + ), + fx.Provide( + fx.Annotate( + func() (b, error) { return nil, nil }, + fx.OnStart(func(context.Context) error { + bHook = true + return nil + }), + ), + ), + fx.Provide(func(a, b) c { return nil }), + fx.Invoke(func(c) {}), + ) + + ctx := context.Background() + assert.False(t, aHook) + assert.False(t, bHook) + require.NoError(t, app.Start(ctx)) + assert.True(t, aHook) + assert.True(t, bHook) + require.NoError(t, app.Stop(ctx)) + }) + + t.Run("with extra dependency parameter", func(t *testing.T) { + t.Parallel() + + type ( + a interface{} + b interface{} + c interface{} + ) + + var aHook bool + + app := fxtest.New(t, + fx.Provide( + fx.Annotate( + func() (a, error) { return nil, nil }, + fx.OnStart(func(context.Context, b) error { + aHook = true + return nil + }), + ), + ), + fx.Provide(func() b { return nil }), + fx.Provide(func(a, b) c { return nil }), + fx.Invoke(func(c) {}), + ) + + ctx := context.Background() + assert.False(t, aHook) + require.NoError(t, app.Start(ctx)) + defer func() { + require.NoError(t, app.Stop(ctx)) + }() + assert.True(t, aHook) + }) + + t.Run("with multiple extra dependency parameters", func(t *testing.T) { + t.Parallel() + + type ( + a interface{} + b interface{} + c interface{} + ) + + var aHook bool + + app := fxtest.New(t, + fx.Provide( + fx.Annotate( + func() (a, error) { return nil, nil }, + fx.OnStart(func(context.Context, b, c) error { + aHook = true + return nil + }), + ), + ), + fx.Provide(func() b { return nil }), + fx.Provide(func() c { return nil }), + fx.Invoke(func(a) {}), + ) + + ctx := context.Background() + assert.False(t, aHook) + require.NoError(t, app.Start(ctx)) + defer func() { + require.NoError(t, app.Stop(ctx)) + }() + assert.True(t, aHook) + }) + + t.Run("with unprovided dependency", func(t *testing.T) { + t.Parallel() + + type ( + a interface{} + b interface{} + ) + + err := validateApp(t, + fx.Provide( + fx.Annotate( + func() (a, error) { return nil, nil }, + fx.OnStart(func(context.Context, b) error { + return nil + }), + ), + ), + fx.Invoke(func(a) {}), + ) + + require.Error(t, err) + require.Contains(t, err.Error(), "missing type: fx_test.b") + }) + + t.Run("that returns error", func(t *testing.T) { + t.Parallel() + + type stub interface{} + + app := fxtest.New(t, + fx.Provide( + fx.Annotate( + func() (stub, error) { return nil, nil }, + fx.OnStart(func(context.Context) error { + return errors.New("hook failed") + }), + ), + ), + fx.Invoke(func(stub) {}), + ) + + err := app.Start(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "hook failed") + }) + + t.Run("error with multiple hooks of the same type", func(t *testing.T) { + t.Parallel() + + type stub interface{} + + err := validateApp(t, + fx.Provide( + fx.Annotate( + func() stub { return nil }, + fx.OnStart(func(context.Context) error { return nil }), + fx.OnStart(func(context.Context) error { return nil }), + ), + ), + fx.Invoke(func(s stub) {}), + ) + + require.Error(t, err) + require.Contains(t, err.Error(), "cannot apply more than one") + }) + + t.Run("with Supply", func(t *testing.T) { + t.Parallel() + + type ( + A interface { + WriteString(string) (int, error) + } + ) + + buf := bytes.NewBuffer(nil) + cotr := fx.Provide( + fx.Annotate( + func() A { + return buf + }, + fx.OnStart(func(_ context.Context, a A, s fmt.Stringer) error { + a.WriteString(s.String()) + return nil + }), + ), + ) + + supply := fx.Supply( + fx.Annotate( + &asStringer{"supply"}, + fx.OnStart(func(_ context.Context) error { + return nil + }), + fx.As(new(fmt.Stringer)), + ), + ) + + opts := fx.Options( + cotr, + supply, + fx.Invoke(func(A) {}), + ) + + app := fxtest.New(t, opts) + ctx := context.Background() + err := app.Start(ctx) + require.NoError(t, err) + require.NoError(t, app.Stop(ctx)) + require.Equal(t, "supply", buf.String()) + }) + + t.Run("with Decorate", func(t *testing.T) { + t.Parallel() + + type ( + A interface { + WriteString(string) (int, error) + } + ) + + buf := bytes.NewBuffer(nil) + cotr := fx.Provide(func() A { return buf }) + + var called bool + decorated := fx.Decorate( + fx.Annotate( + func(in A) A { + in.WriteString("decorated") + return in + }, + fx.OnStart(func(_ context.Context, a A) error { + // assert that the interface we get is the decorated one + called = assert.Equal(t, "decorated", buf.String()) + return nil + }), + ), + ) + + opts := fx.Options( + cotr, + decorated, + fx.Invoke(func(A) {}), + ) + + app := fxtest.New(t, opts) + ctx := context.Background() + err := app.Start(ctx) + require.NoError(t, err) + require.NoError(t, app.Stop(ctx)) + require.True(t, called) + require.Equal(t, "decorated", buf.String()) + }) + + t.Run("with Supply and Decorate", func(t *testing.T) { + t.Parallel() + + type ( + A interface{} + ) + + ch := make(chan string, 3) + + cotr := fx.Provide( + fx.Annotate( + func() A { return nil }, + fx.OnStart(func(_ context.Context, s fmt.Stringer) error { + ch <- "constructor" + fmt.Printf("executing!\n") + require.Equal(t, "supply", s.String()) + return nil + }), + ), + ) + + supply := fx.Supply( + fx.Annotate( + &asStringer{"supply"}, + fx.OnStart(func(_ context.Context) error { + ch <- "supply" + return nil + }), + fx.As(new(fmt.Stringer)), + ), + ) + + decorated := fx.Decorate( + fx.Annotate( + func(in A) A { return in }, + fx.OnStart(func(_ context.Context) error { + ch <- "decorated" + return nil + }), + ), + ) + + opts := fx.Options( + cotr, + supply, + decorated, + fx.Invoke(func(A) {}), + ) + + app := fxtest.New(t, opts) + ctx := context.Background() + err := app.Start(ctx) + require.NoError(t, err) + require.NoError(t, app.Stop(ctx)) + close(ch) + + require.Equal(t, "supply", <-ch) + require.Equal(t, "constructor", <-ch) + require.Equal(t, "decorated", <-ch) + }) + + t.Run("with nil target", func(t *testing.T) { + type A interface{} + err := validateApp(t, + fx.Provide( + fx.Annotate( + func() A { return nil }, + fx.OnStart(nil), + ), + ), + ) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot use nil function") + }) + + t.Run("with non-func target", func(t *testing.T) { + type A interface{} + err := validateApp(t, + fx.Provide( + fx.Annotate( + func() A { return nil }, + fx.OnStart(&struct{}{}), + ), + ), + ) + require.Error(t, err) + require.Contains(t, err.Error(), "must provide function") + }) + + t.Run("without context parameter", func(t *testing.T) { + type A interface{} + err := validateApp(t, + fx.Provide( + fx.Annotate( + func() A { return nil }, + fx.OnStart(func() {}), + ), + ), + ) + require.Error(t, err) + require.Contains(t, err.Error(), "must be context.Context") + }) + + t.Run("with variatic hook", func(t *testing.T) { + type A interface{} + err := validateApp(t, + fx.Provide( + fx.Annotate( + func() A { return nil }, + fx.OnStart(func(context.Context, ...A) error { + return nil + }), + ), + ), + ) + require.Error(t, err) + require.Contains(t, err.Error(), "must not accept variatic") + }) + + t.Run("without returning error", func(t *testing.T) { + type A interface{} + err := validateApp(t, + fx.Provide( + fx.Annotate( + func() A { return nil }, + fx.OnStart(func(context.Context) {}), + ), + ), + ) + require.Error(t, err) + require.Contains(t, err.Error(), "must return only an error") + }) + + t.Run("with constructor error", func(t *testing.T) { + type A interface{} + app := fx.New( + fx.Provide( + fx.Annotate( + func() (A, error) { + return nil, errors.New("hooks should not be installed") + }, + fx.OnStart(func(context.Context) error { + require.FailNow(t, "hook should not be called") + return nil + }), + ), + ), + fx.Invoke(func(A) {}), + ) + + err := app.Start(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "hooks should not be installed") + }) } From f4cf21be6706731e117516889149b4e3628a5184 Mon Sep 17 00:00:00 2001 From: jmills Date: Tue, 5 Jul 2022 21:45:28 +0000 Subject: [PATCH 02/16] resolve non-indent and table comments --- annotated.go | 8 +-- annotated_test.go | 121 +++++++++++++++------------------------------- 2 files changed, 40 insertions(+), 89 deletions(-) diff --git a/annotated.go b/annotated.go index 16a7a5ced..c2076e323 100644 --- a/annotated.go +++ b/annotated.go @@ -306,7 +306,6 @@ func (la *lifecycleHookAnnotation) resolveMap(results []reflect.Type) ( } } - fmt.Printf("Result map %+v\n", resultMap) return } @@ -319,7 +318,7 @@ func (la *lifecycleHookAnnotation) resolveLifecycleParamField( if param.Kind() == reflect.Struct { nf := param.NumField() if n <= nf { - value = param.FieldByName(fmt.Sprintf("Field%d", n-1)) + value = param.FieldByName(fmt.Sprintf("Field%d", n)) } } @@ -375,10 +374,7 @@ func (la *lifecycleHookAnnotation) parameters(results ...reflect.Type) ( } params = append(params, field) - resolver := func(v reflect.Value, pos int) (value reflect.Value) { - value = la.resolveLifecycleParamField(v, i) - return - } + resolver := la.resolveLifecycleParamField resolverIdx = append(resolverIdx, argSource{ pos: i, diff --git a/annotated_test.go b/annotated_test.go index 63634b534..5851c2cf0 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -985,10 +985,9 @@ func TestAnnotate(t *testing.T) { assert.Contains(t, err.Error(), "fx.In structs cannot be annotated") }) - t.Run("Hooks", testHookAnnotations) } -func testHookAnnotations(t *testing.T) { +func TestHookAnnotations(t *testing.T) { t.Parallel() validateApp := func(t *testing.T, opts ...fx.Option) error { @@ -998,7 +997,6 @@ func testHookAnnotations(t *testing.T) { } t.Run("depend on result interface of target", func(t *testing.T) { - //t.Skip() type stub interface { String() string } @@ -1014,13 +1012,7 @@ func testHookAnnotations(t *testing.T) { }, fx.OnStart(func(_ context.Context, s stub) error { started = true - if !assert.Equal(t, "expected", s.String()) { - return fmt.Errorf( - "expected %q got %q", - "expected", - s.String(), - ) - } + require.Equal(t, "expected", s.String()) return nil }), ), @@ -1098,8 +1090,7 @@ func testHookAnnotations(t *testing.T) { return nil }), ), - ), - fx.Provide( + fx.Annotate( func() (b, error) { return nil, nil }, fx.OnStart(func(context.Context) error { @@ -1107,8 +1098,9 @@ func testHookAnnotations(t *testing.T) { return nil }), ), + + func(a, b) c { return nil }, ), - fx.Provide(func(a, b) c { return nil }), fx.Invoke(func(c) {}), ) @@ -1120,75 +1112,41 @@ func testHookAnnotations(t *testing.T) { assert.True(t, bHook) require.NoError(t, app.Stop(ctx)) }) - - t.Run("with extra dependency parameter", func(t *testing.T) { - t.Parallel() - - type ( - a interface{} - b interface{} - c interface{} - ) - - var aHook bool - - app := fxtest.New(t, - fx.Provide( - fx.Annotate( - func() (a, error) { return nil, nil }, - fx.OnStart(func(context.Context, b) error { - aHook = true - return nil - }), - ), - ), - fx.Provide(func() b { return nil }), - fx.Provide(func(a, b) c { return nil }), - fx.Invoke(func(c) {}), - ) - - ctx := context.Background() - assert.False(t, aHook) - require.NoError(t, app.Start(ctx)) - defer func() { - require.NoError(t, app.Stop(ctx)) - }() - assert.True(t, aHook) - }) - t.Run("with multiple extra dependency parameters", func(t *testing.T) { t.Parallel() type ( - a interface{} - b interface{} - c interface{} + A interface{} + B interface{} + C interface{} ) - var aHook bool + var value int app := fxtest.New(t, fx.Provide( fx.Annotate( - func() (a, error) { return nil, nil }, - fx.OnStart(func(context.Context, b, c) error { - aHook = true + func() (A, error) { return nil, nil }, + fx.OnStart(func(_ context.Context, b B, c C) error { + b1, _ := b.(int) + c1, _ := c.(int) + value = b1 + c1 return nil }), ), ), - fx.Provide(func() b { return nil }), - fx.Provide(func() c { return nil }), - fx.Invoke(func(a) {}), + fx.Provide(func() B { return int(1) }), + fx.Provide(func() C { return int(2) }), + fx.Invoke(func(A) {}), ) ctx := context.Background() - assert.False(t, aHook) + assert.Zero(t, value) require.NoError(t, app.Start(ctx)) defer func() { require.NoError(t, app.Stop(ctx)) }() - assert.True(t, aHook) + assert.Equal(t, 3, value) }) t.Run("with unprovided dependency", func(t *testing.T) { @@ -1254,20 +1212,19 @@ func testHookAnnotations(t *testing.T) { ) require.Error(t, err) - require.Contains(t, err.Error(), "cannot apply more than one") + require.Contains(t, err.Error(), "cannot apply more than one OnStart hook annotation") }) t.Run("with Supply", func(t *testing.T) { t.Parallel() - type ( - A interface { - WriteString(string) (int, error) - } - ) + type A interface { + WriteString(string) (int, error) + } buf := bytes.NewBuffer(nil) - cotr := fx.Provide( + var called bool + ctor := fx.Provide( fx.Annotate( func() A { return buf @@ -1282,7 +1239,8 @@ func testHookAnnotations(t *testing.T) { supply := fx.Supply( fx.Annotate( &asStringer{"supply"}, - fx.OnStart(func(_ context.Context) error { + fx.OnStart(func(context.Context) error { + called = true return nil }), fx.As(new(fmt.Stringer)), @@ -1290,30 +1248,30 @@ func testHookAnnotations(t *testing.T) { ) opts := fx.Options( - cotr, + ctor, supply, fx.Invoke(func(A) {}), ) app := fxtest.New(t, opts) ctx := context.Background() + require.False(t, called) err := app.Start(ctx) require.NoError(t, err) require.NoError(t, app.Stop(ctx)) require.Equal(t, "supply", buf.String()) + require.True(t, called) }) t.Run("with Decorate", func(t *testing.T) { t.Parallel() - type ( - A interface { - WriteString(string) (int, error) - } - ) + type A interface { + WriteString(string) (int, error) + } buf := bytes.NewBuffer(nil) - cotr := fx.Provide(func() A { return buf }) + ctor := fx.Provide(func() A { return buf }) var called bool decorated := fx.Decorate( @@ -1331,7 +1289,7 @@ func testHookAnnotations(t *testing.T) { ) opts := fx.Options( - cotr, + ctor, decorated, fx.Invoke(func(A) {}), ) @@ -1348,18 +1306,15 @@ func testHookAnnotations(t *testing.T) { t.Run("with Supply and Decorate", func(t *testing.T) { t.Parallel() - type ( - A interface{} - ) + type A interface{} ch := make(chan string, 3) - cotr := fx.Provide( + ctor := fx.Provide( fx.Annotate( func() A { return nil }, fx.OnStart(func(_ context.Context, s fmt.Stringer) error { ch <- "constructor" - fmt.Printf("executing!\n") require.Equal(t, "supply", s.String()) return nil }), @@ -1388,7 +1343,7 @@ func testHookAnnotations(t *testing.T) { ) opts := fx.Options( - cotr, + ctor, supply, decorated, fx.Invoke(func(A) {}), From 72cc8a0e93a0f6d49ad1dcbe6131ae7993212e68 Mon Sep 17 00:00:00 2001 From: jmills Date: Tue, 5 Jul 2022 21:52:53 +0000 Subject: [PATCH 03/16] move failure cases to their own test --- annotated_test.go | 202 +++++++++++++++++++++++----------------------- 1 file changed, 103 insertions(+), 99 deletions(-) diff --git a/annotated_test.go b/annotated_test.go index 5851c2cf0..8b0d15d12 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -990,12 +990,6 @@ func TestAnnotate(t *testing.T) { func TestHookAnnotations(t *testing.T) { t.Parallel() - validateApp := func(t *testing.T, opts ...fx.Option) error { - return fx.ValidateApp( - append(opts, fx.Logger(fxtest.NewTestPrinter(t)))..., - ) - } - t.Run("depend on result interface of target", func(t *testing.T) { type stub interface { String() string @@ -1149,72 +1143,6 @@ func TestHookAnnotations(t *testing.T) { assert.Equal(t, 3, value) }) - t.Run("with unprovided dependency", func(t *testing.T) { - t.Parallel() - - type ( - a interface{} - b interface{} - ) - - err := validateApp(t, - fx.Provide( - fx.Annotate( - func() (a, error) { return nil, nil }, - fx.OnStart(func(context.Context, b) error { - return nil - }), - ), - ), - fx.Invoke(func(a) {}), - ) - - require.Error(t, err) - require.Contains(t, err.Error(), "missing type: fx_test.b") - }) - - t.Run("that returns error", func(t *testing.T) { - t.Parallel() - - type stub interface{} - - app := fxtest.New(t, - fx.Provide( - fx.Annotate( - func() (stub, error) { return nil, nil }, - fx.OnStart(func(context.Context) error { - return errors.New("hook failed") - }), - ), - ), - fx.Invoke(func(stub) {}), - ) - - err := app.Start(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "hook failed") - }) - - t.Run("error with multiple hooks of the same type", func(t *testing.T) { - t.Parallel() - - type stub interface{} - - err := validateApp(t, - fx.Provide( - fx.Annotate( - func() stub { return nil }, - fx.OnStart(func(context.Context) error { return nil }), - fx.OnStart(func(context.Context) error { return nil }), - ), - ), - fx.Invoke(func(s stub) {}), - ) - - require.Error(t, err) - require.Contains(t, err.Error(), "cannot apply more than one OnStart hook annotation") - }) - t.Run("with Supply", func(t *testing.T) { t.Parallel() @@ -1361,18 +1289,115 @@ func TestHookAnnotations(t *testing.T) { require.Equal(t, "decorated", <-ch) }) - t.Run("with nil target", func(t *testing.T) { +} + +func TestHookAnnotationFailures(t *testing.T) { + validateApp := func(t *testing.T, opts ...fx.Option) error { + return fx.ValidateApp( + append(opts, fx.Logger(fxtest.NewTestPrinter(t)))..., + ) + } + + t.Run("with unprovided dependency", func(t *testing.T) { + t.Parallel() + + type ( + a interface{} + b interface{} + ) + + err := validateApp(t, + fx.Provide( + fx.Annotate( + func() (a, error) { return nil, nil }, + fx.OnStart(func(context.Context, b) error { + return nil + }), + ), + ), + fx.Invoke(func(a) {}), + ) + + require.Error(t, err) + require.Contains(t, err.Error(), "missing type: fx_test.b") + }) + + t.Run("that returns error", func(t *testing.T) { + t.Parallel() + + type stub interface{} + + app := fxtest.New(t, + fx.Provide( + fx.Annotate( + func() (stub, error) { return nil, nil }, + fx.OnStart(func(context.Context) error { + return errors.New("hook failed") + }), + ), + ), + fx.Invoke(func(stub) {}), + ) + + err := app.Start(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "hook failed") + }) + + t.Run("error with multiple hooks of the same type", func(t *testing.T) { + t.Parallel() + + type stub interface{} + + err := validateApp(t, + fx.Provide( + fx.Annotate( + func() stub { return nil }, + fx.OnStart(func(context.Context) error { return nil }), + fx.OnStart(func(context.Context) error { return nil }), + ), + ), + fx.Invoke(func(s stub) {}), + ) + + require.Error(t, err) + require.Contains(t, err.Error(), "cannot apply more than one OnStart hook annotation") + }) + + t.Run("without returning error", func(t *testing.T) { type A interface{} err := validateApp(t, fx.Provide( fx.Annotate( func() A { return nil }, - fx.OnStart(nil), + fx.OnStart(func(context.Context) {}), ), ), ) require.Error(t, err) - require.Contains(t, err.Error(), "cannot use nil function") + require.Contains(t, err.Error(), "must return only an error") + }) + + t.Run("with constructor error", func(t *testing.T) { + type A interface{} + app := fx.New( + fx.Provide( + fx.Annotate( + func() (A, error) { + return nil, errors.New("hooks should not be installed") + }, + fx.OnStart(func(context.Context) error { + require.FailNow(t, "hook should not be called") + return nil + }), + ), + ), + fx.Invoke(func(A) {}), + ) + + err := app.Start(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "hooks should not be installed") }) t.Run("with non-func target", func(t *testing.T) { @@ -1419,39 +1444,18 @@ func TestHookAnnotations(t *testing.T) { require.Contains(t, err.Error(), "must not accept variatic") }) - t.Run("without returning error", func(t *testing.T) { + t.Run("with nil target", func(t *testing.T) { type A interface{} err := validateApp(t, fx.Provide( fx.Annotate( func() A { return nil }, - fx.OnStart(func(context.Context) {}), + fx.OnStart(nil), ), ), ) require.Error(t, err) - require.Contains(t, err.Error(), "must return only an error") + require.Contains(t, err.Error(), "cannot use nil function") }) - t.Run("with constructor error", func(t *testing.T) { - type A interface{} - app := fx.New( - fx.Provide( - fx.Annotate( - func() (A, error) { - return nil, errors.New("hooks should not be installed") - }, - fx.OnStart(func(context.Context) error { - require.FailNow(t, "hook should not be called") - return nil - }), - ), - ), - fx.Invoke(func(A) {}), - ) - - err := app.Start(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "hooks should not be installed") - }) } From 740afda19821750215d9beea2bc684d12c418c20 Mon Sep 17 00:00:00 2001 From: jmills Date: Tue, 5 Jul 2022 22:12:17 +0000 Subject: [PATCH 04/16] move failure cases to their own test --- annotated_test.go | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/annotated_test.go b/annotated_test.go index 8b0d15d12..9ee8a4198 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -1292,12 +1292,39 @@ func TestHookAnnotations(t *testing.T) { } func TestHookAnnotationFailures(t *testing.T) { + t.Parallel() validateApp := func(t *testing.T, opts ...fx.Option) error { return fx.ValidateApp( append(opts, fx.Logger(fxtest.NewTestPrinter(t)))..., ) } + table := []struct { + name string + opts fx.Option + validateApp bool + errContains string + }{} + + for _, tt := range table { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if tt.validateApp { + err := validateApp(t, tt.opts) + require.Error(t, err) + require.Contains(t, err.Error(), tt.errContains) + return + } + + app := fx.New(tt.opts) + ctx := context.Background() + err := app.Start(ctx) + require.Error(t, err) + require.Contains(t, err.Error(), tt.errContains) + }) + } + t.Run("with unprovided dependency", func(t *testing.T) { t.Parallel() From d797a0797c824ce79006b4932329f7eee9a8e61e Mon Sep 17 00:00:00 2001 From: jmills Date: Tue, 5 Jul 2022 22:35:39 +0000 Subject: [PATCH 05/16] refactor failure hook tests to table and clean up indentation --- annotated_test.go | 497 ++++++++++++++++++++-------------------------- 1 file changed, 217 insertions(+), 280 deletions(-) diff --git a/annotated_test.go b/annotated_test.go index 9ee8a4198..109064964 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -997,20 +997,20 @@ func TestHookAnnotations(t *testing.T) { var started bool + hook := fx.Annotate( + func() (stub, error) { + b := []byte("expected") + return bytes.NewBuffer(b), nil + }, + fx.OnStart(func(_ context.Context, s stub) error { + started = true + require.Equal(t, "expected", s.String()) + return nil + }), + ) + app := fxtest.New(t, - fx.Provide( - fx.Annotate( - func() (stub, error) { - b := []byte("expected") - return bytes.NewBuffer(b), nil - }, - fx.OnStart(func(_ context.Context, s stub) error { - started = true - require.Equal(t, "expected", s.String()) - return nil - }), - ), - ), + fx.Provide(hook), fx.Invoke(func(s stub) { require.Equal(t, "expected", s.String()) }), @@ -1034,20 +1034,20 @@ func TestHookAnnotations(t *testing.T) { stopped bool ) + hook := fx.Annotate( + func() (stub, error) { return nil, nil }, + fx.OnStart(func(context.Context) error { + started = true + return nil + }), + fx.OnStop(func(context.Context) error { + stopped = true + return nil + }), + ) + app := fxtest.New(t, - fx.Provide( - fx.Annotate( - func() (stub, error) { return nil, nil }, - fx.OnStart(func(context.Context) error { - started = true - return nil - }), - fx.OnStop(func(context.Context) error { - stopped = true - return nil - }), - ), - ), + fx.Provide(hook), fx.Invoke(func(s stub) { invoked = s == nil }), @@ -1075,29 +1075,28 @@ func TestHookAnnotations(t *testing.T) { var aHook, bHook bool - app := fxtest.New(t, - fx.Provide( - fx.Annotate( - func() (a, error) { return nil, nil }, - fx.OnStart(func(context.Context) error { - aHook = true - return nil - }), - ), - - fx.Annotate( - func() (b, error) { return nil, nil }, - fx.OnStart(func(context.Context) error { - bHook = true - return nil - }), - ), + provided := fx.Provide( + fx.Annotate( + func() (a, error) { return nil, nil }, + fx.OnStart(func(context.Context) error { + aHook = true + return nil + }), + ), - func(a, b) c { return nil }, + fx.Annotate( + func() (b, error) { return nil, nil }, + fx.OnStart(func(context.Context) error { + bHook = true + return nil + }), ), - fx.Invoke(func(c) {}), + + func(a, b) c { return nil }, ) + app := fxtest.New(t, provided, fx.Invoke(func(c) {})) + ctx := context.Background() assert.False(t, aHook) assert.False(t, bHook) @@ -1117,18 +1116,18 @@ func TestHookAnnotations(t *testing.T) { var value int + hook := fx.Annotate( + func() (A, error) { return nil, nil }, + fx.OnStart(func(_ context.Context, b B, c C) error { + b1, _ := b.(int) + c1, _ := c.(int) + value = b1 + c1 + return nil + }), + ) + app := fxtest.New(t, - fx.Provide( - fx.Annotate( - func() (A, error) { return nil, nil }, - fx.OnStart(func(_ context.Context, b B, c C) error { - b1, _ := b.(int) - c1, _ := c.(int) - value = b1 + c1 - return nil - }), - ), - ), + fx.Provide(hook), fx.Provide(func() B { return int(1) }), fx.Provide(func() C { return int(2) }), fx.Invoke(func(A) {}), @@ -1152,29 +1151,30 @@ func TestHookAnnotations(t *testing.T) { buf := bytes.NewBuffer(nil) var called bool - ctor := fx.Provide( - fx.Annotate( - func() A { - return buf - }, - fx.OnStart(func(_ context.Context, a A, s fmt.Stringer) error { - a.WriteString(s.String()) - return nil - }), - ), + + hook := fx.Annotate( + func() A { + return buf + }, + fx.OnStart(func(_ context.Context, a A, s fmt.Stringer) error { + a.WriteString(s.String()) + return nil + }), ) - supply := fx.Supply( - fx.Annotate( - &asStringer{"supply"}, - fx.OnStart(func(context.Context) error { - called = true - return nil - }), - fx.As(new(fmt.Stringer)), - ), + ctor := fx.Provide(hook) + + hook = fx.Annotate( + &asStringer{"supply"}, + fx.OnStart(func(context.Context) error { + called = true + return nil + }), + fx.As(new(fmt.Stringer)), ) + supply := fx.Supply(hook) + opts := fx.Options( ctor, supply, @@ -1202,20 +1202,21 @@ func TestHookAnnotations(t *testing.T) { ctor := fx.Provide(func() A { return buf }) var called bool - decorated := fx.Decorate( - fx.Annotate( - func(in A) A { - in.WriteString("decorated") - return in - }, - fx.OnStart(func(_ context.Context, a A) error { - // assert that the interface we get is the decorated one - called = assert.Equal(t, "decorated", buf.String()) - return nil - }), - ), + + hook := fx.Annotate( + func(in A) A { + in.WriteString("decorated") + return in + }, + fx.OnStart(func(_ context.Context, a A) error { + // assert that the interface we get is the decorated one + called = assert.Equal(t, "decorated", buf.String()) + return nil + }), ) + decorated := fx.Decorate(hook) + opts := fx.Options( ctor, decorated, @@ -1238,38 +1239,38 @@ func TestHookAnnotations(t *testing.T) { ch := make(chan string, 3) - ctor := fx.Provide( - fx.Annotate( - func() A { return nil }, - fx.OnStart(func(_ context.Context, s fmt.Stringer) error { - ch <- "constructor" - require.Equal(t, "supply", s.String()) - return nil - }), - ), + hook := fx.Annotate( + func() A { return nil }, + fx.OnStart(func(_ context.Context, s fmt.Stringer) error { + ch <- "constructor" + require.Equal(t, "supply", s.String()) + return nil + }), ) - supply := fx.Supply( - fx.Annotate( - &asStringer{"supply"}, - fx.OnStart(func(_ context.Context) error { - ch <- "supply" - return nil - }), - fx.As(new(fmt.Stringer)), - ), + ctor := fx.Provide(hook) + + hook = fx.Annotate( + &asStringer{"supply"}, + fx.OnStart(func(_ context.Context) error { + ch <- "supply" + return nil + }), + fx.As(new(fmt.Stringer)), ) - decorated := fx.Decorate( - fx.Annotate( - func(in A) A { return in }, - fx.OnStart(func(_ context.Context) error { - ch <- "decorated" - return nil - }), - ), + supply := fx.Supply(hook) + + hook = fx.Annotate( + func(in A) A { return in }, + fx.OnStart(func(_ context.Context) error { + ch <- "decorated" + return nil + }), ) + decorated := fx.Decorate(hook) + opts := fx.Options( ctor, supply, @@ -1299,190 +1300,126 @@ func TestHookAnnotationFailures(t *testing.T) { ) } + type ( + A interface{} + B interface{} + ) + table := []struct { name string - opts fx.Option - validateApp bool + annotation interface{} + useNew bool errContains string - }{} + }{ + { + name: "with unprovided dependency", + errContains: "missing type: fx_test.B", + annotation: fx.Annotate( + func() A { return nil }, + fx.OnStart(func(context.Context, B) error { + return nil + }), + ), + }, + { + name: "with hook that errors", + errContains: "hook failed", + useNew: true, + annotation: fx.Annotate( + func() (A, error) { return nil, nil }, + fx.OnStart(func(context.Context) error { + return errors.New("hook failed") + }), + ), + }, + { + name: "with with multiple hooks of the same type", + errContains: "cannot apply more than one OnStart hook annotation", + annotation: fx.Annotate( + func() A { return nil }, + fx.OnStart(func(context.Context) error { return nil }), + fx.OnStart(func(context.Context) error { return nil }), + ), + }, + { + name: "with hook that doesn't return an error", + errContains: "must return only an error", + annotation: fx.Annotate( + func() A { return nil }, + fx.OnStart(func(context.Context) {}), + ), + }, + { + name: "with constructor that errors", + errContains: "hooks should not be installed", + useNew: true, + annotation: fx.Annotate( + func() (A, error) { + return nil, errors.New("hooks should not be installed") + }, + fx.OnStart(func(context.Context) error { + require.FailNow(t, "hook should not be called") + return nil + }), + ), + }, + { + name: "without a function target", + errContains: "must provide function", + annotation: fx.Annotate( + func() A { return nil }, + fx.OnStart(&struct{}{}), + ), + }, + { + name: "without context.Context as first parameter", + errContains: "must be context.Context", + annotation: fx.Annotate( + func() A { return nil }, + fx.OnStart(func() {}), + ), + }, + { + name: "with variactic hook", + errContains: "must not accept variatic", + annotation: fx.Annotate( + func() A { return nil }, + fx.OnStart(func(context.Context, ...A) error { + return nil + }), + ), + }, + { + name: "with nil hook target", + errContains: "cannot use nil function", + annotation: fx.Annotate( + func() A { return nil }, + fx.OnStart(nil), + ), + }, + } for _, tt := range table { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - if tt.validateApp { - err := validateApp(t, tt.opts) + opts := fx.Options( + fx.Provide(tt.annotation), + fx.Invoke(func(A) {}), + ) + + if !tt.useNew { + err := validateApp(t, opts) require.Error(t, err) require.Contains(t, err.Error(), tt.errContains) return } - app := fx.New(tt.opts) + app := fx.New(opts) ctx := context.Background() err := app.Start(ctx) require.Error(t, err) require.Contains(t, err.Error(), tt.errContains) }) } - - t.Run("with unprovided dependency", func(t *testing.T) { - t.Parallel() - - type ( - a interface{} - b interface{} - ) - - err := validateApp(t, - fx.Provide( - fx.Annotate( - func() (a, error) { return nil, nil }, - fx.OnStart(func(context.Context, b) error { - return nil - }), - ), - ), - fx.Invoke(func(a) {}), - ) - - require.Error(t, err) - require.Contains(t, err.Error(), "missing type: fx_test.b") - }) - - t.Run("that returns error", func(t *testing.T) { - t.Parallel() - - type stub interface{} - - app := fxtest.New(t, - fx.Provide( - fx.Annotate( - func() (stub, error) { return nil, nil }, - fx.OnStart(func(context.Context) error { - return errors.New("hook failed") - }), - ), - ), - fx.Invoke(func(stub) {}), - ) - - err := app.Start(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "hook failed") - }) - - t.Run("error with multiple hooks of the same type", func(t *testing.T) { - t.Parallel() - - type stub interface{} - - err := validateApp(t, - fx.Provide( - fx.Annotate( - func() stub { return nil }, - fx.OnStart(func(context.Context) error { return nil }), - fx.OnStart(func(context.Context) error { return nil }), - ), - ), - fx.Invoke(func(s stub) {}), - ) - - require.Error(t, err) - require.Contains(t, err.Error(), "cannot apply more than one OnStart hook annotation") - }) - - t.Run("without returning error", func(t *testing.T) { - type A interface{} - err := validateApp(t, - fx.Provide( - fx.Annotate( - func() A { return nil }, - fx.OnStart(func(context.Context) {}), - ), - ), - ) - require.Error(t, err) - require.Contains(t, err.Error(), "must return only an error") - }) - - t.Run("with constructor error", func(t *testing.T) { - type A interface{} - app := fx.New( - fx.Provide( - fx.Annotate( - func() (A, error) { - return nil, errors.New("hooks should not be installed") - }, - fx.OnStart(func(context.Context) error { - require.FailNow(t, "hook should not be called") - return nil - }), - ), - ), - fx.Invoke(func(A) {}), - ) - - err := app.Start(context.Background()) - require.Error(t, err) - require.Contains(t, err.Error(), "hooks should not be installed") - }) - - t.Run("with non-func target", func(t *testing.T) { - type A interface{} - err := validateApp(t, - fx.Provide( - fx.Annotate( - func() A { return nil }, - fx.OnStart(&struct{}{}), - ), - ), - ) - require.Error(t, err) - require.Contains(t, err.Error(), "must provide function") - }) - - t.Run("without context parameter", func(t *testing.T) { - type A interface{} - err := validateApp(t, - fx.Provide( - fx.Annotate( - func() A { return nil }, - fx.OnStart(func() {}), - ), - ), - ) - require.Error(t, err) - require.Contains(t, err.Error(), "must be context.Context") - }) - - t.Run("with variatic hook", func(t *testing.T) { - type A interface{} - err := validateApp(t, - fx.Provide( - fx.Annotate( - func() A { return nil }, - fx.OnStart(func(context.Context, ...A) error { - return nil - }), - ), - ), - ) - require.Error(t, err) - require.Contains(t, err.Error(), "must not accept variatic") - }) - - t.Run("with nil target", func(t *testing.T) { - type A interface{} - err := validateApp(t, - fx.Provide( - fx.Annotate( - func() A { return nil }, - fx.OnStart(nil), - ), - ), - ) - require.Error(t, err) - require.Contains(t, err.Error(), "cannot use nil function") - }) - } From 994442022f58adb6b7ba69217295cb0d17055814 Mon Sep 17 00:00:00 2001 From: jmills Date: Tue, 5 Jul 2022 22:38:12 +0000 Subject: [PATCH 06/16] move stringification of life cycle hook type to stringer --- annotated.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/annotated.go b/annotated.go index c2076e323..6c042dc3a 100644 --- a/annotated.go +++ b/annotated.go @@ -209,7 +209,7 @@ type lifecycleHookAnnotation struct { Target interface{} } -func (la *lifecycleHookAnnotation) apply(ann *annotated) error { +func (la *lifecycleHookAnnotation) String() string { name := "UnknownHookAnnotation" switch la.Type { case _onStartHookType: @@ -217,11 +217,14 @@ func (la *lifecycleHookAnnotation) apply(ann *annotated) error { case _onStopHookType: name = _onStopHook } + return name +} +func (la *lifecycleHookAnnotation) apply(ann *annotated) error { if la.Target == nil { return fmt.Errorf( "cannot use nil function for %v hook annotation", - name, + la, ) } @@ -229,7 +232,7 @@ func (la *lifecycleHookAnnotation) apply(ann *annotated) error { if la.Type == h.Type { return fmt.Errorf( "cannot apply more than one %v hook annotation", - name, + la, ) } } From a6a6ade5935ec54693f50c4634706a940918023a Mon Sep 17 00:00:00 2001 From: jmills Date: Tue, 5 Jul 2022 22:41:43 +0000 Subject: [PATCH 07/16] update godoc --- annotated.go | 18 +++++++++++------- annotated_test.go | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/annotated.go b/annotated.go index 6c042dc3a..f2288dabd 100644 --- a/annotated.go +++ b/annotated.go @@ -223,7 +223,7 @@ func (la *lifecycleHookAnnotation) String() string { func (la *lifecycleHookAnnotation) apply(ann *annotated) error { if la.Target == nil { return fmt.Errorf( - "cannot use nil function for %v hook annotation", + "cannot use nil function for %q hook annotation", la, ) } @@ -231,7 +231,7 @@ func (la *lifecycleHookAnnotation) apply(ann *annotated) error { for _, h := range ann.Hooks { if la.Type == h.Type { return fmt.Errorf( - "cannot apply more than one %v hook annotation", + "cannot apply more than one %q hook annotation", la, ) } @@ -240,7 +240,8 @@ func (la *lifecycleHookAnnotation) apply(ann *annotated) error { ft := la.targetType() if ft.Kind() != reflect.Func { return fmt.Errorf( - "must provide function for hook, got %v (%T)", + "must provide function for %q hook, got %v (%T)", + la, la.Target, la.Target, ) @@ -319,8 +320,7 @@ func (la *lifecycleHookAnnotation) resolveLifecycleParamField( value reflect.Value, ) { if param.Kind() == reflect.Struct { - nf := param.NumField() - if n <= nf { + if n <= param.NumField() { value = param.FieldByName(fmt.Sprintf("Field%d", n)) } } @@ -471,7 +471,9 @@ func (la *lifecycleHookAnnotation) Build(results ...reflect.Type) (reflect.Value // }), // ) // -// Only one OnStart annotation may be applied to a given function at a time. +// Only one OnStart annotation may be applied to a given function at a time, +// however functions may be annotated with other types of lifecylce Hooks, such +// as OnStop. func OnStart(onStart interface{}) Annotation { return &lifecycleHookAnnotation{ Type: _onStartHookType, @@ -491,7 +493,9 @@ func OnStart(onStart interface{}) Annotation { // }), // ) // -// Only one OnStop annotation may be applied to a given function at a time. +// Only one OnStop annotation may be applied to a given function at a time, +// however functions may be annotated with other types of lifecylce Hooks, such +// as OnStart. func OnStop(onStop interface{}) Annotation { return &lifecycleHookAnnotation{ Type: _onStopHookType, diff --git a/annotated_test.go b/annotated_test.go index 109064964..e295609c2 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -1334,7 +1334,7 @@ func TestHookAnnotationFailures(t *testing.T) { }, { name: "with with multiple hooks of the same type", - errContains: "cannot apply more than one OnStart hook annotation", + errContains: "cannot apply more than one \"OnStart\" hook annotation", annotation: fx.Annotate( func() A { return nil }, fx.OnStart(func(context.Context) error { return nil }), From 4160f0bc606bdea57d4d513c027aa78ed4f74b09 Mon Sep 17 00:00:00 2001 From: jmills Date: Tue, 5 Jul 2022 22:43:31 +0000 Subject: [PATCH 08/16] remove hook annotation Build() error --- annotated.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/annotated.go b/annotated.go index f2288dabd..92738fc3e 100644 --- a/annotated.go +++ b/annotated.go @@ -427,7 +427,7 @@ func (la *lifecycleHookAnnotation) buildHook(fn func(context.Context) error) (ho return } -func (la *lifecycleHookAnnotation) Build(results ...reflect.Type) (reflect.Value, error) { +func (la *lifecycleHookAnnotation) Build(results ...reflect.Type) reflect.Value { in, paramMap := la.parameters(results...) params := []reflect.Type{in} for _, r := range results { @@ -456,7 +456,7 @@ func (la *lifecycleHookAnnotation) Build(results ...reflect.Type) (reflect.Value return []reflect.Value{} }) - return newFn, nil + return newFn } // OnStart is an Annotation that appends an OnStart Hook to the application @@ -612,9 +612,7 @@ func (ann *annotated) Build() (interface{}, error) { var hooks []reflect.Value for _, hook := range ann.Hooks { - if hookFn, err := hook.Build(resultTypes...); err == nil { - hooks = append(hooks, hookFn) - } + hooks = append(hooks, hook.Build(resultTypes...)) } newFnType := reflect.FuncOf(paramTypes, resultTypes, false) From 97694734ecd0c6062c46ff9067cd3ef6e3e06666 Mon Sep 17 00:00:00 2001 From: jmills Date: Tue, 5 Jul 2022 22:57:28 +0000 Subject: [PATCH 09/16] address style and coverage --- annotated.go | 13 ++++++------- annotated_test.go | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/annotated.go b/annotated.go index 92738fc3e..ad2f8562d 100644 --- a/annotated.go +++ b/annotated.go @@ -611,8 +611,8 @@ func (ann *annotated) Build() (interface{}, error) { paramTypes, remapParams, hookParams := ann.parameters(resultTypes...) var hooks []reflect.Value - for _, hook := range ann.Hooks { - hooks = append(hooks, hook.Build(resultTypes...)) + for _, hookBuilder := range ann.Hooks { + hooks = append(hooks, hookBuilder.Build(resultTypes...)) } newFnType := reflect.FuncOf(paramTypes, resultTypes, false) @@ -631,19 +631,18 @@ func (ann *annotated) Build() (interface{}, error) { } results = remapResults(results) - // if the results are greater than zero and the final result + // if the number of results is greater than zero and the final result // is a non-nil error, do not execute hook installers hasErrorResult := len(results) > 0 && results[len(results)-1].Type() == _typeOfError if hasErrorResult { - err, ok := results[len(results)-1].Interface().(error) - if ok && err != nil { + if err, ok := results[len(results)-1].Interface().(error); ok && err != nil { return results } } - for i, hook := range hooks { + for i, hookBuilder := range hooks { hookArgs := hookParams(i, origArgs, results) - hook.Call(hookArgs) + hookBuilder.Call(hookArgs) } return results diff --git a/annotated_test.go b/annotated_test.go index e295609c2..dfb615fdc 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -1394,7 +1394,7 @@ func TestHookAnnotationFailures(t *testing.T) { errContains: "cannot use nil function", annotation: fx.Annotate( func() A { return nil }, - fx.OnStart(nil), + fx.OnStop(nil), ), }, } From 6e84e73f30bbed5f00e6b1ff810516879041b3f5 Mon Sep 17 00:00:00 2001 From: jmills Date: Tue, 5 Jul 2022 22:58:54 +0000 Subject: [PATCH 10/16] address style --- annotated.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/annotated.go b/annotated.go index ad2f8562d..274b831e1 100644 --- a/annotated.go +++ b/annotated.go @@ -730,10 +730,10 @@ func (ann *annotated) parameters(results ...reflect.Type) ( // append required types for hooks to types field, but do not // include them as params in constructor call - for h, t := range ann.Hooks { + for i, t := range ann.Hooks { params, _ := t.parameters(results...) field := reflect.StructField{ - Name: fmt.Sprintf("Hook%d", h), + Name: fmt.Sprintf("Hook%d", i), Type: params, } inFields = append(inFields, field) From 5604c2596d5cb6ee7b5aa6f14b9315fc5bf9e258 Mon Sep 17 00:00:00 2001 From: jmills Date: Tue, 5 Jul 2022 23:15:27 +0000 Subject: [PATCH 11/16] add invoke test case + fix --- annotated.go | 7 ++++++- annotated_test.go | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/annotated.go b/annotated.go index 274b831e1..2d5f8de7b 100644 --- a/annotated.go +++ b/annotated.go @@ -395,7 +395,12 @@ func (la *lifecycleHookAnnotation) parameters(results ...reflect.Type) ( if len(args) != 0 { p := args[0] - results := args[1] + + var results reflect.Value + + if len(args) > 1 { + results = args[1] + } lc, _ = p.FieldByName("Lifecycle").Interface().(Lifecycle) diff --git a/annotated_test.go b/annotated_test.go index dfb615fdc..d9de60c61 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -990,6 +990,24 @@ func TestAnnotate(t *testing.T) { func TestHookAnnotations(t *testing.T) { t.Parallel() + t.Run("with hook on invoke", func(t *testing.T) { + t.Parallel() + + var called bool + hook := fx.Annotate( + func() {}, + fx.OnStart(func(context.Context) error { + called = true + return nil + }), + ) + app := fxtest.New(t, fx.Invoke(hook)) + + require.False(t, called) + require.NoError(t, app.Start(context.Background())) + require.True(t, called) + }) + t.Run("depend on result interface of target", func(t *testing.T) { type stub interface { String() string From bf0496b857be7f95a4573aea97d93b552124d902 Mon Sep 17 00:00:00 2001 From: jmills Date: Tue, 26 Jul 2022 21:45:02 +0000 Subject: [PATCH 12/16] style fixes --- annotated.go | 59 ++++++++++++++++++++-------------------------------- 1 file changed, 23 insertions(+), 36 deletions(-) diff --git a/annotated.go b/annotated.go index 2d5f8de7b..2afd4b7a6 100644 --- a/annotated.go +++ b/annotated.go @@ -237,7 +237,8 @@ func (la *lifecycleHookAnnotation) apply(ann *annotated) error { } } - ft := la.targetType() + ft := reflect.TypeOf(la.Target) + if ft.Kind() != reflect.Func { return fmt.Errorf( "must provide function for %q hook, got %v (%T)", @@ -283,30 +284,20 @@ var ( _typeOfContext reflect.Type = reflect.TypeOf((*context.Context)(nil)).Elem() ) -func (la *lifecycleHookAnnotation) targetType() (targetType reflect.Type) { - return reflect.TypeOf(la.Target) -} +type valueResolver func(reflect.Value, int) reflect.Value func (la *lifecycleHookAnnotation) resolveMap(results []reflect.Type) ( - resultMap map[reflect.Type]struct { - resolve func(reflect.Value, int) reflect.Value - }, + resultMap map[reflect.Type]valueResolver, ) { // index the constructor results by type and position to allow // for us to omit these from the in types that must be injected, // and to allow us to interleave constructor results // into our hook arguments. - resultMap = make(map[reflect.Type]struct { - resolve func(reflect.Value, int) reflect.Value - }, 0) + resultMap = make(map[reflect.Type]valueResolver, len(results)) for _, r := range results { - resultMap[r] = struct { - resolve func(reflect.Value, int) reflect.Value - }{ - resolve: func(v reflect.Value, pos int) (value reflect.Value) { - return v - }, + resultMap[r] = func(v reflect.Value, pos int) (value reflect.Value) { + return v } } @@ -349,24 +340,24 @@ func (la *lifecycleHookAnnotation) parameters(results ...reflect.Type) ( }, } - type valueResolver func(reflect.Value, int) reflect.Value type argSource struct { pos int result bool resolve valueResolver } + ft := reflect.TypeOf(la.Target) resolverIdx := make([]argSource, 1) - ft := la.targetType() + for i := 1; i < ft.NumIn(); i++ { t := ft.In(i) - resultIdx, isProvidedByResults := resultMap[t] + result, isProvidedByResults := resultMap[t] if isProvidedByResults { resolverIdx = append(resolverIdx, argSource{ pos: i, result: true, - resolve: resultIdx.resolve, + resolve: result, }) continue } @@ -377,11 +368,9 @@ func (la *lifecycleHookAnnotation) parameters(results ...reflect.Type) ( } params = append(params, field) - resolver := la.resolveLifecycleParamField - resolverIdx = append(resolverIdx, argSource{ pos: i, - resolve: resolver, + resolve: la.resolveLifecycleParamField, }) } @@ -393,29 +382,25 @@ func (la *lifecycleHookAnnotation) parameters(results ...reflect.Type) ( remapped = make([]reflect.Value, ft.NumIn()) if len(args) != 0 { - - p := args[0] - - var results reflect.Value + var ( + results reflect.Value + p = args[0] + ) if len(args) > 1 { results = args[1] } lc, _ = p.FieldByName("Lifecycle").Interface().(Lifecycle) - for i := 1; i < ft.NumIn(); i++ { resolver := resolverIdx[i] - source := p if resolver.result { source = results } - remapped[i] = resolver.resolve(source, i) } } - return } return @@ -615,9 +600,11 @@ func (ann *annotated) Build() (interface{}, error) { } paramTypes, remapParams, hookParams := ann.parameters(resultTypes...) - var hooks []reflect.Value - for _, hookBuilder := range ann.Hooks { - hooks = append(hooks, hookBuilder.Build(resultTypes...)) + hookFns := make([]reflect.Value, len(ann.Hooks)) + for i, builder := range ann.Hooks { + if hookFn := builder.Build(resultTypes...); !hookFn.IsZero() { + hookFns[i] = hookFn + } } newFnType := reflect.FuncOf(paramTypes, resultTypes, false) @@ -645,9 +632,9 @@ func (ann *annotated) Build() (interface{}, error) { } } - for i, hookBuilder := range hooks { + for i, hookFn := range hookFns { hookArgs := hookParams(i, origArgs, results) - hookBuilder.Call(hookArgs) + hookFn.Call(hookArgs) } return results From b9b87e14a4ec3f4d582e6dde1f79b8f424a8f6f6 Mon Sep 17 00:00:00 2001 From: jmills Date: Tue, 26 Jul 2022 21:46:40 +0000 Subject: [PATCH 13/16] style fixes --- annotated.go | 9 +-------- annotated_test.go | 6 +++++- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/annotated.go b/annotated.go index 2afd4b7a6..be44191ee 100644 --- a/annotated.go +++ b/annotated.go @@ -732,7 +732,6 @@ func (ann *annotated) parameters(results ...reflect.Type) ( } types = []reflect.Type{reflect.StructOf(inFields)} - remap = func(args []reflect.Value) []reflect.Value { params := args[0] args = args[:0] @@ -744,27 +743,21 @@ func (ann *annotated) parameters(results ...reflect.Type) ( hookValueMap = func(hook int, args []reflect.Value, results []reflect.Value) (out []reflect.Value) { params := args[0] - if params.Kind() == reflect.Struct { var zero reflect.Value - value := params.FieldByNameFunc(func(name string) bool { - return name == fmt.Sprintf("Hook%d", hook) - }) + value := params.FieldByName(fmt.Sprintf("Hook%d", hook)) if value != zero { out = append(out, value) } } - for _, r := range results { if r.Type() != _typeOfError { out = append(out, r) } } - return } - return } diff --git a/annotated_test.go b/annotated_test.go index d9de60c61..d06d48c5c 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -994,8 +994,11 @@ func TestHookAnnotations(t *testing.T) { t.Parallel() var called bool + var invoked bool hook := fx.Annotate( - func() {}, + func() { + invoked = true + }, fx.OnStart(func(context.Context) error { called = true return nil @@ -1006,6 +1009,7 @@ func TestHookAnnotations(t *testing.T) { require.False(t, called) require.NoError(t, app.Start(context.Background())) require.True(t, called) + require.True(t, invoked) }) t.Run("depend on result interface of target", func(t *testing.T) { From dbd375e01e7076457a0f7b9b5ef603b36357563b Mon Sep 17 00:00:00 2001 From: jmills Date: Tue, 26 Jul 2022 22:07:58 +0000 Subject: [PATCH 14/16] update docs' --- annotated.go | 43 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/annotated.go b/annotated.go index be44191ee..f9efcc574 100644 --- a/annotated.go +++ b/annotated.go @@ -451,16 +451,29 @@ func (la *lifecycleHookAnnotation) Build(results ...reflect.Type) reflect.Value // OnStart is an Annotation that appends an OnStart Hook to the application // Lifecycle when that function is called. This provides a way to create -// Lifecycle OnStart hooks without building a function that takes a dependency -// on the Lifecycle type. +// Lifecycle OnStart (see Lifecycle type documentation) hooks without building a +// function that takes a dependency on the Lifecycle type. // // fx.Annotate( -// func(...) Server { ... }, +// NewServer, // fx.OnStart(func(ctx context.Context, server Server) error { // return server.Listen(ctx) // }), // ) // +// Which is functionally the same as: +// +// fx.Provide( +// func(lifecycle fx.Lifecycle, p Params) Server { +// server := NewServer(p) +// lifecycle.Append(fx.Hook{ +// OnStart: func(ctx context.Context) error { +// return server.Listen(ctx) +// }, +// }) +// } +// ) +// // Only one OnStart annotation may be applied to a given function at a time, // however functions may be annotated with other types of lifecylce Hooks, such // as OnStop. @@ -473,16 +486,29 @@ func OnStart(onStart interface{}) Annotation { // OnStop is an Annotation that appends an OnStop Hook to the application // Lifecycle when that function is called. This provides a way to create -// Lifecycle OnStop hooks without building a function that takes a dependency -// on the Lifecycle type. +// Lifecycle OnStop (see Lifecycle type documentation) hooks without building a +// function that takes a dependency on the Lifecycle type. // // fx.Annotate( -// func(...) Server { ... }, +// NewServer, // fx.OnStop(func(ctx context.Context, server Server) error { // return server.Shutdown(ctx) // }), // ) // +// Which is functionally the same as: +// +// fx.Provide( +// func(lifecycle fx.Lifecycle, p Params) Server { +// server := NewServer(p) +// lifecycle.Append(fx.Hook{ +// OnStart: func(ctx context.Context) error { +// return server.Shutdown(ctx) +// }, +// }) +// } +// ) +// // Only one OnStop annotation may be applied to a given function at a time, // however functions may be annotated with other types of lifecylce Hooks, such // as OnStart. @@ -672,7 +698,10 @@ func (ann *annotated) typeCheckOrigFn() error { // parameters returns the type for the parameters of the annotated function, // and a function that maps the arguments of the annotated function -// back to the arguments of the target function. +// back to the arguments of the target function and a function that maps +// values to any lifecycle hook annotations. It accepts a variactic set +// of reflect.Type which allows for omitting any resulting constructor types +// from required parameters for annotation hooks. func (ann *annotated) parameters(results ...reflect.Type) ( types []reflect.Type, remap func([]reflect.Value) []reflect.Value, From 8b7e96e52b648c30ce8be4cbd2eef1c73f4374cb Mon Sep 17 00:00:00 2001 From: jmills Date: Tue, 26 Jul 2022 22:23:00 +0000 Subject: [PATCH 15/16] address unit test comments --- annotated_test.go | 142 ++++++++++++++++++---------------------------- 1 file changed, 55 insertions(+), 87 deletions(-) diff --git a/annotated_test.go b/annotated_test.go index d06d48c5c..d259e7f33 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -987,29 +987,53 @@ func TestAnnotate(t *testing.T) { } +func assertApp( + t *testing.T, + app interface { + Start(context.Context) error + Stop(context.Context) error + }, + started *bool, + stopped *bool, + invoked *bool, +) { + t.Helper() + ctx := context.Background() + assert.False(t, *started) + require.NoError(t, app.Start(ctx)) + assert.True(t, *started) + + if invoked != nil { + assert.True(t, *invoked) + } + + if stopped != nil { + assert.False(t, *stopped) + require.NoError(t, app.Stop(ctx)) + assert.True(t, *stopped) + } +} + func TestHookAnnotations(t *testing.T) { t.Parallel() t.Run("with hook on invoke", func(t *testing.T) { t.Parallel() - var called bool + var started bool var invoked bool hook := fx.Annotate( func() { invoked = true }, fx.OnStart(func(context.Context) error { - called = true + started = true return nil }), ) app := fxtest.New(t, fx.Invoke(hook)) - require.False(t, called) - require.NoError(t, app.Start(context.Background())) - require.True(t, called) - require.True(t, invoked) + assertApp(t, app, &started, nil, &invoked) }) t.Run("depend on result interface of target", func(t *testing.T) { @@ -1038,11 +1062,7 @@ func TestHookAnnotations(t *testing.T) { }), ) - ctx := context.Background() - assert.False(t, started) - require.NoError(t, app.Start(ctx)) - assert.True(t, started) - require.NoError(t, app.Stop(ctx)) + assertApp(t, app, &started, nil, nil) }) t.Run("start and stop without dependencies", func(t *testing.T) { @@ -1075,58 +1095,9 @@ func TestHookAnnotations(t *testing.T) { }), ) - ctx := context.Background() - assert.False(t, started) - require.NoError(t, app.Start(ctx)) - assert.True(t, invoked) - assert.True(t, started) - assert.False(t, stopped) - require.NoError(t, app.Stop(ctx)) - assert.True(t, stopped) - + assertApp(t, app, &started, &stopped, &invoked) }) - t.Run("depedency chain", func(t *testing.T) { - t.Parallel() - - type ( - a interface{} - b interface{} - c interface{} - ) - - var aHook, bHook bool - - provided := fx.Provide( - fx.Annotate( - func() (a, error) { return nil, nil }, - fx.OnStart(func(context.Context) error { - aHook = true - return nil - }), - ), - - fx.Annotate( - func() (b, error) { return nil, nil }, - fx.OnStart(func(context.Context) error { - bHook = true - return nil - }), - ), - - func(a, b) c { return nil }, - ) - - app := fxtest.New(t, provided, fx.Invoke(func(c) {})) - - ctx := context.Background() - assert.False(t, aHook) - assert.False(t, bHook) - require.NoError(t, app.Start(ctx)) - assert.True(t, aHook) - assert.True(t, bHook) - require.NoError(t, app.Stop(ctx)) - }) t.Run("with multiple extra dependency parameters", func(t *testing.T) { t.Parallel() @@ -1174,28 +1145,27 @@ func TestHookAnnotations(t *testing.T) { buf := bytes.NewBuffer(nil) var called bool - hook := fx.Annotate( - func() A { - return buf - }, - fx.OnStart(func(_ context.Context, a A, s fmt.Stringer) error { - a.WriteString(s.String()) - return nil - }), - ) - - ctor := fx.Provide(hook) - - hook = fx.Annotate( - &asStringer{"supply"}, - fx.OnStart(func(context.Context) error { - called = true - return nil - }), - fx.As(new(fmt.Stringer)), + ctor := fx.Provide( + fx.Annotate( + func() A { + return buf + }, + fx.OnStart(func(_ context.Context, a A, s fmt.Stringer) error { + a.WriteString(s.String()) + return nil + }), + ), ) - supply := fx.Supply(hook) + supply := fx.Supply( + fx.Annotate( + &asStringer{"supply"}, + fx.OnStart(func(context.Context) error { + called = true + return nil + }), + fx.As(new(fmt.Stringer)), + )) opts := fx.Options( ctor, @@ -1231,7 +1201,6 @@ func TestHookAnnotations(t *testing.T) { return in }, fx.OnStart(func(_ context.Context, a A) error { - // assert that the interface we get is the decorated one called = assert.Equal(t, "decorated", buf.String()) return nil }), @@ -1247,8 +1216,7 @@ func TestHookAnnotations(t *testing.T) { app := fxtest.New(t, opts) ctx := context.Background() - err := app.Start(ctx) - require.NoError(t, err) + require.NoError(t, app.Start(ctx)) require.NoError(t, app.Stop(ctx)) require.True(t, called) require.Equal(t, "decorated", buf.String()) @@ -1345,7 +1313,7 @@ func TestHookAnnotationFailures(t *testing.T) { }, { name: "with hook that errors", - errContains: "hook failed", + errContains: "OnStart hook failed", useNew: true, annotation: fx.Annotate( func() (A, error) { return nil, nil }, @@ -1355,8 +1323,8 @@ func TestHookAnnotationFailures(t *testing.T) { ), }, { - name: "with with multiple hooks of the same type", - errContains: "cannot apply more than one \"OnStart\" hook annotation", + name: "with multiple hooks of the same type", + errContains: `cannot apply more than one "OnStart" hook annotation`, annotation: fx.Annotate( func() A { return nil }, fx.OnStart(func(context.Context) error { return nil }), From 67de07fa1a8bf2590b757f0b10fd987368d185eb Mon Sep 17 00:00:00 2001 From: jmills Date: Tue, 26 Jul 2022 22:58:43 +0000 Subject: [PATCH 16/16] fix bad error assertion in hook tests --- annotated_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/annotated_test.go b/annotated_test.go index d259e7f33..a0def4f91 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -1313,7 +1313,7 @@ func TestHookAnnotationFailures(t *testing.T) { }, { name: "with hook that errors", - errContains: "OnStart hook failed", + errContains: "hook failed", useNew: true, annotation: fx.Annotate( func() (A, error) { return nil, nil },