Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update gty-migrate-from-testify #228

Merged
merged 6 commits into from Apr 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions assert/assert.go
Expand Up @@ -16,8 +16,8 @@ The example below shows assert used with some common types.
import (
"testing"

"gotest.tools/assert"
is "gotest.tools/assert/cmp"
"gotest.tools/v3/assert"
is "gotest.tools/v3/assert/cmp"
)

func TestEverything(t *testing.T) {
Expand Down
8 changes: 5 additions & 3 deletions assert/cmd/gty-migrate-from-testify/call.go
Expand Up @@ -8,14 +8,16 @@ import (
"go/token"
)

// call wraps an testify assert ast.CallExpr and exposes properties of the
// expression to facilitate migrating the expression to a gotest.tools/assert
// call wraps a testify/assert ast.CallExpr and exposes properties of the
// expression to facilitate migrating the expression to a gotest.tools/v3/assert
type call struct {
fileset *token.FileSet
expr *ast.CallExpr
xIdent *ast.Ident
selExpr *ast.SelectorExpr
assert string
// assert is Assert (if the testify package was require), or Check (if the
// testify package was assert).
assert string
}

func (c call) String() string {
Expand Down
4 changes: 2 additions & 2 deletions assert/cmd/gty-migrate-from-testify/doc.go
@@ -1,9 +1,9 @@
/*

Command gty-migrate-from-testify migrates packages from
testify/assert and testify/require to gotest.tools/assert.
testify/assert and testify/require to gotest.tools/v3/assert.

$ go get gotest.tools/assert/cmd/gty-migrate-from-testify
$ go get gotest.tools/v3/assert/cmd/gty-migrate-from-testify

Usage:

Expand Down
102 changes: 0 additions & 102 deletions assert/cmd/gty-migrate-from-testify/importer.go

This file was deleted.

106 changes: 53 additions & 53 deletions assert/cmd/gty-migrate-from-testify/main.go
Expand Up @@ -4,9 +4,7 @@ import (
"bytes"
"fmt"
"go/ast"
"go/build"
"go/format"
"go/parser"
"go/token"
"io/ioutil"
"log"
Expand All @@ -17,7 +15,7 @@ import (

"github.com/pkg/errors"
"github.com/spf13/pflag"
"golang.org/x/tools/go/loader"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/imports"
)

Expand All @@ -27,7 +25,8 @@ type options struct {
debug bool
cmpImportName string
showLoaderErrors bool
useAllFiles bool
buildFlags []string
localImportPath string
}

func main() {
Expand Down Expand Up @@ -62,12 +61,14 @@ func setupFlags(name string) (*pflag.FlagSet, *options) {
"import alias to use for the assert/cmp package")
flags.BoolVar(&opts.showLoaderErrors, "print-loader-errors", false,
"print errors from loading source")
flags.BoolVar(&opts.useAllFiles, "ignore-build-tags", false,
"migrate all files ignoring build tags")
flags.StringSliceVar(&opts.buildFlags, "build-tags", nil,
"build to pass to Go when loading source files")
flags.StringVar(&opts.localImportPath, "local-import-path", "",
"value to pass to 'goimports -local' flag for sorting local imports")
flags.Usage = func() {
fmt.Fprintf(os.Stderr, `Usage: %s [OPTIONS] PACKAGE [PACKAGE...]

Migrate calls from testify/{assert|require} to gotest.tools/assert.
Migrate calls from testify/{assert|require} to gotest.tools/v3/assert.

%s`, name, flags.FlagUsages())
}
Expand All @@ -87,18 +88,19 @@ func handleExitError(name string, err error) {
}

func run(opts options) error {
program, err := loadProgram(opts)
imports.LocalPrefix = opts.localImportPath

fset := token.NewFileSet()
pkgs, err := loadPackages(opts, fset)
if err != nil {
return errors.Wrapf(err, "failed to load program")
}

pkgs := program.InitialPackages()
debugf("package count: %d", len(pkgs))

fileset := program.Fset
for _, pkg := range pkgs {
for _, astFile := range pkg.Files {
absFilename := fileset.File(astFile.Pos()).Name()
debugf("file count for package %v: %d", pkg.PkgPath, len(pkg.Syntax))
for _, astFile := range pkg.Syntax {
absFilename := fset.File(astFile.Pos()).Name()
filename := relativePath(absFilename)
importNames := newImportNames(astFile.Imports, opts)
if !importNames.hasTestifyImports() {
Expand All @@ -109,9 +111,9 @@ func run(opts options) error {
debugf("migrating %s with imports: %#v", filename, importNames)
m := migration{
file: astFile,
fileset: fileset,
fileset: fset,
importNames: importNames,
pkgInfo: pkg,
pkgInfo: pkg.TypesInfo,
}
migrateFile(m)
if opts.dryRun {
Expand All @@ -132,47 +134,33 @@ func run(opts options) error {
return nil
}

func loadProgram(opts options) (*loader.Program, error) {
fakeImporter, err := newFakeImporter()
var loadMode = packages.NeedName |
packages.NeedFiles |
packages.NeedCompiledGoFiles |
packages.NeedDeps |
packages.NeedImports |
packages.NeedTypes |
packages.NeedTypesInfo |
packages.NeedTypesSizes |
packages.NeedSyntax

func loadPackages(opts options, fset *token.FileSet) ([]*packages.Package, error) {
conf := &packages.Config{
Mode: loadMode,
Fset: fset,
Tests: true,
Logf: debugf,
BuildFlags: opts.buildFlags,
}

pkgs, err := packages.Load(conf, opts.pkgs...)
if err != nil {
return nil, err
}
defer fakeImporter.Close()

conf := loader.Config{
Fset: token.NewFileSet(),
ParserMode: parser.ParseComments,
Build: buildContext(opts),
AllowErrors: true,
FindPackage: fakeImporter.Import,
}
for _, pkg := range opts.pkgs {
conf.ImportWithTests(pkg)
}
if !opts.showLoaderErrors {
conf.TypeChecker.Error = func(e error) {}
}
program, err := conf.Load()
if opts.showLoaderErrors {
for p, pkg := range program.AllPackages {
if len(pkg.Errors) > 0 {
fmt.Printf("Package %s loaded with some errors:\n", p.Name())
for _, err := range pkg.Errors {
fmt.Println(" ", err.Error())
}
}
}
}
return program, err
}

func buildContext(opts options) *build.Context {
c := build.Default
c.UseAllFiles = opts.useAllFiles
if val, ok := os.LookupEnv("GOPATH"); ok {
c.GOPATH = val
packages.PrintErrors(pkgs)
}
return &c
return pkgs, nil
}

func relativePath(p string) string {
Expand Down Expand Up @@ -214,8 +202,9 @@ func (p importNames) funcNameFromTestifyName(name string) string {
}

func newImportNames(imports []*ast.ImportSpec, opt options) importNames {
defaultAssertAlias := path.Base(pkgAssert)
importNames := importNames{
assert: path.Base(pkgAssert),
assert: defaultAssertAlias,
cmp: path.Base(pkgCmp),
}
for _, spec := range imports {
Expand All @@ -225,7 +214,18 @@ func newImportNames(imports []*ast.ImportSpec, opt options) importNames {
case pkgTestifyRequire, pkgGopkgTestifyRequire:
importNames.testifyRequire = identOrDefault(spec.Name, "require")
default:
if importedAs(spec, path.Base(pkgAssert)) {
pkgPath := strings.Trim(spec.Path.Value, `"`)

switch {
// v3/assert is already imported and has an alias
case pkgPath == pkgAssert:
if spec.Name != nil && spec.Name.Name != "" {
importNames.assert = spec.Name.Name
}
continue

// some other package is imported as assert
case importedAs(spec, path.Base(pkgAssert)) && importNames.assert == defaultAssertAlias:
importNames.assert = "gtyassert"
}
}
Expand Down
22 changes: 18 additions & 4 deletions assert/cmd/gty-migrate-from-testify/migrate.go
Expand Up @@ -8,16 +8,15 @@ import (
"path"

"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/loader"
)

const (
pkgTestifyAssert = "github.com/stretchr/testify/assert"
pkgGopkgTestifyAssert = "gopkg.in/stretchr/testify.v1/assert"
pkgTestifyRequire = "github.com/stretchr/testify/require"
pkgGopkgTestifyRequire = "gopkg.in/stretchr/testify.v1/require"
pkgAssert = "gotest.tools/assert"
pkgCmp = "gotest.tools/assert/cmp"
pkgAssert = "gotest.tools/v3/assert"
pkgCmp = "gotest.tools/v3/assert/cmp"
)

const (
Expand All @@ -36,7 +35,7 @@ type migration struct {
file *ast.File
fileset *token.FileSet
importNames importNames
pkgInfo *loader.PackageInfo
pkgInfo *types.Info
}

func migrateFile(migration migration) {
Expand Down Expand Up @@ -165,6 +164,8 @@ func convertTestifyAssertion(tcall call, migration migration) ast.Node {
return convertEqualError(tcall, imports)
case "Error", "Errorf":
return convertError(tcall, imports)
case "ErrorContains", "ErrorContainsf":
return convertErrorContains(tcall, imports)
case "Empty", "Emptyf":
return convertEmpty(tcall, imports)
case "Nil", "Nilf":
Expand Down Expand Up @@ -313,6 +314,19 @@ func convertError(tcall call, imports importNames) ast.Node {
tcall.extraArgs(2)...))
}

func convertErrorContains(tcall call, imports importNames) ast.Node {
return &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{
Name: imports.assert,
NamePos: tcall.xIdent.NamePos,
},
Sel: &ast.Ident{Name: "ErrorContains"},
},
Args: tcall.expr.Args,
}
}

func convertEmpty(tcall call, imports importNames) ast.Node {
cmpArgs := []ast.Expr{
tcall.arg(1),
Expand Down