From 26861688bfadaa58112f7a949d196ec09c1eaee5 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Sun, 1 May 2022 14:33:27 -0400 Subject: [PATCH 1/4] Move update flag to internal/source package So it can be used by other packages. --- golden/golden.go | 10 ++++------ golden/golden_test.go | 7 ++++--- internal/source/update.go | 11 +++++++++++ 3 files changed, 19 insertions(+), 9 deletions(-) create mode 100644 internal/source/update.go diff --git a/golden/golden.go b/golden/golden.go index 72ca05f..47ea85f 100644 --- a/golden/golden.go +++ b/golden/golden.go @@ -18,13 +18,11 @@ import ( "gotest.tools/v3/assert" "gotest.tools/v3/assert/cmp" "gotest.tools/v3/internal/format" + "gotest.tools/v3/internal/source" ) -var flagUpdate bool - func init() { - flag.BoolVar(&flagUpdate, "update", false, "update golden files") - flag.BoolVar(&flagUpdate, "test.update-golden", false, "deprecated flag") + flag.BoolVar(&source.Update, "test.update-golden", false, "deprecated flag") } type helperT interface { @@ -46,7 +44,7 @@ var NormalizeCRLFToLF = os.Getenv("GOTESTTOOLS_GOLDEN_NormalizeCRLFToLF") != "fa // FlagUpdate returns true when the -update flag has been set. func FlagUpdate() bool { - return flagUpdate + return source.Update } // Open opens the file in ./testdata @@ -180,7 +178,7 @@ func compare(actual []byte, filename string) (cmp.Result, []byte) { } func update(filename string, actual []byte) error { - if !flagUpdate { + if !source.Update { return nil } if dir := filepath.Dir(Path(filename)); dir != "." { diff --git a/golden/golden_test.go b/golden/golden_test.go index 54e807c..3b0bd02 100644 --- a/golden/golden_test.go +++ b/golden/golden_test.go @@ -9,6 +9,7 @@ import ( "gotest.tools/v3/assert" "gotest.tools/v3/assert/cmp" "gotest.tools/v3/fs" + "gotest.tools/v3/internal/source" ) type fakeT struct { @@ -190,10 +191,10 @@ func TestGoldenAssertBytes(t *testing.T) { } func setUpdateFlag(t *testing.T) func() { - orig := flagUpdate - flagUpdate = true + orig := source.Update + source.Update = true undo := func() { - flagUpdate = orig + source.Update = orig } t.Cleanup(undo) return undo diff --git a/internal/source/update.go b/internal/source/update.go new file mode 100644 index 0000000..f66fff9 --- /dev/null +++ b/internal/source/update.go @@ -0,0 +1,11 @@ +package source + +import "flag" + +// Update is set by the -update flag. It indicates the user running the tests +// would like to update any golden values. +var Update bool + +func init() { + flag.BoolVar(&Update, "update", false, "update golden files") +} From 3abbc52d92cd70f17a01aea4c351adf4b76f28c1 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Sun, 1 May 2022 15:43:19 -0400 Subject: [PATCH 2/4] refactor: prepare for other uses of the ast.File --- internal/source/source.go | 79 +++++++++++--------------------------- internal/source/version.go | 35 +++++++++++++++++ 2 files changed, 58 insertions(+), 56 deletions(-) create mode 100644 internal/source/version.go diff --git a/internal/source/source.go b/internal/source/source.go index 4dbc1bc..2686ded 100644 --- a/internal/source/source.go +++ b/internal/source/source.go @@ -10,12 +10,8 @@ import ( "go/token" "os" "runtime" - "strconv" - "strings" ) -const baseStackIndex = 1 - // FormattedCallExprArg returns the argument from an ast.CallExpr at the // index in the call stack. The argument is formatted using FormatNode. func FormattedCallExprArg(stackIndex int, argPos int) (string, error) { @@ -32,28 +28,26 @@ func FormattedCallExprArg(stackIndex int, argPos int) (string, error) { // CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at // the index in the call stack. func CallExprArgs(stackIndex int) ([]ast.Expr, error) { - _, filename, lineNum, ok := runtime.Caller(baseStackIndex + stackIndex) + _, filename, line, ok := runtime.Caller(stackIndex + 1) if !ok { return nil, errors.New("failed to get call stack") } - debug("call stack position: %s:%d", filename, lineNum) + debug("call stack position: %s:%d", filename, line) - node, err := getNodeAtLine(filename, lineNum) - if err != nil { - return nil, err - } - debug("found node: %s", debugFormatNode{node}) - - return getCallExprArgs(node) -} - -func getNodeAtLine(filename string, lineNum int) (ast.Node, error) { fileset := token.NewFileSet() astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors) if err != nil { return nil, fmt.Errorf("failed to parse source file %s: %w", filename, err) } + expr, err := getCallExprArgs(fileset, astFile, line) + if err != nil { + return nil, fmt.Errorf("call from %s:%d: %w", filename, line, err) + } + return expr, nil +} + +func getNodeAtLine(fileset *token.FileSet, astFile *ast.File, lineNum int) (ast.Node, error) { if node := scanToLine(fileset, astFile, lineNum); node != nil { return node, nil } @@ -63,8 +57,7 @@ func getNodeAtLine(filename string, lineNum int) (ast.Node, error) { return node, err } } - return nil, fmt.Errorf( - "failed to find an expression on line %d in %s", lineNum, filename) + return nil, nil } func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node { @@ -73,7 +66,7 @@ func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node { switch { case node == nil || matchedNode != nil: return false - case nodePosition(fileset, node).Line == lineNum: + case fileset.Position(node.Pos()).Line == lineNum: matchedNode = node return false } @@ -82,46 +75,17 @@ func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node { return matchedNode } -// In golang 1.9 the line number changed from being the line where the statement -// ended to the line where the statement began. -func nodePosition(fileset *token.FileSet, node ast.Node) token.Position { - if goVersionBefore19 { - return fileset.Position(node.End()) - } - return fileset.Position(node.Pos()) -} - -// GoVersionLessThan returns true if runtime.Version() is semantically less than -// version major.minor. Returns false if a release version can not be parsed from -// runtime.Version(). -func GoVersionLessThan(major, minor int64) bool { - version := runtime.Version() - // not a release version - if !strings.HasPrefix(version, "go") { - return false - } - version = strings.TrimPrefix(version, "go") - parts := strings.Split(version, ".") - if len(parts) < 2 { - return false - } - rMajor, err := strconv.ParseInt(parts[0], 10, 32) - if err != nil { - return false - } - if rMajor != major { - return rMajor < major - } - rMinor, err := strconv.ParseInt(parts[1], 10, 32) - if err != nil { - return false +func getCallExprArgs(fileset *token.FileSet, astFile *ast.File, line int) ([]ast.Expr, error) { + node, err := getNodeAtLine(fileset, astFile, line) + switch { + case err != nil: + return nil, err + case node == nil: + return nil, fmt.Errorf("failed to find an expression") } - return rMinor < minor -} -var goVersionBefore19 = GoVersionLessThan(1, 9) + debug("found node: %s", debugFormatNode{node}) -func getCallExprArgs(node ast.Node) ([]ast.Expr, error) { visitor := &callExprVisitor{} ast.Walk(visitor, node) if visitor.expr == nil { @@ -172,6 +136,9 @@ type debugFormatNode struct { } func (n debugFormatNode) String() string { + if n.Node == nil { + return "none" + } out, err := FormatNode(n.Node) if err != nil { return fmt.Sprintf("failed to format %s: %s", n.Node, err) diff --git a/internal/source/version.go b/internal/source/version.go new file mode 100644 index 0000000..5fa8a90 --- /dev/null +++ b/internal/source/version.go @@ -0,0 +1,35 @@ +package source + +import ( + "runtime" + "strconv" + "strings" +) + +// GoVersionLessThan returns true if runtime.Version() is semantically less than +// version major.minor. Returns false if a release version can not be parsed from +// runtime.Version(). +func GoVersionLessThan(major, minor int64) bool { + version := runtime.Version() + // not a release version + if !strings.HasPrefix(version, "go") { + return false + } + version = strings.TrimPrefix(version, "go") + parts := strings.Split(version, ".") + if len(parts) < 2 { + return false + } + rMajor, err := strconv.ParseInt(parts[0], 10, 32) + if err != nil { + return false + } + if rMajor != major { + return rMajor < major + } + rMinor, err := strconv.ParseInt(parts[1], 10, 32) + if err != nil { + return false + } + return rMinor < minor +} From a0e2cd3809c1c70b8bc819453eca5e8d40197e48 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Sun, 1 May 2022 14:50:28 -0400 Subject: [PATCH 3/4] assert: first draft of inline golden variables --- assert/assert_ext_test.go | 64 ++++++++++ assert/cmp/compare.go | 8 +- assert/cmp/result.go | 5 + icmd/command_test.go | 29 ++++- .../result-match-no-match-no-error.golden | 10 -- icmd/testdata/result-match-no-match.golden | 12 -- internal/assert/result.go | 21 +++ internal/source/defers.go | 2 +- internal/source/source.go | 2 +- internal/source/update.go | 120 +++++++++++++++++- 10 files changed, 240 insertions(+), 33 deletions(-) create mode 100644 assert/assert_ext_test.go delete mode 100644 icmd/testdata/result-match-no-match-no-error.golden delete mode 100644 icmd/testdata/result-match-no-match.golden diff --git a/assert/assert_ext_test.go b/assert/assert_ext_test.go new file mode 100644 index 0000000..88b0421 --- /dev/null +++ b/assert/assert_ext_test.go @@ -0,0 +1,64 @@ +package assert_test + +import ( + "fmt" + "testing" + + "gotest.tools/v3/assert" + "gotest.tools/v3/internal/source" +) + +func TestEqual_WithGoldenUpdate(t *testing.T) { + t.Run("assert failed with update=false", func(t *testing.T) { + ft := &fakeTestingT{} + actual := `not this value` + assert.Equal(ft, actual, expectedOne) + assert.Assert(t, ft.failNowed) + }) + + t.Run("value is updated when -update=true", func(t *testing.T) { + patchUpdate(t) + ft := &fakeTestingT{} + + actual := `this is the +actual value +that we are testing against` + assert.Equal(ft, actual, expectedOne) + + // reset + fmt.Println("WHHHHHHHHHHY") + assert.Equal(ft, "\n\n\n", expectedOne) + }) +} + +var expectedOne = ` + + +` + +func patchUpdate(t *testing.T) { + source.Update = true + t.Cleanup(func() { + source.Update = false + }) +} + +type fakeTestingT struct { + failNowed bool + failed bool + msgs []string +} + +func (f *fakeTestingT) FailNow() { + f.failNowed = true +} + +func (f *fakeTestingT) Fail() { + f.failed = true +} + +func (f *fakeTestingT) Log(args ...interface{}) { + f.msgs = append(f.msgs, args[0].(string)) +} + +func (f *fakeTestingT) Helper() {} diff --git a/assert/cmp/compare.go b/assert/cmp/compare.go index 1f42bd0..78f76e4 100644 --- a/assert/cmp/compare.go +++ b/assert/cmp/compare.go @@ -35,7 +35,7 @@ func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison { if diff == "" { return ResultSuccess } - return multiLineDiffResult(diff) + return multiLineDiffResult(diff, x, y) } } @@ -102,7 +102,7 @@ func Equal(x, y interface{}) Comparison { return ResultSuccess case isMultiLineStringCompare(x, y): diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)}) - return multiLineDiffResult(diff) + return multiLineDiffResult(diff, x, y) } return ResultFailureTemplate(` {{- printf "%v" .Data.x}} ( @@ -128,12 +128,12 @@ func isMultiLineStringCompare(x, y interface{}) bool { return strings.Contains(strX, "\n") || strings.Contains(strY, "\n") } -func multiLineDiffResult(diff string) Result { +func multiLineDiffResult(diff string, x, y interface{}) Result { return ResultFailureTemplate(` --- {{ with callArg 0 }}{{ formatNode . }}{{else}}←{{end}} +++ {{ with callArg 1 }}{{ formatNode . }}{{else}}→{{end}} {{ .Data.diff }}`, - map[string]interface{}{"diff": diff}) + map[string]interface{}{"diff": diff, "x": x, "y": y}) } // Len succeeds if the sequence has the expected length. diff --git a/assert/cmp/result.go b/assert/cmp/result.go index 2b0eb7e..28ef8d3 100644 --- a/assert/cmp/result.go +++ b/assert/cmp/result.go @@ -69,6 +69,11 @@ func (r templatedResult) FailureMessage(args []ast.Expr) string { return msg } +func (r templatedResult) UpdatedExpected(stackIndex int) error { + // TODO: would be nice to have structured data instead of a map + return source.UpdateExpectedValue(stackIndex+1, r.data["x"], r.data["y"]) +} + // ResultFailureTemplate returns a Result with a template string and data which // can be used to format a failure message. The template may access data from .Data, // the comparison args with the callArg function, and the formatNode function may diff --git a/icmd/command_test.go b/icmd/command_test.go index 5619c3b..1a5fef9 100644 --- a/icmd/command_test.go +++ b/icmd/command_test.go @@ -12,7 +12,6 @@ import ( exec "golang.org/x/sys/execabs" "gotest.tools/v3/assert" "gotest.tools/v3/fs" - "gotest.tools/v3/golden" "gotest.tools/v3/internal/maint" ) @@ -120,9 +119,22 @@ func TestResult_Match_NotMatched(t *testing.T) { } err := result.match(exp) assert.ErrorContains(t, err, "Failures") - golden.Assert(t, err.Error(), "result-match-no-match.golden") + assert.Equal(t, err.Error(), expectedMatch) } +var expectedMatch = ` +Command: binary arg1 +ExitCode: 99 (timeout) +Error: exit code 99 +Stdout: the output +Stderr: the stderr + +Failures: +ExitCode was 99 expected 101 +Expected command to finish, but it hit the timeout +Expected stdout to contain "Something else" +Expected stderr to contain "[NOTHING]"` + func newLockedBuffer(s string) *lockedBuffer { return &lockedBuffer{buf: *bytes.NewBufferString(s)} } @@ -140,9 +152,20 @@ func TestResult_Match_NotMatchedNoError(t *testing.T) { } err := result.match(exp) assert.ErrorContains(t, err, "Failures") - golden.Assert(t, err.Error(), "result-match-no-match-no-error.golden") + assert.Equal(t, err.Error(), expectedResultMatchNoMatch) } +var expectedResultMatchNoMatch = ` +Command: binary arg1 +ExitCode: 0 +Stdout: the output +Stderr: the stderr + +Failures: +ExitCode was 0 expected 101 +Expected stdout to contain "Something else" +Expected stderr to contain "[NOTHING]"` + func TestResult_Match_Match(t *testing.T) { result := &Result{ Cmd: exec.Command("binary", "arg1"), diff --git a/icmd/testdata/result-match-no-match-no-error.golden b/icmd/testdata/result-match-no-match-no-error.golden deleted file mode 100644 index 162d766..0000000 --- a/icmd/testdata/result-match-no-match-no-error.golden +++ /dev/null @@ -1,10 +0,0 @@ - -Command: binary arg1 -ExitCode: 0 -Stdout: the output -Stderr: the stderr - -Failures: -ExitCode was 0 expected 101 -Expected stdout to contain "Something else" -Expected stderr to contain "[NOTHING]" \ No newline at end of file diff --git a/icmd/testdata/result-match-no-match.golden b/icmd/testdata/result-match-no-match.golden deleted file mode 100644 index 819f9fd..0000000 --- a/icmd/testdata/result-match-no-match.golden +++ /dev/null @@ -1,12 +0,0 @@ - -Command: binary arg1 -ExitCode: 99 (timeout) -Error: exit code 99 -Stdout: the output -Stderr: the stderr - -Failures: -ExitCode was 99 expected 101 -Expected command to finish, but it hit the timeout -Expected stdout to contain "Something else" -Expected stderr to contain "[NOTHING]" \ No newline at end of file diff --git a/internal/assert/result.go b/internal/assert/result.go index 20cd541..3603206 100644 --- a/internal/assert/result.go +++ b/internal/assert/result.go @@ -1,6 +1,7 @@ package assert import ( + "errors" "fmt" "go/ast" @@ -25,6 +26,22 @@ func RunComparison( return true } + if source.Update { + if updater, ok := result.(updateExpected); ok { + const stackIndex = 3 // Assert/Check, assert, RunComparison + err := updater.UpdatedExpected(stackIndex) + switch { + case err == nil: + return true + case errors.Is(err, source.ErrNotFound): + // do nothing, fallthrough to regular failure message + default: + t.Log("failed to update source", err) + return false + } + } + } + var message string switch typed := result.(type) { case resultWithComparisonArgs: @@ -52,6 +69,10 @@ type resultBasic interface { FailureMessage() string } +type updateExpected interface { + UpdatedExpected(stackIndex int) error +} + // filterPrintableExpr filters the ast.Expr slice to only include Expr that are // easy to read when printed and contain relevant information to an assertion. // diff --git a/internal/source/defers.go b/internal/source/defers.go index 8e5a6fb..392d9fe 100644 --- a/internal/source/defers.go +++ b/internal/source/defers.go @@ -28,7 +28,7 @@ func guessDefer(node ast.Node) (ast.Node, error) { defers := collectDefers(node) switch len(defers) { case 0: - return nil, fmt.Errorf("failed to expression in defer") + return nil, fmt.Errorf("failed to find expression in defer") case 1: return defers[0].Call, nil default: diff --git a/internal/source/source.go b/internal/source/source.go index 2686ded..453bee4 100644 --- a/internal/source/source.go +++ b/internal/source/source.go @@ -47,7 +47,7 @@ func CallExprArgs(stackIndex int) ([]ast.Expr, error) { return expr, nil } -func getNodeAtLine(fileset *token.FileSet, astFile *ast.File, lineNum int) (ast.Node, error) { +func getNodeAtLine(fileset *token.FileSet, astFile ast.Node, lineNum int) (ast.Node, error) { if node := scanToLine(fileset, astFile, lineNum); node != nil { return node, nil } diff --git a/internal/source/update.go b/internal/source/update.go index f66fff9..1b669f6 100644 --- a/internal/source/update.go +++ b/internal/source/update.go @@ -1,11 +1,127 @@ package source -import "flag" +import ( + "bytes" + "errors" + "flag" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "os" + "runtime" + "strings" +) // Update is set by the -update flag. It indicates the user running the tests // would like to update any golden values. var Update bool func init() { - flag.BoolVar(&Update, "update", false, "update golden files") + flag.BoolVar(&Update, "update", false, "update golden values") +} + +// ErrNotFound indicates that UpdateExpectedValue failed to find the +// variable to update, likely because it is not a package level variable. +var ErrNotFound = fmt.Errorf("failed to find variable for update of golden value") + +// UpdateExpectedValue looks for a package-level variable with a name that +// starts with expected in the arguments to the caller. If the variable is +// found, the value of the variable will be updated to value of the other +// argument to the caller. +func UpdateExpectedValue(stackIndex int, x, y interface{}) error { + _, filename, line, ok := runtime.Caller(stackIndex + 1) + if !ok { + return errors.New("failed to get call stack") + } + debug("call stack position: %s:%d", filename, line) + + fileset := token.NewFileSet() + astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments) + if err != nil { + return fmt.Errorf("failed to parse source file %s: %w", filename, err) + } + + debug("before modification: %v", debugFormatNode{astFile}) + + expr, err := getCallExprArgs(fileset, astFile, line) + if err != nil { + return fmt.Errorf("call from %s:%d: %w", filename, line, err) + } + + if len(expr) < 3 { + debug("not enough arguments %d: %v", + len(expr), debugFormatNode{Node: &ast.CallExpr{Args: expr}}) + return ErrNotFound + } + + argIndex, varName := getVarNameForExpectedValueArg(expr) + if argIndex < 0 || varName == "" { + debug("no arguments started with the word 'expected': %v", + debugFormatNode{Node: &ast.CallExpr{Args: expr}}) + return ErrNotFound + } + + value := x + if argIndex == 1 { + value = y + } + + obj := astFile.Scope.Objects[varName] + if obj == nil { + return ErrNotFound + } + if obj.Kind != ast.Con && obj.Kind != ast.Var { + debug("can only update var and const, found %v", obj.Kind) + return ErrNotFound + } + + spec, ok := obj.Decl.(*ast.ValueSpec) + if !ok { + debug("can only update *ast.ValueSpec, found %T", obj.Decl) + return ErrNotFound + } + if len(spec.Names) != 1 { + debug("more than one name in ast.ValueSpec") + return ErrNotFound + } + + // TODO: allow a function to wrap the string literal + spec.Values[0] = &ast.BasicLit{ + Kind: token.STRING, + // TODO: safer + Value: "`" + value.(string) + "`", + } + + debug("after modification: %v", debugFormatNode{astFile}) + + var buf bytes.Buffer + if err := format.Node(&buf, fileset, astFile); err != nil { + return fmt.Errorf("failed to format file after update: %w", err) + } + + fh, err := os.Create(filename) + if err != nil { + return fmt.Errorf("failed to open file %v: %w", filename, err) + } + if _, err = fh.Write(buf.Bytes()); err != nil { + return fmt.Errorf("failed to write file %v: %w", filename, err) + } + if err := fh.Sync(); err != nil { + return fmt.Errorf("failed to sync file %v: %w", filename, err) + } + return nil +} + +func getVarNameForExpectedValueArg(expr []ast.Expr) (int, string) { + for i := 1; i < 3; i++ { + switch e := expr[i].(type) { + case *ast.Ident: + if strings.HasPrefix(strings.ToLower(e.Name), "expected") { + return i, e.Name + } + } + } + return -1, "" } From 82e8930f9ab0a6b5348258ca21b68d72b2fa7591 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Sun, 29 May 2022 16:53:20 -0400 Subject: [PATCH 4/4] assert: tests for golden variables --- assert/assert_ext_test.go | 72 ++++++++++++++++++++++++++++++++------- internal/source/source.go | 2 +- internal/source/update.go | 27 ++++++++++----- 3 files changed, 80 insertions(+), 21 deletions(-) diff --git a/assert/assert_ext_test.go b/assert/assert_ext_test.go index 88b0421..5903f70 100644 --- a/assert/assert_ext_test.go +++ b/assert/assert_ext_test.go @@ -1,7 +1,11 @@ package assert_test import ( - "fmt" + "go/parser" + "go/token" + "io/ioutil" + "runtime" + "strings" "testing" "gotest.tools/v3/assert" @@ -9,32 +13,56 @@ import ( ) func TestEqual_WithGoldenUpdate(t *testing.T) { - t.Run("assert failed with update=false", func(t *testing.T) { + t.Run("assert failed with -update=false", func(t *testing.T) { ft := &fakeTestingT{} actual := `not this value` assert.Equal(ft, actual, expectedOne) assert.Assert(t, ft.failNowed) }) - t.Run("value is updated when -update=true", func(t *testing.T) { + t.Run("var is updated when -update=true", func(t *testing.T) { patchUpdate(t) - ft := &fakeTestingT{} + t.Cleanup(func() { + resetVariable(t, "expectedOne", "") + }) actual := `this is the actual value -that we are testing against` - assert.Equal(ft, actual, expectedOne) +that we are testing +` + assert.Equal(t, actual, expectedOne) - // reset - fmt.Println("WHHHHHHHHHHY") - assert.Equal(ft, "\n\n\n", expectedOne) - }) -} + raw, err := ioutil.ReadFile(fileName(t)) + assert.NilError(t, err) -var expectedOne = ` + expected := "var expectedOne = `this is the\nactual value\nthat we are testing\n`" + assert.Assert(t, strings.Contains(string(raw), expected), "actual=%v", string(raw)) + }) + t.Run("const is updated when -update=true", func(t *testing.T) { + patchUpdate(t) + t.Cleanup(func() { + resetVariable(t, "expectedTwo", "") + }) + actual := `this is the new +expected value ` + assert.Equal(t, actual, expectedTwo) + + raw, err := ioutil.ReadFile(fileName(t)) + assert.NilError(t, err) + + expected := "const expectedTwo = `this is the new\nexpected value\n`" + assert.Assert(t, strings.Contains(string(raw), expected), "actual=%v", string(raw)) + }) +} + +// expectedOne is updated by running the tests with -update +var expectedOne = `` + +// expectedTwo is updated by running the tests with -update +const expectedTwo = `` func patchUpdate(t *testing.T) { source.Update = true @@ -43,6 +71,26 @@ func patchUpdate(t *testing.T) { }) } +func fileName(t *testing.T) string { + t.Helper() + _, filename, _, ok := runtime.Caller(1) + assert.Assert(t, ok, "failed to get call stack") + return filename +} + +func resetVariable(t *testing.T, varName string, value string) { + t.Helper() + _, filename, _, ok := runtime.Caller(1) + assert.Assert(t, ok, "failed to get call stack") + + fileset := token.NewFileSet() + astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments) + assert.NilError(t, err) + + err = source.UpdateVariable(filename, fileset, astFile, varName, value) + assert.NilError(t, err, "failed to reset file") +} + type fakeTestingT struct { failNowed bool failed bool diff --git a/internal/source/source.go b/internal/source/source.go index 453bee4..a3f7008 100644 --- a/internal/source/source.go +++ b/internal/source/source.go @@ -75,7 +75,7 @@ func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node { return matchedNode } -func getCallExprArgs(fileset *token.FileSet, astFile *ast.File, line int) ([]ast.Expr, error) { +func getCallExprArgs(fileset *token.FileSet, astFile ast.Node, line int) ([]ast.Expr, error) { node, err := getNodeAtLine(fileset, astFile, line) switch { case err != nil: diff --git a/internal/source/update.go b/internal/source/update.go index 1b669f6..bd9678b 100644 --- a/internal/source/update.go +++ b/internal/source/update.go @@ -43,8 +43,6 @@ func UpdateExpectedValue(stackIndex int, x, y interface{}) error { return fmt.Errorf("failed to parse source file %s: %w", filename, err) } - debug("before modification: %v", debugFormatNode{astFile}) - expr, err := getCallExprArgs(fileset, astFile, line) if err != nil { return fmt.Errorf("call from %s:%d: %w", filename, line, err) @@ -68,6 +66,23 @@ func UpdateExpectedValue(stackIndex int, x, y interface{}) error { value = y } + strValue, ok := value.(string) + if !ok { + debug("value must be type string, got %T", value) + return ErrNotFound + } + return UpdateVariable(filename, fileset, astFile, varName, strValue) +} + +// UpdateVariable writes to filename the contents of astFile with the value of +// the variable updated to value. +func UpdateVariable( + filename string, + fileset *token.FileSet, + astFile *ast.File, + varName string, + value string, +) error { obj := astFile.Scope.Objects[varName] if obj == nil { return ErrNotFound @@ -87,15 +102,11 @@ func UpdateExpectedValue(stackIndex int, x, y interface{}) error { return ErrNotFound } - // TODO: allow a function to wrap the string literal spec.Values[0] = &ast.BasicLit{ - Kind: token.STRING, - // TODO: safer - Value: "`" + value.(string) + "`", + Kind: token.STRING, + Value: "`" + value + "`", } - debug("after modification: %v", debugFormatNode{astFile}) - var buf bytes.Buffer if err := format.Node(&buf, fileset, astFile); err != nil { return fmt.Errorf("failed to format file after update: %w", err)