Skip to content

Commit

Permalink
Merge pull request #244 from motemen/update-expected-inside-func
Browse files Browse the repository at this point in the history
assert: allow updating expected vars/consts inside functions
  • Loading branch information
dnephin committed Sep 25, 2022
2 parents 97735af + 4c65207 commit 36dd5d1
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .golangci.yml
Expand Up @@ -5,7 +5,7 @@ linters-settings:
lll:
line-length: 100
maintidx:
under: 40
under: 35

issues:
exclude-use-default: false
Expand Down
71 changes: 70 additions & 1 deletion assert/assert_ext_test.go
@@ -1,6 +1,7 @@
package assert_test

import (
"go/ast"
"go/parser"
"go/token"
"io/ioutil"
Expand Down Expand Up @@ -56,6 +57,48 @@ expected value
expected := "const expectedTwo = `this is the new\nexpected value\n`"
assert.Assert(t, strings.Contains(string(raw), expected), "actual=%v", string(raw))
})

t.Run("var inside function is updated when -update=true", func(t *testing.T) {
patchUpdate(t)
t.Cleanup(func() {
resetVariable(t, "expectedInsideFunc", "")
})

actual := `this is the new
expected value
for var inside function
`
expectedInsideFunc := ``

assert.Equal(t, actual, expectedInsideFunc)

raw, err := ioutil.ReadFile(fileName(t))
assert.NilError(t, err)

expected := "expectedInsideFunc := `this is the new\nexpected value\nfor var inside function\n`"
assert.Assert(t, strings.Contains(string(raw), expected), "actual=%v", string(raw))
})

t.Run("const inside function is updated when -update=true", func(t *testing.T) {
patchUpdate(t)
t.Cleanup(func() {
resetVariable(t, "expectedConstInsideFunc", "")
})

actual := `this is the new
expected value
for const inside function
`
const expectedConstInsideFunc = ``

assert.Equal(t, actual, expectedConstInsideFunc)

raw, err := ioutil.ReadFile(fileName(t))
assert.NilError(t, err)

expected := "const expectedConstInsideFunc = `this is the new\nexpected value\nfor const inside function\n`"
assert.Assert(t, strings.Contains(string(raw), expected), "actual=%v", string(raw))
})
}

// expectedOne is updated by running the tests with -update
Expand Down Expand Up @@ -87,7 +130,33 @@ func resetVariable(t *testing.T, varName string, value string) {
astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments)
assert.NilError(t, err)

err = source.UpdateVariable(filename, fileset, astFile, varName, value)
var ident *ast.Ident
ast.Inspect(astFile, func(n ast.Node) bool {
switch v := n.(type) {
case *ast.AssignStmt:
if len(v.Lhs) == 1 {
if id, ok := v.Lhs[0].(*ast.Ident); ok {
if id.Name == varName {
ident = id
return false
}
}
}

case *ast.ValueSpec:
for _, id := range v.Names {
if id.Name == varName {
ident = id
return false
}
}
}

return true
})
assert.Assert(t, ident != nil, "failed to get ident for %s", varName)

err = source.UpdateVariable(filename, fileset, astFile, ident, value)
assert.NilError(t, err, "failed to reset file")
}

Expand Down
2 changes: 1 addition & 1 deletion assert/cmd/gty-migrate-from-testify/migrate.go
Expand Up @@ -141,7 +141,7 @@ func convertTestifySingleArgCall(tcall call) ast.Node {
}
}

func convertTestifyAssertion(tcall call, migration migration) ast.Node { //nolint:maintidx
func convertTestifyAssertion(tcall call, migration migration) ast.Node {
imports := migration.importNames

switch tcall.selExpr.Sel.Name {
Expand Down
51 changes: 32 additions & 19 deletions internal/source/update.go
Expand Up @@ -54,8 +54,8 @@ func UpdateExpectedValue(stackIndex int, x, y interface{}) error {
return ErrNotFound
}

argIndex, varName := getVarNameForExpectedValueArg(expr)
if argIndex < 0 || varName == "" {
argIndex, ident := getIdentForExpectedValueArg(expr)
if argIndex < 0 || ident == nil {
debug("no arguments started with the word 'expected': %v",
debugFormatNode{Node: &ast.CallExpr{Args: expr}})
return ErrNotFound
Expand All @@ -71,7 +71,7 @@ func UpdateExpectedValue(stackIndex int, x, y interface{}) error {
debug("value must be type string, got %T", value)
return ErrNotFound
}
return UpdateVariable(filename, fileset, astFile, varName, strValue)
return UpdateVariable(filename, fileset, astFile, ident, strValue)
}

// UpdateVariable writes to filename the contents of astFile with the value of
Expand All @@ -80,10 +80,10 @@ func UpdateVariable(
filename string,
fileset *token.FileSet,
astFile *ast.File,
varName string,
ident *ast.Ident,
value string,
) error {
obj := astFile.Scope.Objects[varName]
obj := ident.Obj
if obj == nil {
return ErrNotFound
}
Expand All @@ -92,20 +92,33 @@ func UpdateVariable(
return ErrNotFound
}

spec, ok := obj.Decl.(*ast.ValueSpec)
if !ok {
switch decl := obj.Decl.(type) {
case *ast.ValueSpec:
if len(decl.Names) != 1 {
debug("more than one name in ast.ValueSpec")
return ErrNotFound
}

decl.Values[0] = &ast.BasicLit{
Kind: token.STRING,
Value: "`" + value + "`",
}

case *ast.AssignStmt:
if len(decl.Lhs) != 1 {
debug("more than one name in ast.AssignStmt")
return ErrNotFound
}

decl.Rhs[0] = &ast.BasicLit{
Kind: token.STRING,
Value: "`" + value + "`",
}

default:
debug("can only update *ast.ValueSpec, found %T", obj.Decl)
return ErrNotFound
}
if len(spec.Names) != 1 {
debug("more than one name in ast.ValueSpec")
return ErrNotFound
}

spec.Values[0] = &ast.BasicLit{
Kind: token.STRING,
Value: "`" + value + "`",
}

var buf bytes.Buffer
if err := format.Node(&buf, fileset, astFile); err != nil {
Expand All @@ -125,14 +138,14 @@ func UpdateVariable(
return nil
}

func getVarNameForExpectedValueArg(expr []ast.Expr) (int, string) {
func getIdentForExpectedValueArg(expr []ast.Expr) (int, *ast.Ident) {
for i := 1; i < 3; i++ {
switch e := expr[i].(type) {
case *ast.Ident:
if strings.HasPrefix(strings.ToLower(e.Name), "expected") {
return i, e.Name
return i, e
}
}
}
return -1, ""
return -1, nil
}

0 comments on commit 36dd5d1

Please sign in to comment.