Skip to content

Commit

Permalink
Merge pull request #228 from dnephin/update-gty-migrate
Browse files Browse the repository at this point in the history
Update gty-migrate-from-testify
  • Loading branch information
dnephin committed Apr 4, 2022
2 parents dc5149e + a9159b3 commit 99a7307
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 212 deletions.
4 changes: 2 additions & 2 deletions assert/assert.go
Expand Up @@ -16,8 +16,8 @@ The example below shows assert used with some common types.
import (
"testing"
"gotest.tools/assert"
is "gotest.tools/assert/cmp"
"gotest.tools/v3/assert"
is "gotest.tools/v3/assert/cmp"
)
func TestEverything(t *testing.T) {
Expand Down
8 changes: 5 additions & 3 deletions assert/cmd/gty-migrate-from-testify/call.go
Expand Up @@ -8,14 +8,16 @@ import (
"go/token"
)

// call wraps an testify assert ast.CallExpr and exposes properties of the
// expression to facilitate migrating the expression to a gotest.tools/assert
// call wraps a testify/assert ast.CallExpr and exposes properties of the
// expression to facilitate migrating the expression to a gotest.tools/v3/assert
type call struct {
fileset *token.FileSet
expr *ast.CallExpr
xIdent *ast.Ident
selExpr *ast.SelectorExpr
assert string
// assert is Assert (if the testify package was require), or Check (if the
// testify package was assert).
assert string
}

func (c call) String() string {
Expand Down
4 changes: 2 additions & 2 deletions assert/cmd/gty-migrate-from-testify/doc.go
@@ -1,9 +1,9 @@
/*
Command gty-migrate-from-testify migrates packages from
testify/assert and testify/require to gotest.tools/assert.
testify/assert and testify/require to gotest.tools/v3/assert.
$ go get gotest.tools/assert/cmd/gty-migrate-from-testify
$ go get gotest.tools/v3/assert/cmd/gty-migrate-from-testify
Usage:
Expand Down
102 changes: 0 additions & 102 deletions assert/cmd/gty-migrate-from-testify/importer.go

This file was deleted.

106 changes: 53 additions & 53 deletions assert/cmd/gty-migrate-from-testify/main.go
Expand Up @@ -4,9 +4,7 @@ import (
"bytes"
"fmt"
"go/ast"
"go/build"
"go/format"
"go/parser"
"go/token"
"io/ioutil"
"log"
Expand All @@ -17,7 +15,7 @@ import (

"github.com/pkg/errors"
"github.com/spf13/pflag"
"golang.org/x/tools/go/loader"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/imports"
)

Expand All @@ -27,7 +25,8 @@ type options struct {
debug bool
cmpImportName string
showLoaderErrors bool
useAllFiles bool
buildFlags []string
localImportPath string
}

func main() {
Expand Down Expand Up @@ -62,12 +61,14 @@ func setupFlags(name string) (*pflag.FlagSet, *options) {
"import alias to use for the assert/cmp package")
flags.BoolVar(&opts.showLoaderErrors, "print-loader-errors", false,
"print errors from loading source")
flags.BoolVar(&opts.useAllFiles, "ignore-build-tags", false,
"migrate all files ignoring build tags")
flags.StringSliceVar(&opts.buildFlags, "build-tags", nil,
"build to pass to Go when loading source files")
flags.StringVar(&opts.localImportPath, "local-import-path", "",
"value to pass to 'goimports -local' flag for sorting local imports")
flags.Usage = func() {
fmt.Fprintf(os.Stderr, `Usage: %s [OPTIONS] PACKAGE [PACKAGE...]
Migrate calls from testify/{assert|require} to gotest.tools/assert.
Migrate calls from testify/{assert|require} to gotest.tools/v3/assert.
%s`, name, flags.FlagUsages())
}
Expand All @@ -87,18 +88,19 @@ func handleExitError(name string, err error) {
}

func run(opts options) error {
program, err := loadProgram(opts)
imports.LocalPrefix = opts.localImportPath

fset := token.NewFileSet()
pkgs, err := loadPackages(opts, fset)
if err != nil {
return errors.Wrapf(err, "failed to load program")
}

pkgs := program.InitialPackages()
debugf("package count: %d", len(pkgs))

fileset := program.Fset
for _, pkg := range pkgs {
for _, astFile := range pkg.Files {
absFilename := fileset.File(astFile.Pos()).Name()
debugf("file count for package %v: %d", pkg.PkgPath, len(pkg.Syntax))
for _, astFile := range pkg.Syntax {
absFilename := fset.File(astFile.Pos()).Name()
filename := relativePath(absFilename)
importNames := newImportNames(astFile.Imports, opts)
if !importNames.hasTestifyImports() {
Expand All @@ -109,9 +111,9 @@ func run(opts options) error {
debugf("migrating %s with imports: %#v", filename, importNames)
m := migration{
file: astFile,
fileset: fileset,
fileset: fset,
importNames: importNames,
pkgInfo: pkg,
pkgInfo: pkg.TypesInfo,
}
migrateFile(m)
if opts.dryRun {
Expand All @@ -132,47 +134,33 @@ func run(opts options) error {
return nil
}

func loadProgram(opts options) (*loader.Program, error) {
fakeImporter, err := newFakeImporter()
var loadMode = packages.NeedName |
packages.NeedFiles |
packages.NeedCompiledGoFiles |
packages.NeedDeps |
packages.NeedImports |
packages.NeedTypes |
packages.NeedTypesInfo |
packages.NeedTypesSizes |
packages.NeedSyntax

func loadPackages(opts options, fset *token.FileSet) ([]*packages.Package, error) {
conf := &packages.Config{
Mode: loadMode,
Fset: fset,
Tests: true,
Logf: debugf,
BuildFlags: opts.buildFlags,
}

pkgs, err := packages.Load(conf, opts.pkgs...)
if err != nil {
return nil, err
}
defer fakeImporter.Close()

conf := loader.Config{
Fset: token.NewFileSet(),
ParserMode: parser.ParseComments,
Build: buildContext(opts),
AllowErrors: true,
FindPackage: fakeImporter.Import,
}
for _, pkg := range opts.pkgs {
conf.ImportWithTests(pkg)
}
if !opts.showLoaderErrors {
conf.TypeChecker.Error = func(e error) {}
}
program, err := conf.Load()
if opts.showLoaderErrors {
for p, pkg := range program.AllPackages {
if len(pkg.Errors) > 0 {
fmt.Printf("Package %s loaded with some errors:\n", p.Name())
for _, err := range pkg.Errors {
fmt.Println(" ", err.Error())
}
}
}
}
return program, err
}

func buildContext(opts options) *build.Context {
c := build.Default
c.UseAllFiles = opts.useAllFiles
if val, ok := os.LookupEnv("GOPATH"); ok {
c.GOPATH = val
packages.PrintErrors(pkgs)
}
return &c
return pkgs, nil
}

func relativePath(p string) string {
Expand Down Expand Up @@ -214,8 +202,9 @@ func (p importNames) funcNameFromTestifyName(name string) string {
}

func newImportNames(imports []*ast.ImportSpec, opt options) importNames {
defaultAssertAlias := path.Base(pkgAssert)
importNames := importNames{
assert: path.Base(pkgAssert),
assert: defaultAssertAlias,
cmp: path.Base(pkgCmp),
}
for _, spec := range imports {
Expand All @@ -225,7 +214,18 @@ func newImportNames(imports []*ast.ImportSpec, opt options) importNames {
case pkgTestifyRequire, pkgGopkgTestifyRequire:
importNames.testifyRequire = identOrDefault(spec.Name, "require")
default:
if importedAs(spec, path.Base(pkgAssert)) {
pkgPath := strings.Trim(spec.Path.Value, `"`)

switch {
// v3/assert is already imported and has an alias
case pkgPath == pkgAssert:
if spec.Name != nil && spec.Name.Name != "" {
importNames.assert = spec.Name.Name
}
continue

// some other package is imported as assert
case importedAs(spec, path.Base(pkgAssert)) && importNames.assert == defaultAssertAlias:
importNames.assert = "gtyassert"
}
}
Expand Down
22 changes: 18 additions & 4 deletions assert/cmd/gty-migrate-from-testify/migrate.go
Expand Up @@ -8,16 +8,15 @@ import (
"path"

"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/loader"
)

const (
pkgTestifyAssert = "github.com/stretchr/testify/assert"
pkgGopkgTestifyAssert = "gopkg.in/stretchr/testify.v1/assert"
pkgTestifyRequire = "github.com/stretchr/testify/require"
pkgGopkgTestifyRequire = "gopkg.in/stretchr/testify.v1/require"
pkgAssert = "gotest.tools/assert"
pkgCmp = "gotest.tools/assert/cmp"
pkgAssert = "gotest.tools/v3/assert"
pkgCmp = "gotest.tools/v3/assert/cmp"
)

const (
Expand All @@ -36,7 +35,7 @@ type migration struct {
file *ast.File
fileset *token.FileSet
importNames importNames
pkgInfo *loader.PackageInfo
pkgInfo *types.Info
}

func migrateFile(migration migration) {
Expand Down Expand Up @@ -165,6 +164,8 @@ func convertTestifyAssertion(tcall call, migration migration) ast.Node {
return convertEqualError(tcall, imports)
case "Error", "Errorf":
return convertError(tcall, imports)
case "ErrorContains", "ErrorContainsf":
return convertErrorContains(tcall, imports)
case "Empty", "Emptyf":
return convertEmpty(tcall, imports)
case "Nil", "Nilf":
Expand Down Expand Up @@ -313,6 +314,19 @@ func convertError(tcall call, imports importNames) ast.Node {
tcall.extraArgs(2)...))
}

func convertErrorContains(tcall call, imports importNames) ast.Node {
return &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{
Name: imports.assert,
NamePos: tcall.xIdent.NamePos,
},
Sel: &ast.Ident{Name: "ErrorContains"},
},
Args: tcall.expr.Args,
}
}

func convertEmpty(tcall call, imports importNames) ast.Node {
cmpArgs := []ast.Expr{
tcall.arg(1),
Expand Down

0 comments on commit 99a7307

Please sign in to comment.