diff --git a/gomega_dsl.go b/gomega_dsl.go index 6936e2411..e7cf8cce5 100644 --- a/gomega_dsl.go +++ b/gomega_dsl.go @@ -52,7 +52,7 @@ var Default = Gomega(internal.NewGomega(internal.FetchDefaultDurationBundle())) // rich ecosystem of matchers without causing a test to fail. For example, to aggregate a series of potential failures // or for use in a non-test setting. func NewGomega(fail types.GomegaFailHandler) Gomega { - return internal.NewGomega(Default.(*internal.Gomega).DurationBundle).ConfigureWithFailHandler(fail) + return internal.NewGomega(internalGomega(Default).DurationBundle).ConfigureWithFailHandler(fail) } // WithT wraps a *testing.T and provides `Expect`, `Eventually`, and `Consistently` methods. This allows you to leverage @@ -69,6 +69,20 @@ type WithT = internal.Gomega // GomegaWithT is deprecated in favor of gomega.WithT, which does not stutter. type GomegaWithT = WithT +// inner is an interface that allows users to provide a wrapper around Default. The wrapper +// must implement the inner interface and return either the original Default or the result of +// a call to NewGomega(). +type inner interface { + Inner() Gomega +} + +func internalGomega(g Gomega) *internal.Gomega { + if v, ok := g.(inner); ok { + return v.Inner().(*internal.Gomega) + } + return g.(*internal.Gomega) +} + // NewWithT takes a *testing.T and returngs a `gomega.WithT` allowing you to use `Expect`, `Eventually`, and `Consistently` along with // Gomega's rich ecosystem of matchers in standard `testing` test suits. // @@ -79,7 +93,7 @@ type GomegaWithT = WithT // g.Expect(f.HasCow()).To(BeTrue(), "Farm should have cow") // } func NewWithT(t types.GomegaTestingT) *WithT { - return internal.NewGomega(Default.(*internal.Gomega).DurationBundle).ConfigureWithT(t) + return internal.NewGomega(internalGomega(Default).DurationBundle).ConfigureWithT(t) } // NewGomegaWithT is deprecated in favor of gomega.NewWithT, which does not stutter. @@ -88,20 +102,20 @@ var NewGomegaWithT = NewWithT // RegisterFailHandler connects Ginkgo to Gomega. When a matcher fails // the fail handler passed into RegisterFailHandler is called. func RegisterFailHandler(fail types.GomegaFailHandler) { - Default.(*internal.Gomega).ConfigureWithFailHandler(fail) + internalGomega(Default).ConfigureWithFailHandler(fail) } // RegisterFailHandlerWithT is deprecated and will be removed in a future release. // users should use RegisterFailHandler, or RegisterTestingT func RegisterFailHandlerWithT(_ types.GomegaTestingT, fail types.GomegaFailHandler) { fmt.Println("RegisterFailHandlerWithT is deprecated. Please use RegisterFailHandler or RegisterTestingT instead.") - Default.(*internal.Gomega).ConfigureWithFailHandler(fail) + internalGomega(Default).ConfigureWithFailHandler(fail) } // RegisterTestingT connects Gomega to Golang's XUnit style // Testing.T tests. It is now deprecated and you should use NewWithT() instead to get a fresh instance of Gomega for each test. func RegisterTestingT(t types.GomegaTestingT) { - Default.(*internal.Gomega).ConfigureWithT(t) + internalGomega(Default).ConfigureWithT(t) } // InterceptGomegaFailures runs a given callback and returns an array of @@ -112,13 +126,13 @@ func RegisterTestingT(t types.GomegaTestingT) { // This is most useful when testing custom matchers, but can also be used to check // on a value using a Gomega assertion without causing a test failure. func InterceptGomegaFailures(f func()) []string { - originalHandler := Default.(*internal.Gomega).Fail + originalHandler := internalGomega(Default).Fail failures := []string{} - Default.(*internal.Gomega).Fail = func(message string, callerSkip ...int) { + internalGomega(Default).Fail = func(message string, callerSkip ...int) { failures = append(failures, message) } defer func() { - Default.(*internal.Gomega).Fail = originalHandler + internalGomega(Default).Fail = originalHandler }() f() return failures @@ -131,14 +145,14 @@ func InterceptGomegaFailures(f func()) []string { // does not register a failure with the FailHandler registered via RegisterFailHandler - it is up // to the user to decide what to do with the returned error func InterceptGomegaFailure(f func()) (err error) { - originalHandler := Default.(*internal.Gomega).Fail - Default.(*internal.Gomega).Fail = func(message string, callerSkip ...int) { + originalHandler := internalGomega(Default).Fail + internalGomega(Default).Fail = func(message string, callerSkip ...int) { err = errors.New(message) panic("stop execution") } defer func() { - Default.(*internal.Gomega).Fail = originalHandler + internalGomega(Default).Fail = originalHandler if e := recover(); e != nil { if err == nil { panic(e) @@ -151,7 +165,7 @@ func InterceptGomegaFailure(f func()) (err error) { } func ensureDefaultGomegaIsConfigured() { - if !Default.(*internal.Gomega).IsConfigured() { + if !internalGomega(Default).IsConfigured() { panic(nilGomegaPanic) } }