diff --git a/assert/assert_ext_test.go b/assert/assert_ext_test.go new file mode 100644 index 00000000..88b04215 --- /dev/null +++ b/assert/assert_ext_test.go @@ -0,0 +1,64 @@ +package assert_test + +import ( + "fmt" + "testing" + + "gotest.tools/v3/assert" + "gotest.tools/v3/internal/source" +) + +func TestEqual_WithGoldenUpdate(t *testing.T) { + t.Run("assert failed with update=false", func(t *testing.T) { + ft := &fakeTestingT{} + actual := `not this value` + assert.Equal(ft, actual, expectedOne) + assert.Assert(t, ft.failNowed) + }) + + t.Run("value is updated when -update=true", func(t *testing.T) { + patchUpdate(t) + ft := &fakeTestingT{} + + actual := `this is the +actual value +that we are testing against` + assert.Equal(ft, actual, expectedOne) + + // reset + fmt.Println("WHHHHHHHHHHY") + assert.Equal(ft, "\n\n\n", expectedOne) + }) +} + +var expectedOne = ` + + +` + +func patchUpdate(t *testing.T) { + source.Update = true + t.Cleanup(func() { + source.Update = false + }) +} + +type fakeTestingT struct { + failNowed bool + failed bool + msgs []string +} + +func (f *fakeTestingT) FailNow() { + f.failNowed = true +} + +func (f *fakeTestingT) Fail() { + f.failed = true +} + +func (f *fakeTestingT) Log(args ...interface{}) { + f.msgs = append(f.msgs, args[0].(string)) +} + +func (f *fakeTestingT) Helper() {} diff --git a/assert/cmp/compare.go b/assert/cmp/compare.go index 1f42bd0c..78f76e4e 100644 --- a/assert/cmp/compare.go +++ b/assert/cmp/compare.go @@ -35,7 +35,7 @@ func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison { if diff == "" { return ResultSuccess } - return multiLineDiffResult(diff) + return multiLineDiffResult(diff, x, y) } } @@ -102,7 +102,7 @@ func Equal(x, y interface{}) Comparison { return ResultSuccess case isMultiLineStringCompare(x, y): diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)}) - return multiLineDiffResult(diff) + return multiLineDiffResult(diff, x, y) } return ResultFailureTemplate(` {{- printf "%v" .Data.x}} ( @@ -128,12 +128,12 @@ func isMultiLineStringCompare(x, y interface{}) bool { return strings.Contains(strX, "\n") || strings.Contains(strY, "\n") } -func multiLineDiffResult(diff string) Result { +func multiLineDiffResult(diff string, x, y interface{}) Result { return ResultFailureTemplate(` --- {{ with callArg 0 }}{{ formatNode . }}{{else}}←{{end}} +++ {{ with callArg 1 }}{{ formatNode . }}{{else}}→{{end}} {{ .Data.diff }}`, - map[string]interface{}{"diff": diff}) + map[string]interface{}{"diff": diff, "x": x, "y": y}) } // Len succeeds if the sequence has the expected length. diff --git a/assert/cmp/result.go b/assert/cmp/result.go index 2b0eb7e3..28ef8d3d 100644 --- a/assert/cmp/result.go +++ b/assert/cmp/result.go @@ -69,6 +69,11 @@ func (r templatedResult) FailureMessage(args []ast.Expr) string { return msg } +func (r templatedResult) UpdatedExpected(stackIndex int) error { + // TODO: would be nice to have structured data instead of a map + return source.UpdateExpectedValue(stackIndex+1, r.data["x"], r.data["y"]) +} + // ResultFailureTemplate returns a Result with a template string and data which // can be used to format a failure message. The template may access data from .Data, // the comparison args with the callArg function, and the formatNode function may diff --git a/icmd/command_test.go b/icmd/command_test.go index 5619c3bc..1a5fef93 100644 --- a/icmd/command_test.go +++ b/icmd/command_test.go @@ -12,7 +12,6 @@ import ( exec "golang.org/x/sys/execabs" "gotest.tools/v3/assert" "gotest.tools/v3/fs" - "gotest.tools/v3/golden" "gotest.tools/v3/internal/maint" ) @@ -120,9 +119,22 @@ func TestResult_Match_NotMatched(t *testing.T) { } err := result.match(exp) assert.ErrorContains(t, err, "Failures") - golden.Assert(t, err.Error(), "result-match-no-match.golden") + assert.Equal(t, err.Error(), expectedMatch) } +var expectedMatch = ` +Command: binary arg1 +ExitCode: 99 (timeout) +Error: exit code 99 +Stdout: the output +Stderr: the stderr + +Failures: +ExitCode was 99 expected 101 +Expected command to finish, but it hit the timeout +Expected stdout to contain "Something else" +Expected stderr to contain "[NOTHING]"` + func newLockedBuffer(s string) *lockedBuffer { return &lockedBuffer{buf: *bytes.NewBufferString(s)} } @@ -140,9 +152,20 @@ func TestResult_Match_NotMatchedNoError(t *testing.T) { } err := result.match(exp) assert.ErrorContains(t, err, "Failures") - golden.Assert(t, err.Error(), "result-match-no-match-no-error.golden") + assert.Equal(t, err.Error(), expectedResultMatchNoMatch) } +var expectedResultMatchNoMatch = ` +Command: binary arg1 +ExitCode: 0 +Stdout: the output +Stderr: the stderr + +Failures: +ExitCode was 0 expected 101 +Expected stdout to contain "Something else" +Expected stderr to contain "[NOTHING]"` + func TestResult_Match_Match(t *testing.T) { result := &Result{ Cmd: exec.Command("binary", "arg1"), diff --git a/icmd/testdata/result-match-no-match-no-error.golden b/icmd/testdata/result-match-no-match-no-error.golden deleted file mode 100644 index 162d7665..00000000 --- a/icmd/testdata/result-match-no-match-no-error.golden +++ /dev/null @@ -1,10 +0,0 @@ - -Command: binary arg1 -ExitCode: 0 -Stdout: the output -Stderr: the stderr - -Failures: -ExitCode was 0 expected 101 -Expected stdout to contain "Something else" -Expected stderr to contain "[NOTHING]" \ No newline at end of file diff --git a/icmd/testdata/result-match-no-match.golden b/icmd/testdata/result-match-no-match.golden deleted file mode 100644 index 819f9fdd..00000000 --- a/icmd/testdata/result-match-no-match.golden +++ /dev/null @@ -1,12 +0,0 @@ - -Command: binary arg1 -ExitCode: 99 (timeout) -Error: exit code 99 -Stdout: the output -Stderr: the stderr - -Failures: -ExitCode was 99 expected 101 -Expected command to finish, but it hit the timeout -Expected stdout to contain "Something else" -Expected stderr to contain "[NOTHING]" \ No newline at end of file diff --git a/internal/assert/result.go b/internal/assert/result.go index 20cd5412..36032061 100644 --- a/internal/assert/result.go +++ b/internal/assert/result.go @@ -1,6 +1,7 @@ package assert import ( + "errors" "fmt" "go/ast" @@ -25,6 +26,22 @@ func RunComparison( return true } + if source.Update { + if updater, ok := result.(updateExpected); ok { + const stackIndex = 3 // Assert/Check, assert, RunComparison + err := updater.UpdatedExpected(stackIndex) + switch { + case err == nil: + return true + case errors.Is(err, source.ErrNotFound): + // do nothing, fallthrough to regular failure message + default: + t.Log("failed to update source", err) + return false + } + } + } + var message string switch typed := result.(type) { case resultWithComparisonArgs: @@ -52,6 +69,10 @@ type resultBasic interface { FailureMessage() string } +type updateExpected interface { + UpdatedExpected(stackIndex int) error +} + // filterPrintableExpr filters the ast.Expr slice to only include Expr that are // easy to read when printed and contain relevant information to an assertion. // diff --git a/internal/source/defers.go b/internal/source/defers.go index 8e5a6fb7..392d9fe0 100644 --- a/internal/source/defers.go +++ b/internal/source/defers.go @@ -28,7 +28,7 @@ func guessDefer(node ast.Node) (ast.Node, error) { defers := collectDefers(node) switch len(defers) { case 0: - return nil, fmt.Errorf("failed to expression in defer") + return nil, fmt.Errorf("failed to find expression in defer") case 1: return defers[0].Call, nil default: diff --git a/internal/source/source.go b/internal/source/source.go index 2686ded8..453bee4d 100644 --- a/internal/source/source.go +++ b/internal/source/source.go @@ -47,7 +47,7 @@ func CallExprArgs(stackIndex int) ([]ast.Expr, error) { return expr, nil } -func getNodeAtLine(fileset *token.FileSet, astFile *ast.File, lineNum int) (ast.Node, error) { +func getNodeAtLine(fileset *token.FileSet, astFile ast.Node, lineNum int) (ast.Node, error) { if node := scanToLine(fileset, astFile, lineNum); node != nil { return node, nil } diff --git a/internal/source/update.go b/internal/source/update.go index f66fff99..1b669f6b 100644 --- a/internal/source/update.go +++ b/internal/source/update.go @@ -1,11 +1,127 @@ package source -import "flag" +import ( + "bytes" + "errors" + "flag" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "os" + "runtime" + "strings" +) // Update is set by the -update flag. It indicates the user running the tests // would like to update any golden values. var Update bool func init() { - flag.BoolVar(&Update, "update", false, "update golden files") + flag.BoolVar(&Update, "update", false, "update golden values") +} + +// ErrNotFound indicates that UpdateExpectedValue failed to find the +// variable to update, likely because it is not a package level variable. +var ErrNotFound = fmt.Errorf("failed to find variable for update of golden value") + +// UpdateExpectedValue looks for a package-level variable with a name that +// starts with expected in the arguments to the caller. If the variable is +// found, the value of the variable will be updated to value of the other +// argument to the caller. +func UpdateExpectedValue(stackIndex int, x, y interface{}) error { + _, filename, line, ok := runtime.Caller(stackIndex + 1) + if !ok { + return errors.New("failed to get call stack") + } + debug("call stack position: %s:%d", filename, line) + + fileset := token.NewFileSet() + astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments) + if err != nil { + return fmt.Errorf("failed to parse source file %s: %w", filename, err) + } + + debug("before modification: %v", debugFormatNode{astFile}) + + expr, err := getCallExprArgs(fileset, astFile, line) + if err != nil { + return fmt.Errorf("call from %s:%d: %w", filename, line, err) + } + + if len(expr) < 3 { + debug("not enough arguments %d: %v", + len(expr), debugFormatNode{Node: &ast.CallExpr{Args: expr}}) + return ErrNotFound + } + + argIndex, varName := getVarNameForExpectedValueArg(expr) + if argIndex < 0 || varName == "" { + debug("no arguments started with the word 'expected': %v", + debugFormatNode{Node: &ast.CallExpr{Args: expr}}) + return ErrNotFound + } + + value := x + if argIndex == 1 { + value = y + } + + obj := astFile.Scope.Objects[varName] + if obj == nil { + return ErrNotFound + } + if obj.Kind != ast.Con && obj.Kind != ast.Var { + debug("can only update var and const, found %v", obj.Kind) + return ErrNotFound + } + + spec, ok := obj.Decl.(*ast.ValueSpec) + if !ok { + debug("can only update *ast.ValueSpec, found %T", obj.Decl) + return ErrNotFound + } + if len(spec.Names) != 1 { + debug("more than one name in ast.ValueSpec") + return ErrNotFound + } + + // TODO: allow a function to wrap the string literal + spec.Values[0] = &ast.BasicLit{ + Kind: token.STRING, + // TODO: safer + Value: "`" + value.(string) + "`", + } + + debug("after modification: %v", debugFormatNode{astFile}) + + var buf bytes.Buffer + if err := format.Node(&buf, fileset, astFile); err != nil { + return fmt.Errorf("failed to format file after update: %w", err) + } + + fh, err := os.Create(filename) + if err != nil { + return fmt.Errorf("failed to open file %v: %w", filename, err) + } + if _, err = fh.Write(buf.Bytes()); err != nil { + return fmt.Errorf("failed to write file %v: %w", filename, err) + } + if err := fh.Sync(); err != nil { + return fmt.Errorf("failed to sync file %v: %w", filename, err) + } + return nil +} + +func getVarNameForExpectedValueArg(expr []ast.Expr) (int, string) { + for i := 1; i < 3; i++ { + switch e := expr[i].(type) { + case *ast.Ident: + if strings.HasPrefix(strings.ToLower(e.Name), "expected") { + return i, e.Name + } + } + } + return -1, "" }