From b75d2f250836b95028d2bfac2220cef5ccf08837 Mon Sep 17 00:00:00 2001 From: Jim Slattery Date: Thu, 25 Feb 2021 19:38:56 -0500 Subject: [PATCH] Allow WithTransform function to accept a nil value (#422) --- matchers/with_transform.go | 20 +++++++++++++++----- matchers/with_transform_test.go | 8 +++++++- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/matchers/with_transform.go b/matchers/with_transform.go index 8e58d8a0f..a4db310b6 100644 --- a/matchers/with_transform.go +++ b/matchers/with_transform.go @@ -40,15 +40,25 @@ func NewWithTransformMatcher(transform interface{}, matcher types.GomegaMatcher) } func (m *WithTransformMatcher) Match(actual interface{}) (bool, error) { - // return error if actual's type is incompatible with Transform function's argument type - actualType := reflect.TypeOf(actual) - if !actualType.AssignableTo(m.transformArgType) { - return false, fmt.Errorf("Transform function expects '%s' but we have '%s'", m.transformArgType, actualType) + // prepare a parameter to pass to the Transform function + var param reflect.Value + { + if actual != nil { + // return error if actual's type is incompatible with Transform function's argument type + actualType := reflect.TypeOf(actual) + if !actualType.AssignableTo(m.transformArgType) { + return false, fmt.Errorf("Transform function expects '%s' but we have '%s'", m.transformArgType, actualType) + } + param = reflect.ValueOf(actual) + } else { + // make a nil value of the expected type + param = reflect.New(m.transformArgType).Elem() + } } // call the Transform function with `actual` fn := reflect.ValueOf(m.Transform) - result := fn.Call([]reflect.Value{reflect.ValueOf(actual)}) + result := fn.Call([]reflect.Value{param}) m.transformedValue = result[0].Interface() // expect exactly one value return m.Matcher.Match(m.transformedValue) diff --git a/matchers/with_transform_test.go b/matchers/with_transform_test.go index 5fcc949ae..858039e0d 100644 --- a/matchers/with_transform_test.go +++ b/matchers/with_transform_test.go @@ -51,7 +51,13 @@ var _ = Describe("WithTransformMatcher", func() { Expect(S{1, "hi"}).To(WithTransform(transformer, Equal("hi"))) // transform expects interface - errString := func(e error) string { return e.Error() } + errString := func(e error) string { + if e == nil { + return "" + } + return e.Error() + } + Expect(nil).To(WithTransform(errString, Equal("")), "handles nil actual values") Expect(errors.New("abc")).To(WithTransform(errString, Equal("abc"))) })