diff --git a/mg/deps.go b/mg/deps.go index cae591d3..f0c2509b 100644 --- a/mg/deps.go +++ b/mg/deps.go @@ -134,6 +134,13 @@ func checkFns(fns []interface{}) []Fn { funcs[i] = fn continue } + + // Check if the target provided is a not function so we can give a clear warning + t := reflect.TypeOf(f) + if t == nil || t.Kind() != reflect.Func { + panic(fmt.Errorf("non-function used as a target dependency: %T. The mg.Deps, mg.SerialDeps and mg.CtxDeps functions accept function names, such as mg.Deps(TargetA, TargetB)", f)) + } + funcs[i] = F(f) } return funcs diff --git a/mg/deps_internal_test.go b/mg/deps_internal_test.go index 5b561bf5..cf28e4a4 100644 --- a/mg/deps_internal_test.go +++ b/mg/deps_internal_test.go @@ -2,6 +2,7 @@ package mg import ( "bytes" + "fmt" "log" "os" "strings" @@ -33,3 +34,23 @@ func bar() { } func baz() {} + +func TestDepWasNotInvoked(t *testing.T) { + fn1 := func() error { + return nil + } + defer func() { + err := recover() + if err == nil { + t.Fatal("expected panic, but didn't get one") + } + gotErr := fmt.Sprint(err) + wantErr := "non-function used as a target dependency: . The mg.Deps, mg.SerialDeps and mg.CtxDeps functions accept function names, such as mg.Deps(TargetA, TargetB)" + if !strings.Contains(gotErr, wantErr) { + t.Fatalf(`expected to get "%s" but got "%s"`, wantErr, gotErr) + } + }() + func(fns ...interface{}) { + checkFns(fns) + }(fn1()) +} diff --git a/mg/fn.go b/mg/fn.go index 57376081..21487661 100644 --- a/mg/fn.go +++ b/mg/fn.go @@ -100,8 +100,8 @@ func (f fn) Run(ctx context.Context) error { func checkF(target interface{}, args []interface{}) (hasContext, isNamespace bool, _ error) { t := reflect.TypeOf(target) - if t.Kind() != reflect.Func { - return false, false, fmt.Errorf("non-function passed to mg.F: %T", target) + if t == nil || t.Kind() != reflect.Func { + return false, false, fmt.Errorf("non-function passed to mg.F: %T. The mg.F function accepts function names, such as mg.F(TargetA, \"arg1\", \"arg2\")", target) } if t.NumOut() > 1 { diff --git a/mg/fn_test.go b/mg/fn_test.go index 80111678..ee3d0bb3 100644 --- a/mg/fn_test.go +++ b/mg/fn_test.go @@ -114,6 +114,16 @@ func TestFuncCheck(t *testing.T) { if err == nil { t.Error("expected func(*int) error to be invalid") } + + defer func() { + if r := recover(); r !=nil { + t.Error("expected a nil function argument to be handled gracefully") + } + }() + _, _, err = checkF(nil, []interface{}{1,2}) + if err == nil { + t.Error("expected a nil function argument to be invalid") + } } func TestF(t *testing.T) {