From 6c46d54150f8198b4865d08e4f616983b6860abf Mon Sep 17 00:00:00 2001 From: Per Johansson Date: Fri, 14 Jan 2022 15:23:55 +0100 Subject: [PATCH] Add variadic support to mg.F Allows to pass sh.Run to mg.F as such: mg.Deps( mg.F(sh.Run, "go", "test", "./..."), ) This improves the magefile by removing some of the one-liner functions that you might end up with that are only used through mg.Deps. Resolves #401. --- mg/fn.go | 19 +++++++++++++++---- mg/fn_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/mg/fn.go b/mg/fn.go index 21487661..3856857a 100644 --- a/mg/fn.go +++ b/mg/fn.go @@ -111,8 +111,8 @@ func checkF(target interface{}, args []interface{}) (hasContext, isNamespace boo return false, false, fmt.Errorf("target's return value is not an error") } - // more inputs than slots is always an error - if len(args) > t.NumIn() { + // more inputs than slots is an error if not variadic + if len(args) > t.NumIn() && !t.IsVariadic() { return false, false, fmt.Errorf("too many arguments for target, got %d for %T", len(args), target) } @@ -142,12 +142,21 @@ func checkF(target interface{}, args []interface{}) (hasContext, isNamespace boo x++ } - if len(args) != inputs { + if t.IsVariadic() { + if len(args) < inputs-1 { + return false, false, fmt.Errorf("too few arguments for target, got %d for %T", len(args), target) + + } + } else if len(args) != inputs { return false, false, fmt.Errorf("wrong number of arguments for target, got %d for %T", len(args), target) } for _, arg := range args { argT := t.In(x) + if t.IsVariadic() && x == t.NumIn()-1 { + // For the variadic argument, use the slice element type. + argT = argT.Elem() + } if !argTypes[argT] { return false, false, fmt.Errorf("argument %d (%s), is not a supported argument type", x, argT) } @@ -155,7 +164,9 @@ func checkF(target interface{}, args []interface{}) (hasContext, isNamespace boo if argT != passedT { return false, false, fmt.Errorf("argument %d expected to be %s, but is %s", x, argT, passedT) } - x++ + if x < t.NumIn()-1 { + x++ + } } return hasContext, isNamespace, nil } diff --git a/mg/fn_test.go b/mg/fn_test.go index ee3d0bb3..8ca481c8 100644 --- a/mg/fn_test.go +++ b/mg/fn_test.go @@ -3,6 +3,7 @@ package mg import ( "context" "fmt" + "reflect" "sync/atomic" "testing" "time" @@ -213,6 +214,40 @@ func TestFNilError(t *testing.T) { } } +func TestFVariadic(t *testing.T) { + fn := F(func(args ...string) { + if !reflect.DeepEqual(args, []string{"a", "b"}) { + t.Errorf("Wrong args, got %v, want [a b]", args) + } + }, "a", "b") + err := fn.Run(context.Background()) + if err != nil { + t.Fatal(err) + } + + fn = F(func(a string, b ...string) {}, "a", "b1", "b2") + err = fn.Run(context.Background()) + if err != nil { + t.Fatal(err) + } + + fn = F(func(a ...string) {}) + err = fn.Run(context.Background()) + if err != nil { + t.Fatal(err) + } + + func() { + defer func() { + err, _ := recover().(error) + if err == nil || err.Error() != "too few arguments for target, got 0 for func(string, ...string)" { + t.Fatal(err) + } + }() + F(func(a string, b ...string) {}) + }() +} + type Foo Namespace func (Foo) Bare() {}