diff --git a/annotated.go b/annotated.go index 85a8ec0e6..f9efcc574 100644 --- a/annotated.go +++ b/annotated.go @@ -21,6 +21,7 @@ package fx import ( + "context" "errors" "fmt" "reflect" @@ -195,6 +196,329 @@ func ResultTags(tags ...string) Annotation { return resultTagsAnnotation{tags} } +type _lifecycleHookAnnotationType int + +const ( + _unknownHookType _lifecycleHookAnnotationType = iota + _onStartHookType + _onStopHookType +) + +type lifecycleHookAnnotation struct { + Type _lifecycleHookAnnotationType + Target interface{} +} + +func (la *lifecycleHookAnnotation) String() string { + name := "UnknownHookAnnotation" + switch la.Type { + case _onStartHookType: + name = _onStartHook + case _onStopHookType: + name = _onStopHook + } + return name +} + +func (la *lifecycleHookAnnotation) apply(ann *annotated) error { + if la.Target == nil { + return fmt.Errorf( + "cannot use nil function for %q hook annotation", + la, + ) + } + + for _, h := range ann.Hooks { + if la.Type == h.Type { + return fmt.Errorf( + "cannot apply more than one %q hook annotation", + la, + ) + } + } + + ft := reflect.TypeOf(la.Target) + + if ft.Kind() != reflect.Func { + return fmt.Errorf( + "must provide function for %q hook, got %v (%T)", + la, + la.Target, + la.Target, + ) + } + + if ft.NumIn() < 1 || ft.In(0) != _typeOfContext { + return fmt.Errorf( + "first argument of hook must be context.Context, got %v (%T)", + la.Target, + la.Target, + ) + } + + hasOut := ft.NumOut() == 1 + returnsErr := hasOut && ft.Out(0) == _typeOfError + + if !hasOut || !returnsErr { + return fmt.Errorf( + "hooks must return only an error type, got %v (%T)", + la.Target, + la.Target, + ) + } + + if ft.IsVariadic() { + return fmt.Errorf( + "hooks must not accept variatic parameters, got %v (%T)", + la.Target, + la.Target, + ) + } + + ann.Hooks = append(ann.Hooks, la) + return nil +} + +var ( + _typeOfLifecycle reflect.Type = reflect.TypeOf((*Lifecycle)(nil)).Elem() + _typeOfContext reflect.Type = reflect.TypeOf((*context.Context)(nil)).Elem() +) + +type valueResolver func(reflect.Value, int) reflect.Value + +func (la *lifecycleHookAnnotation) resolveMap(results []reflect.Type) ( + resultMap map[reflect.Type]valueResolver, +) { + // index the constructor results by type and position to allow + // for us to omit these from the in types that must be injected, + // and to allow us to interleave constructor results + // into our hook arguments. + resultMap = make(map[reflect.Type]valueResolver, len(results)) + + for _, r := range results { + resultMap[r] = func(v reflect.Value, pos int) (value reflect.Value) { + return v + } + } + + return +} + +func (la *lifecycleHookAnnotation) resolveLifecycleParamField( + param reflect.Value, + n int, +) ( + value reflect.Value, +) { + if param.Kind() == reflect.Struct { + if n <= param.NumField() { + value = param.FieldByName(fmt.Sprintf("Field%d", n)) + } + } + + return value +} + +func (la *lifecycleHookAnnotation) parameters(results ...reflect.Type) ( + in reflect.Type, + argmap func( + args []reflect.Value, + ) (Lifecycle, []reflect.Value), +) { + resultMap := la.resolveMap(results) + + // hook functions require a lifecycle, and it should be injected + params := []reflect.StructField{ + { + Name: "In", + Type: _typeOfIn, + Anonymous: true, + }, + { + Name: "Lifecycle", + Type: _typeOfLifecycle, + }, + } + + type argSource struct { + pos int + result bool + resolve valueResolver + } + + ft := reflect.TypeOf(la.Target) + resolverIdx := make([]argSource, 1) + + for i := 1; i < ft.NumIn(); i++ { + t := ft.In(i) + result, isProvidedByResults := resultMap[t] + + if isProvidedByResults { + resolverIdx = append(resolverIdx, argSource{ + pos: i, + result: true, + resolve: result, + }) + continue + } + + field := reflect.StructField{ + Name: fmt.Sprintf("Field%d", i), + Type: t, + } + params = append(params, field) + + resolverIdx = append(resolverIdx, argSource{ + pos: i, + resolve: la.resolveLifecycleParamField, + }) + } + + in = reflect.StructOf(params) + + argmap = func( + args []reflect.Value, + ) (lc Lifecycle, remapped []reflect.Value) { + remapped = make([]reflect.Value, ft.NumIn()) + + if len(args) != 0 { + var ( + results reflect.Value + p = args[0] + ) + + if len(args) > 1 { + results = args[1] + } + + lc, _ = p.FieldByName("Lifecycle").Interface().(Lifecycle) + for i := 1; i < ft.NumIn(); i++ { + resolver := resolverIdx[i] + source := p + if resolver.result { + source = results + } + remapped[i] = resolver.resolve(source, i) + } + } + return + } + return +} + +func (la *lifecycleHookAnnotation) buildHook(fn func(context.Context) error) (hook Hook) { + switch la.Type { + case _onStartHookType: + hook.OnStart = fn + case _onStopHookType: + hook.OnStop = fn + } + + return +} + +func (la *lifecycleHookAnnotation) Build(results ...reflect.Type) reflect.Value { + in, paramMap := la.parameters(results...) + params := []reflect.Type{in} + for _, r := range results { + if r != _typeOfError { + params = append(params, r) + } + } + + origFn := reflect.ValueOf(la.Target) + newFnType := reflect.FuncOf(params, nil, false) + newFn := reflect.MakeFunc(newFnType, func(args []reflect.Value) []reflect.Value { + var lc Lifecycle + lc, args = paramMap(args) + hookFn := func(ctx context.Context) (err error) { + args[0] = reflect.ValueOf(ctx) + + results := origFn.Call(args) + if len(results) > 0 && results[0].Type() == _typeOfError { + err, _ = results[0].Interface().(error) + } + + return + } + + lc.Append(la.buildHook(hookFn)) + return []reflect.Value{} + }) + + return newFn +} + +// OnStart is an Annotation that appends an OnStart Hook to the application +// Lifecycle when that function is called. This provides a way to create +// Lifecycle OnStart (see Lifecycle type documentation) hooks without building a +// function that takes a dependency on the Lifecycle type. +// +// fx.Annotate( +// NewServer, +// fx.OnStart(func(ctx context.Context, server Server) error { +// return server.Listen(ctx) +// }), +// ) +// +// Which is functionally the same as: +// +// fx.Provide( +// func(lifecycle fx.Lifecycle, p Params) Server { +// server := NewServer(p) +// lifecycle.Append(fx.Hook{ +// OnStart: func(ctx context.Context) error { +// return server.Listen(ctx) +// }, +// }) +// } +// ) +// +// Only one OnStart annotation may be applied to a given function at a time, +// however functions may be annotated with other types of lifecylce Hooks, such +// as OnStop. +func OnStart(onStart interface{}) Annotation { + return &lifecycleHookAnnotation{ + Type: _onStartHookType, + Target: onStart, + } +} + +// OnStop is an Annotation that appends an OnStop Hook to the application +// Lifecycle when that function is called. This provides a way to create +// Lifecycle OnStop (see Lifecycle type documentation) hooks without building a +// function that takes a dependency on the Lifecycle type. +// +// fx.Annotate( +// NewServer, +// fx.OnStop(func(ctx context.Context, server Server) error { +// return server.Shutdown(ctx) +// }), +// ) +// +// Which is functionally the same as: +// +// fx.Provide( +// func(lifecycle fx.Lifecycle, p Params) Server { +// server := NewServer(p) +// lifecycle.Append(fx.Hook{ +// OnStart: func(ctx context.Context) error { +// return server.Shutdown(ctx) +// }, +// }) +// } +// ) +// +// Only one OnStop annotation may be applied to a given function at a time, +// however functions may be annotated with other types of lifecylce Hooks, such +// as OnStart. +func OnStop(onStop interface{}) Annotation { + return &lifecycleHookAnnotation{ + Type: _onStopHookType, + Target: onStop, + } +} + type asAnnotation struct { targets []interface{} } @@ -265,6 +589,7 @@ type annotated struct { ResultTags []string As [][]reflect.Type FuncPtr uintptr + Hooks []*lifecycleHookAnnotation } func (ann annotated) String() string { @@ -295,16 +620,26 @@ func (ann *annotated) Build() (interface{}, error) { return nil, fmt.Errorf("invalid annotation function %T: %w", ann.Target, err) } - paramTypes, remapParams := ann.parameters() resultTypes, remapResults, err := ann.results() if err != nil { return nil, err } + paramTypes, remapParams, hookParams := ann.parameters(resultTypes...) + + hookFns := make([]reflect.Value, len(ann.Hooks)) + for i, builder := range ann.Hooks { + if hookFn := builder.Build(resultTypes...); !hookFn.IsZero() { + hookFns[i] = hookFn + } + } newFnType := reflect.FuncOf(paramTypes, resultTypes, false) origFn := reflect.ValueOf(ann.Target) ann.FuncPtr = origFn.Pointer() + newFn := reflect.MakeFunc(newFnType, func(args []reflect.Value) []reflect.Value { + origArgs := make([]reflect.Value, len(args)) + copy(origArgs, args) args = remapParams(args) var results []reflect.Value if ft.IsVariadic() { @@ -313,6 +648,21 @@ func (ann *annotated) Build() (interface{}, error) { results = origFn.Call(args) } results = remapResults(results) + + // if the number of results is greater than zero and the final result + // is a non-nil error, do not execute hook installers + hasErrorResult := len(results) > 0 && results[len(results)-1].Type() == _typeOfError + if hasErrorResult { + if err, ok := results[len(results)-1].Interface().(error); ok && err != nil { + return results + } + } + + for i, hookFn := range hookFns { + hookArgs := hookParams(i, origArgs, results) + hookFn.Call(hookArgs) + } + return results }) @@ -348,10 +698,14 @@ func (ann *annotated) typeCheckOrigFn() error { // parameters returns the type for the parameters of the annotated function, // and a function that maps the arguments of the annotated function -// back to the arguments of the target function. -func (ann *annotated) parameters() ( +// back to the arguments of the target function and a function that maps +// values to any lifecycle hook annotations. It accepts a variactic set +// of reflect.Type which allows for omitting any resulting constructor types +// from required parameters for annotation hooks. +func (ann *annotated) parameters(results ...reflect.Type) ( types []reflect.Type, remap func([]reflect.Value) []reflect.Value, + hookValueMap func(int, []reflect.Value, []reflect.Value) []reflect.Value, ) { ft := reflect.TypeOf(ann.Target) @@ -362,10 +716,10 @@ func (ann *annotated) parameters() ( // No parameter annotations. Return the original types // and an identity function. - if len(ann.ParamTags) == 0 && !ft.IsVariadic() { + if len(ann.ParamTags) == 0 && !ft.IsVariadic() && len(ann.Hooks) == 0 { return types, func(args []reflect.Value) []reflect.Value { return args - } + }, nil } // Turn parameters into an fx.In struct. @@ -395,8 +749,19 @@ func (ann *annotated) parameters() ( inFields = append(inFields, field) } + // append required types for hooks to types field, but do not + // include them as params in constructor call + for i, t := range ann.Hooks { + params, _ := t.parameters(results...) + field := reflect.StructField{ + Name: fmt.Sprintf("Hook%d", i), + Type: params, + } + inFields = append(inFields, field) + } + types = []reflect.Type{reflect.StructOf(inFields)} - return types, func(args []reflect.Value) []reflect.Value { + remap = func(args []reflect.Value) []reflect.Value { params := args[0] args = args[:0] for i := 0; i < ft.NumIn(); i++ { @@ -404,6 +769,25 @@ func (ann *annotated) parameters() ( } return args } + + hookValueMap = func(hook int, args []reflect.Value, results []reflect.Value) (out []reflect.Value) { + params := args[0] + if params.Kind() == reflect.Struct { + var zero reflect.Value + value := params.FieldByName(fmt.Sprintf("Hook%d", hook)) + + if value != zero { + out = append(out, value) + } + } + for _, r := range results { + if r.Type() != _typeOfError { + out = append(out, r) + } + } + return + } + return } // results returns the types of the results of the annotated function, diff --git a/annotated_test.go b/annotated_test.go index f8ba3b9e1..a0def4f91 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -22,6 +22,7 @@ package fx_test import ( "bytes" + "context" "errors" "fmt" "io" @@ -983,4 +984,432 @@ func TestAnnotate(t *testing.T) { assert.Contains(t, err.Error(), "invalid annotation function func(fx_test.B) string") assert.Contains(t, err.Error(), "fx.In structs cannot be annotated") }) + +} + +func assertApp( + t *testing.T, + app interface { + Start(context.Context) error + Stop(context.Context) error + }, + started *bool, + stopped *bool, + invoked *bool, +) { + t.Helper() + ctx := context.Background() + assert.False(t, *started) + require.NoError(t, app.Start(ctx)) + assert.True(t, *started) + + if invoked != nil { + assert.True(t, *invoked) + } + + if stopped != nil { + assert.False(t, *stopped) + require.NoError(t, app.Stop(ctx)) + assert.True(t, *stopped) + } +} + +func TestHookAnnotations(t *testing.T) { + t.Parallel() + + t.Run("with hook on invoke", func(t *testing.T) { + t.Parallel() + + var started bool + var invoked bool + hook := fx.Annotate( + func() { + invoked = true + }, + fx.OnStart(func(context.Context) error { + started = true + return nil + }), + ) + app := fxtest.New(t, fx.Invoke(hook)) + + assertApp(t, app, &started, nil, &invoked) + }) + + t.Run("depend on result interface of target", func(t *testing.T) { + type stub interface { + String() string + } + + var started bool + + hook := fx.Annotate( + func() (stub, error) { + b := []byte("expected") + return bytes.NewBuffer(b), nil + }, + fx.OnStart(func(_ context.Context, s stub) error { + started = true + require.Equal(t, "expected", s.String()) + return nil + }), + ) + + app := fxtest.New(t, + fx.Provide(hook), + fx.Invoke(func(s stub) { + require.Equal(t, "expected", s.String()) + }), + ) + + assertApp(t, app, &started, nil, nil) + }) + + t.Run("start and stop without dependencies", func(t *testing.T) { + t.Parallel() + + type stub interface{} + + var ( + invoked bool + started bool + stopped bool + ) + + hook := fx.Annotate( + func() (stub, error) { return nil, nil }, + fx.OnStart(func(context.Context) error { + started = true + return nil + }), + fx.OnStop(func(context.Context) error { + stopped = true + return nil + }), + ) + + app := fxtest.New(t, + fx.Provide(hook), + fx.Invoke(func(s stub) { + invoked = s == nil + }), + ) + + assertApp(t, app, &started, &stopped, &invoked) + }) + + t.Run("with multiple extra dependency parameters", func(t *testing.T) { + t.Parallel() + + type ( + A interface{} + B interface{} + C interface{} + ) + + var value int + + hook := fx.Annotate( + func() (A, error) { return nil, nil }, + fx.OnStart(func(_ context.Context, b B, c C) error { + b1, _ := b.(int) + c1, _ := c.(int) + value = b1 + c1 + return nil + }), + ) + + app := fxtest.New(t, + fx.Provide(hook), + fx.Provide(func() B { return int(1) }), + fx.Provide(func() C { return int(2) }), + fx.Invoke(func(A) {}), + ) + + ctx := context.Background() + assert.Zero(t, value) + require.NoError(t, app.Start(ctx)) + defer func() { + require.NoError(t, app.Stop(ctx)) + }() + assert.Equal(t, 3, value) + }) + + t.Run("with Supply", func(t *testing.T) { + t.Parallel() + + type A interface { + WriteString(string) (int, error) + } + + buf := bytes.NewBuffer(nil) + var called bool + + ctor := fx.Provide( + fx.Annotate( + func() A { + return buf + }, + fx.OnStart(func(_ context.Context, a A, s fmt.Stringer) error { + a.WriteString(s.String()) + return nil + }), + ), + ) + + supply := fx.Supply( + fx.Annotate( + &asStringer{"supply"}, + fx.OnStart(func(context.Context) error { + called = true + return nil + }), + fx.As(new(fmt.Stringer)), + )) + + opts := fx.Options( + ctor, + supply, + fx.Invoke(func(A) {}), + ) + + app := fxtest.New(t, opts) + ctx := context.Background() + require.False(t, called) + err := app.Start(ctx) + require.NoError(t, err) + require.NoError(t, app.Stop(ctx)) + require.Equal(t, "supply", buf.String()) + require.True(t, called) + }) + + t.Run("with Decorate", func(t *testing.T) { + t.Parallel() + + type A interface { + WriteString(string) (int, error) + } + + buf := bytes.NewBuffer(nil) + ctor := fx.Provide(func() A { return buf }) + + var called bool + + hook := fx.Annotate( + func(in A) A { + in.WriteString("decorated") + return in + }, + fx.OnStart(func(_ context.Context, a A) error { + called = assert.Equal(t, "decorated", buf.String()) + return nil + }), + ) + + decorated := fx.Decorate(hook) + + opts := fx.Options( + ctor, + decorated, + fx.Invoke(func(A) {}), + ) + + app := fxtest.New(t, opts) + ctx := context.Background() + require.NoError(t, app.Start(ctx)) + require.NoError(t, app.Stop(ctx)) + require.True(t, called) + require.Equal(t, "decorated", buf.String()) + }) + + t.Run("with Supply and Decorate", func(t *testing.T) { + t.Parallel() + + type A interface{} + + ch := make(chan string, 3) + + hook := fx.Annotate( + func() A { return nil }, + fx.OnStart(func(_ context.Context, s fmt.Stringer) error { + ch <- "constructor" + require.Equal(t, "supply", s.String()) + return nil + }), + ) + + ctor := fx.Provide(hook) + + hook = fx.Annotate( + &asStringer{"supply"}, + fx.OnStart(func(_ context.Context) error { + ch <- "supply" + return nil + }), + fx.As(new(fmt.Stringer)), + ) + + supply := fx.Supply(hook) + + hook = fx.Annotate( + func(in A) A { return in }, + fx.OnStart(func(_ context.Context) error { + ch <- "decorated" + return nil + }), + ) + + decorated := fx.Decorate(hook) + + opts := fx.Options( + ctor, + supply, + decorated, + fx.Invoke(func(A) {}), + ) + + app := fxtest.New(t, opts) + ctx := context.Background() + err := app.Start(ctx) + require.NoError(t, err) + require.NoError(t, app.Stop(ctx)) + close(ch) + + require.Equal(t, "supply", <-ch) + require.Equal(t, "constructor", <-ch) + require.Equal(t, "decorated", <-ch) + }) + +} + +func TestHookAnnotationFailures(t *testing.T) { + t.Parallel() + validateApp := func(t *testing.T, opts ...fx.Option) error { + return fx.ValidateApp( + append(opts, fx.Logger(fxtest.NewTestPrinter(t)))..., + ) + } + + type ( + A interface{} + B interface{} + ) + + table := []struct { + name string + annotation interface{} + useNew bool + errContains string + }{ + { + name: "with unprovided dependency", + errContains: "missing type: fx_test.B", + annotation: fx.Annotate( + func() A { return nil }, + fx.OnStart(func(context.Context, B) error { + return nil + }), + ), + }, + { + name: "with hook that errors", + errContains: "hook failed", + useNew: true, + annotation: fx.Annotate( + func() (A, error) { return nil, nil }, + fx.OnStart(func(context.Context) error { + return errors.New("hook failed") + }), + ), + }, + { + name: "with multiple hooks of the same type", + errContains: `cannot apply more than one "OnStart" hook annotation`, + annotation: fx.Annotate( + func() A { return nil }, + fx.OnStart(func(context.Context) error { return nil }), + fx.OnStart(func(context.Context) error { return nil }), + ), + }, + { + name: "with hook that doesn't return an error", + errContains: "must return only an error", + annotation: fx.Annotate( + func() A { return nil }, + fx.OnStart(func(context.Context) {}), + ), + }, + { + name: "with constructor that errors", + errContains: "hooks should not be installed", + useNew: true, + annotation: fx.Annotate( + func() (A, error) { + return nil, errors.New("hooks should not be installed") + }, + fx.OnStart(func(context.Context) error { + require.FailNow(t, "hook should not be called") + return nil + }), + ), + }, + { + name: "without a function target", + errContains: "must provide function", + annotation: fx.Annotate( + func() A { return nil }, + fx.OnStart(&struct{}{}), + ), + }, + { + name: "without context.Context as first parameter", + errContains: "must be context.Context", + annotation: fx.Annotate( + func() A { return nil }, + fx.OnStart(func() {}), + ), + }, + { + name: "with variactic hook", + errContains: "must not accept variatic", + annotation: fx.Annotate( + func() A { return nil }, + fx.OnStart(func(context.Context, ...A) error { + return nil + }), + ), + }, + { + name: "with nil hook target", + errContains: "cannot use nil function", + annotation: fx.Annotate( + func() A { return nil }, + fx.OnStop(nil), + ), + }, + } + + for _, tt := range table { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + opts := fx.Options( + fx.Provide(tt.annotation), + fx.Invoke(func(A) {}), + ) + + if !tt.useNew { + err := validateApp(t, opts) + require.Error(t, err) + require.Contains(t, err.Error(), tt.errContains) + return + } + + app := fx.New(opts) + ctx := context.Background() + err := app.Start(ctx) + require.Error(t, err) + require.Contains(t, err.Error(), tt.errContains) + }) + } }