Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
thediveo committed May 20, 2022
1 parent eeb3ba4 commit ed51c4b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
16 changes: 11 additions & 5 deletions matchers/have_existing_field_matcher.go
Expand Up @@ -12,13 +12,19 @@ type HaveExistingFieldMatcher struct {
}

func (matcher *HaveExistingFieldMatcher) Match(actual interface{}) (success bool, err error) {
// we don't care about the field's actual value, just about any error in
// trying to find the field (or method).
_, err = extractField(actual, matcher.Field, "HaveExistingField")
// Report "severe" field extraction errors, but not a missing field.
var ferr fieldError
if errors.As(err, &ferr) {
return false, ferr
if err == nil {
return true, nil
}
return err == nil, nil
var mferr missingFieldError
if errors.As(err, &mferr) {
// missing field errors aren't errors in this context, but instead
// unsuccessful matches.
return false, nil
}
return false, err
}

func (matcher *HaveExistingFieldMatcher) FailureMessage(actual interface{}) (message string) {
Expand Down
20 changes: 10 additions & 10 deletions matchers/have_field.go
Expand Up @@ -8,12 +8,12 @@ import (
"github.com/onsi/gomega/format"
)

// fieldError represents a "severe" field extraction error that
// HaveExistingFieldMatcher must not ignore and thus not interpret as simply a
// missing field.
type fieldError string
// missingFieldError represents a missing field extraction error that
// HaveExistingFieldMatcher can ignore, as opposed to other, sever field
// extraction errors, such as nil pointers, et cetera.
type missingFieldError string

func (e fieldError) Error() string {
func (e missingFieldError) Error() string {
return string(e)
}

Expand All @@ -25,11 +25,11 @@ func extractField(actual interface{}, field string, matchername string) (interfa
actualValue = actualValue.Elem()
}
if actualValue == (reflect.Value{}) {
return nil, fieldError(fmt.Sprintf("%s encountered nil while dereferencing a pointer of type %T.", matchername, actual))
return nil, fmt.Errorf("%s encountered nil while dereferencing a pointer of type %T.", matchername, actual)
}

if actualValue.Kind() != reflect.Struct {
return nil, fieldError(fmt.Sprintf("%s encountered:\n%s\nWhich is not a struct.", matchername, format.Object(actual, 1)))
return nil, fmt.Errorf("%s encountered:\n%s\nWhich is not a struct.", matchername, format.Object(actual, 1))
}

var extractedValue reflect.Value
Expand All @@ -40,17 +40,17 @@ func extractField(actual interface{}, field string, matchername string) (interfa
extractedValue = actualValue.Addr().MethodByName(strings.TrimSuffix(fields[0], "()"))
}
if extractedValue == (reflect.Value{}) {
return nil, fmt.Errorf("%s could not find method named '%s' in struct of type %T.", matchername, fields[0], actual)
return nil, missingFieldError(fmt.Sprintf("%s could not find method named '%s' in struct of type %T.", matchername, fields[0], actual))
}
t := extractedValue.Type()
if t.NumIn() != 0 || t.NumOut() != 1 {
return nil, fieldError(fmt.Sprintf("%s found an invalid method named '%s' in struct of type %T.\nMethods must take no arguments and return exactly one value.", matchername, fields[0], actual))
return nil, fmt.Errorf("%s found an invalid method named '%s' in struct of type %T.\nMethods must take no arguments and return exactly one value.", matchername, fields[0], actual)
}
extractedValue = extractedValue.Call([]reflect.Value{})[0]
} else {
extractedValue = actualValue.FieldByName(fields[0])
if extractedValue == (reflect.Value{}) {
return nil, fmt.Errorf("%s could not find field named '%s' in struct:\n%s", matchername, fields[0], format.Object(actual, 1))
return nil, missingFieldError(fmt.Sprintf("%s could not find field named '%s' in struct:\n%s", matchername, fields[0], format.Object(actual, 1)))
}
}

Expand Down

0 comments on commit ed51c4b

Please sign in to comment.