Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

poll: Result from cmp.Comparison #198

Merged
merged 2 commits into from
Apr 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
169 changes: 27 additions & 142 deletions assert/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,9 @@ See http://pkg.go.dev/gotest.tools/v3/assert/cmd/gty-migrate-from-testify.
package assert // import "gotest.tools/v3/assert"

import (
"fmt"
"go/ast"
"go/token"
"reflect"

gocmp "github.com/google/go-cmp/cmp"
"gotest.tools/v3/assert/cmp"
"gotest.tools/v3/internal/format"
"gotest.tools/v3/internal/source"
"gotest.tools/v3/internal/assert"
)

// BoolOrComparison can be a bool, or cmp.Comparison. See Assert() for usage.
Expand All @@ -90,133 +84,6 @@ type helperT interface {
Helper()
}

const failureMessage = "assertion failed: "

// nolint: gocyclo
func assert(
t TestingT,
failer func(),
argSelector argSelector,
comparison BoolOrComparison,
msgAndArgs ...interface{},
) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
var success bool
switch check := comparison.(type) {
case bool:
if check {
return true
}
logFailureFromBool(t, msgAndArgs...)

// Undocumented legacy comparison without Result type
case func() (success bool, message string):
success = runCompareFunc(t, check, msgAndArgs...)

case nil:
return true

case error:
msg := failureMsgFromError(check)
t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))

case cmp.Comparison:
success = runComparison(t, argSelector, check, msgAndArgs...)

case func() cmp.Result:
success = runComparison(t, argSelector, check, msgAndArgs...)

default:
t.Log(fmt.Sprintf("invalid Comparison: %v (%T)", check, check))
}

if success {
return true
}
failer()
return false
}

func runCompareFunc(
t TestingT,
f func() (success bool, message string),
msgAndArgs ...interface{},
) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if success, message := f(); !success {
t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
return false
}
return true
}

func logFailureFromBool(t TestingT, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
const stackIndex = 3 // Assert()/Check(), assert(), logFailureFromBool()
args, err := source.CallExprArgs(stackIndex)
if err != nil {
t.Log(err.Error())
return
}

const comparisonArgIndex = 1 // Assert(t, comparison)
if len(args) <= comparisonArgIndex {
t.Log(failureMessage + "but assert failed to find the expression to print")
return
}

msg, err := boolFailureMessage(args[comparisonArgIndex])
if err != nil {
t.Log(err.Error())
msg = "expression is false"
}

t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
}

func failureMsgFromError(err error) string {
// Handle errors with non-nil types
v := reflect.ValueOf(err)
if v.Kind() == reflect.Ptr && v.IsNil() {
return fmt.Sprintf("error is not nil: error has type %T", err)
}
return "error is not nil: " + err.Error()
}

func boolFailureMessage(expr ast.Expr) (string, error) {
if binaryExpr, ok := expr.(*ast.BinaryExpr); ok && binaryExpr.Op == token.NEQ {
x, err := source.FormatNode(binaryExpr.X)
if err != nil {
return "", err
}
y, err := source.FormatNode(binaryExpr.Y)
if err != nil {
return "", err
}
return x + " is " + y, nil
}

if unaryExpr, ok := expr.(*ast.UnaryExpr); ok && unaryExpr.Op == token.NOT {
x, err := source.FormatNode(unaryExpr.X)
if err != nil {
return "", err
}
return x + " is true", nil
}

formatted, err := source.FormatNode(expr)
if err != nil {
return "", err
}
return "expression is false: " + formatted, nil
}

// Assert performs a comparison. If the comparison fails, the test is marked as
// failed, a failure message is logged, and execution is stopped immediately.
//
Expand All @@ -235,7 +102,9 @@ func Assert(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{})
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsFromComparisonCall, comparison, msgAndArgs...)
if !assert.Eval(t, assert.ArgsFromComparisonCall, comparison, msgAndArgs...) {
t.FailNow()
}
}

// Check performs a comparison. If the comparison fails the test is marked as
Expand All @@ -247,7 +116,11 @@ func Check(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) b
if ht, ok := t.(helperT); ok {
ht.Helper()
}
return assert(t, t.Fail, argsFromComparisonCall, comparison, msgAndArgs...)
if !assert.Eval(t, assert.ArgsFromComparisonCall, comparison, msgAndArgs...) {
t.Fail()
return false
}
return true
}

// NilError fails the test immediately if err is not nil.
Expand All @@ -256,7 +129,9 @@ func NilError(t TestingT, err error, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, err, msgAndArgs...)
if !assert.Eval(t, assert.ArgsAfterT, err, msgAndArgs...) {
t.FailNow()
}
}

// Equal uses the == operator to assert two values are equal and fails the test
Expand All @@ -275,7 +150,9 @@ func Equal(t TestingT, x, y interface{}, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, cmp.Equal(x, y), msgAndArgs...)
if !assert.Eval(t, assert.ArgsAfterT, cmp.Equal(x, y), msgAndArgs...) {
t.FailNow()
}
}

