Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add variadic support to mg.F #402

Merged
merged 1 commit into from Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 15 additions & 4 deletions mg/fn.go
Expand Up @@ -111,8 +111,8 @@ func checkF(target interface{}, args []interface{}) (hasContext, isNamespace boo
return false, false, fmt.Errorf("target's return value is not an error")
}

// more inputs than slots is always an error
if len(args) > t.NumIn() {
// more inputs than slots is an error if not variadic
if len(args) > t.NumIn() && !t.IsVariadic() {
return false, false, fmt.Errorf("too many arguments for target, got %d for %T", len(args), target)
}

Expand Down Expand Up @@ -142,20 +142,31 @@ func checkF(target interface{}, args []interface{}) (hasContext, isNamespace boo
x++
}

if len(args) != inputs {
if t.IsVariadic() {
if len(args) < inputs-1 {
return false, false, fmt.Errorf("too few arguments for target, got %d for %T", len(args), target)

}
} else if len(args) != inputs {
return false, false, fmt.Errorf("wrong number of arguments for target, got %d for %T", len(args), target)
}

for _, arg := range args {
argT := t.In(x)
if t.IsVariadic() && x == t.NumIn()-1 {
// For the variadic argument, use the slice element type.
argT = argT.Elem()
}
if !argTypes[argT] {
return false, false, fmt.Errorf("argument %d (%s), is not a supported argument type", x, argT)
}
passedT := reflect.TypeOf(arg)
if argT != passedT {
return false, false, fmt.Errorf("argument %d expected to be %s, but is %s", x, argT, passedT)
}
x++
if x < t.NumIn()-1 {
x++
}
}
return hasContext, isNamespace, nil
}
Expand Down
35 changes: 35 additions & 0 deletions mg/fn_test.go
Expand Up @@ -3,6 +3,7 @@ package mg
import (
"context"
"fmt"
"reflect"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -213,6 +214,40 @@ func TestFNilError(t *testing.T) {
}
}

func TestFVariadic(t *testing.T) {
fn := F(func(args ...string) {
if !reflect.DeepEqual(args, []string{"a", "b"}) {
t.Errorf("Wrong args, got %v, want [a b]", args)
}
}, "a", "b")
err := fn.Run(context.Background())
if err != nil {
t.Fatal(err)
}

fn = F(func(a string, b ...string) {}, "a", "b1", "b2")
err = fn.Run(context.Background())
if err != nil {
t.Fatal(err)
}

fn = F(func(a ...string) {})
err = fn.Run(context.Background())
if err != nil {
t.Fatal(err)
}

func() {
defer func() {
err, _ := recover().(error)
if err == nil || err.Error() != "too few arguments for target, got 0 for func(string, ...string)" {
t.Fatal(err)
}
}()
F(func(a string, b ...string) {})
}()
}

type Foo Namespace

func (Foo) Bare() {}
Expand Down