From 2424352ffad826607f6cd6476133033f50327ddf Mon Sep 17 00:00:00 2001 From: Chris Bandy Date: Sun, 28 Feb 2021 01:44:23 -0600 Subject: [PATCH] Check more carefully for nils in WithTransform Co-authored-by: Jim Slattery See: https://golang.org/doc/faq#nil_error --- matchers/with_transform.go | 23 +++++++++--------- matchers/with_transform_test.go | 42 +++++++++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/matchers/with_transform.go b/matchers/with_transform.go index a4db310b6..f3dec9101 100644 --- a/matchers/with_transform.go +++ b/matchers/with_transform.go @@ -42,18 +42,17 @@ func NewWithTransformMatcher(transform interface{}, matcher types.GomegaMatcher) func (m *WithTransformMatcher) Match(actual interface{}) (bool, error) { // prepare a parameter to pass to the Transform function var param reflect.Value - { - if actual != nil { - // return error if actual's type is incompatible with Transform function's argument type - actualType := reflect.TypeOf(actual) - if !actualType.AssignableTo(m.transformArgType) { - return false, fmt.Errorf("Transform function expects '%s' but we have '%s'", m.transformArgType, actualType) - } - param = reflect.ValueOf(actual) - } else { - // make a nil value of the expected type - param = reflect.New(m.transformArgType).Elem() - } + if actual != nil && reflect.TypeOf(actual).AssignableTo(m.transformArgType) { + // The dynamic type of actual is compatible with the transform argument. + param = reflect.ValueOf(actual) + + } else if actual == nil && m.transformArgType.Kind() == reflect.Interface { + // The dynamic type of actual is unknown, so there's no way to make its + // reflect.Value. Create a nil of the transform argument, which is known. + param = reflect.Zero(m.transformArgType) + + } else { + return false, fmt.Errorf("Transform function expects '%s' but we have '%T'", m.transformArgType, actual) } // call the Transform function with `actual` diff --git a/matchers/with_transform_test.go b/matchers/with_transform_test.go index 858039e0d..38436fe75 100644 --- a/matchers/with_transform_test.go +++ b/matchers/with_transform_test.go @@ -37,6 +37,44 @@ var _ = Describe("WithTransformMatcher", func() { }) }) + When("the actual value is incompatible", func() { + It("fails to pass int to func(string)", func() { + actual, transform := int(0), func(string) int { return 0 } + success, err := WithTransform(transform, Equal(0)).Match(actual) + Expect(success).To(BeFalse()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("function expects 'string'")) + Expect(err.Error()).To(ContainSubstring("have 'int'")) + }) + + It("fails to pass string to func(interface)", func() { + actual, transform := "bang", func(error) int { return 0 } + success, err := WithTransform(transform, Equal(0)).Match(actual) + Expect(success).To(BeFalse()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("function expects 'error'")) + Expect(err.Error()).To(ContainSubstring("have 'string'")) + }) + + It("fails to pass nil interface to func(int)", func() { + actual, transform := error(nil), func(int) int { return 0 } + success, err := WithTransform(transform, Equal(0)).Match(actual) + Expect(success).To(BeFalse()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("function expects 'int'")) + Expect(err.Error()).To(ContainSubstring("have ''")) + }) + + It("fails to pass nil interface to func(pointer)", func() { + actual, transform := error(nil), func(*string) int { return 0 } + success, err := WithTransform(transform, Equal(0)).Match(actual) + Expect(success).To(BeFalse()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("function expects '*string'")) + Expect(err.Error()).To(ContainSubstring("have ''")) + }) + }) + It("works with positive cases", func() { Expect(1).To(WithTransform(plus1, Equal(2))) Expect(1).To(WithTransform(plus1, WithTransform(plus1, Equal(3)))) @@ -53,11 +91,11 @@ var _ = Describe("WithTransformMatcher", func() { // transform expects interface errString := func(e error) string { if e == nil { - return "" + return "safe" } return e.Error() } - Expect(nil).To(WithTransform(errString, Equal("")), "handles nil actual values") + Expect(nil).To(WithTransform(errString, Equal("safe")), "handles nil actual values") Expect(errors.New("abc")).To(WithTransform(errString, Equal("abc"))) })