// DeepEqual uses google/go-cmp (https://godoc.org/github.com/google/go-cmp/cmp)
Expand All @@ -289,7 +166,9 @@ func DeepEqual(t TestingT, x, y interface{}, opts ...gocmp.Option) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, cmp.DeepEqual(x, y, opts...))
if !assert.Eval(t, assert.ArgsAfterT, cmp.DeepEqual(x, y, opts...)) {
t.FailNow()
}
}

// Error fails the test if err is nil, or the error message is not the expected
Expand All @@ -299,7 +178,9 @@ func Error(t TestingT, err error, message string, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, cmp.Error(err, message), msgAndArgs...)
if !assert.Eval(t, assert.ArgsAfterT, cmp.Error(err, message), msgAndArgs...) {
t.FailNow()
}
}

// ErrorContains fails the test if err is nil, or the error message does not
Expand All @@ -309,7 +190,9 @@ func ErrorContains(t TestingT, err error, substring string, msgAndArgs ...interf
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, cmp.ErrorContains(err, substring), msgAndArgs...)
if !assert.Eval(t, assert.ArgsAfterT, cmp.ErrorContains(err, substring), msgAndArgs...) {
t.FailNow()
}
}

// ErrorType fails the test if err is nil, or err is not the expected type.
Expand All @@ -330,5 +213,7 @@ func ErrorType(t TestingT, err error, expected interface{}, msgAndArgs ...interf
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, cmp.ErrorType(err, expected), msgAndArgs...)
if !assert.Eval(t, assert.ArgsAfterT, cmp.ErrorType(err, expected), msgAndArgs...) {
t.FailNow()
}
}
143 changes: 143 additions & 0 deletions internal/assert/assert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package assert

import (
"fmt"
"go/ast"
"go/token"
"reflect"

"gotest.tools/v3/assert/cmp"
"gotest.tools/v3/internal/format"
"gotest.tools/v3/internal/source"
)

// LogT is the subset of testing.T used by the assert package.
type LogT interface {
Log(args ...interface{})
}

type helperT interface {
Helper()
}

const failureMessage = "assertion failed: "

// Eval the comparison and print a failure messages if the comparison has failed.
// nolint: gocyclo
func Eval(
t LogT,
argSelector argSelector,
comparison interface{},
msgAndArgs ...interface{},
) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
var success bool
switch check := comparison.(type) {
case bool:
if check {
return true
}
logFailureFromBool(t, msgAndArgs...)

// Undocumented legacy comparison without Result type
case func() (success bool, message string):
success = runCompareFunc(t, check, msgAndArgs...)

case nil:
return true

case error:
msg := failureMsgFromError(check)
t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))

case cmp.Comparison:
success = RunComparison(t, argSelector, check, msgAndArgs...)

case func() cmp.Result:
success = RunComparison(t, argSelector, check, msgAndArgs...)

default:
t.Log(fmt.Sprintf("invalid Comparison: %v (%T)", check, check))
}
return success
}

func runCompareFunc(
t LogT,
f func() (success bool, message string),
msgAndArgs ...interface{},
) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if success, message := f(); !success {
t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
return false
}
return true
}

func logFailureFromBool(t LogT, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
const stackIndex = 3 // Assert()/Check(), assert(), logFailureFromBool()
args, err := source.CallExprArgs(stackIndex)
if err != nil {
t.Log(err.Error())
return
}

const comparisonArgIndex = 1 // Assert(t, comparison)
if len(args) <= comparisonArgIndex {
t.Log(failureMessage + "but assert failed to find the expression to print")
return
}

msg, err := boolFailureMessage(args[comparisonArgIndex])
if err != nil {
t.Log(err.Error())
msg = "expression is false"
}

t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
}

func failureMsgFromError(err error) string {
// Handle errors with non-nil types
v := reflect.ValueOf(err)
if v.Kind() == reflect.Ptr && v.IsNil() {
return fmt.Sprintf("error is not nil: error has type %T", err)
}
return "error is not nil: " + err.Error()
}

func boolFailureMessage(expr ast.Expr) (string, error) {
if binaryExpr, ok := expr.(*ast.BinaryExpr); ok && binaryExpr.Op == token.NEQ {
x, err := source.FormatNode(binaryExpr.X)
if err != nil {
return "", err
}
y, err := source.FormatNode(binaryExpr.Y)
if err != nil {
return "", err
}
return x + " is " + y, nil
}

if unaryExpr, ok := expr.(*ast.UnaryExpr); ok && unaryExpr.Op == token.NOT {
x, err := source.FormatNode(unaryExpr.X)
if err != nil {
return "", err
}
return x + " is true", nil
}

formatted, err := source.FormatNode(expr)
if err != nil {
return "", err
}
return "expression is false: " + formatted, nil
}