diff --git a/config/reflection_magic.go b/config/reflection_magic.go index 57e486ebcb..c462571d2d 100644 --- a/config/reflection_magic.go +++ b/config/reflection_magic.go @@ -1,6 +1,7 @@ package config import ( + "errors" "fmt" "reflect" "runtime" @@ -97,11 +98,38 @@ func makeArgumentConstructors(fnType reflect.Type, argTypes map[reflect.Type]con return out, nil } +func getConstructorOpts(t reflect.Type, opts ...interface{}) ([]reflect.Value, error) { + if !t.IsVariadic() { + if len(opts) > 0 { + return nil, errors.New("constructor doesn't accept any options") + } + return nil, nil + } + if len(opts) == 0 { + return nil, nil + } + // variadic parameters always go last + wantType := t.In(t.NumIn() - 1).Elem() + values := make([]reflect.Value, 0, len(opts)) + for _, opt := range opts { + val := reflect.ValueOf(opt) + if opt == nil { + return nil, errors.New("expected a transport option, got nil") + } + if val.Type() != wantType { + return nil, fmt.Errorf("expected option of type %s, got %s", wantType, reflect.TypeOf(opt)) + } + values = append(values, val.Convert(wantType)) + } + return values, nil +} + // makes a transport constructor. func makeConstructor( tpt interface{}, tptType reflect.Type, argTypes map[reflect.Type]constructor, + opts ...interface{}, ) (func(host.Host, *tptu.Upgrader, connmgr.ConnectionGater) (interface{}, error), error) { v := reflect.ValueOf(tpt) // avoid panicing on nil/zero value. @@ -121,19 +149,24 @@ func makeConstructor( if err != nil { return nil, err } + optValues, err := getConstructorOpts(t, opts...) + if err != nil { + return nil, err + } return func(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater) (interface{}, error) { - arguments := make([]reflect.Value, len(argConstructors)) + arguments := make([]reflect.Value, 0, len(argConstructors)+len(opts)) for i, makeArg := range argConstructors { if arg := makeArg(h, u, cg); arg != nil { - arguments[i] = reflect.ValueOf(arg) + arguments = append(arguments, reflect.ValueOf(arg)) } else { // ValueOf an un-typed nil yields a zero reflect // value. However, we _want_ the zero value of // the _type_. - arguments[i] = reflect.Zero(t.In(i)) + arguments = append(arguments, reflect.Zero(t.In(i))) } } + arguments = append(arguments, optValues...) return callConstructor(v, arguments) }, nil } diff --git a/config/transport.go b/config/transport.go index 9f55838db6..7b6b3d6fd5 100644 --- a/config/transport.go +++ b/config/transport.go @@ -36,14 +36,14 @@ var transportArgTypes = argTypes // // And returns a type implementing transport.Transport and, optionally, an error // (as the second argument). -func TransportConstructor(tpt interface{}) (TptC, error) { +func TransportConstructor(tpt interface{}, opts ...interface{}) (TptC, error) { // Already constructed? if t, ok := tpt.(transport.Transport); ok { return func(_ host.Host, _ *tptu.Upgrader, _ connmgr.ConnectionGater) (transport.Transport, error) { return t, nil }, nil } - ctor, err := makeConstructor(tpt, transportType, transportArgTypes) + ctor, err := makeConstructor(tpt, transportType, transportArgTypes, opts...) if err != nil { return nil, err } diff --git a/config/transport_test.go b/config/transport_test.go index 753d91a303..bca587b9a7 100644 --- a/config/transport_test.go +++ b/config/transport_test.go @@ -5,11 +5,39 @@ import ( "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/transport" + tptu "github.com/libp2p/go-libp2p-transport-upgrader" + "github.com/libp2p/go-tcp-transport" + + "github.com/stretchr/testify/require" ) func TestTransportVariadicOptions(t *testing.T) { _, err := TransportConstructor(func(_ peer.ID, _ ...int) transport.Transport { return nil }) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) +} + +func TestConstructorWithoutOptsCalledWithOpts(t *testing.T) { + _, err := TransportConstructor(func(_ *tptu.Upgrader) transport.Transport { + return nil + }, 42) + require.EqualError(t, err, "constructor doesn't accept any options") +} + +func TestConstructorWithOptsTypeMismatch(t *testing.T) { + _, err := TransportConstructor(func(_ *tptu.Upgrader, opts ...int) transport.Transport { + return nil + }, 42, "foo") + require.EqualError(t, err, "expected option of type int, got string") +} + +func TestConstructorWithOpts(t *testing.T) { + var options []int + c, err := TransportConstructor(func(_ *tptu.Upgrader, opts ...int) transport.Transport { + options = opts + return tcp.NewTCPTransport(nil) + }, 42, 1337) + require.NoError(t, err) + _, err = c(nil, nil, nil) + require.NoError(t, err) + require.Equal(t, []int{42, 1337}, options) } diff --git a/options.go b/options.go index 9ded4dba42..3d20c975cc 100644 --- a/options.go +++ b/options.go @@ -125,8 +125,8 @@ func Muxer(name string, tpt interface{}) Option { // * Public Key // * Address filter (filter.Filter) // * Peerstore -func Transport(tpt interface{}) Option { - tptc, err := config.TransportConstructor(tpt) +func Transport(tpt interface{}, opts ...interface{}) Option { + tptc, err := config.TransportConstructor(tpt, opts...) err = traceError(err, 1) return func(cfg *Config) error { if err != nil {