diff --git a/internal/aliasfix/aliasfix.go b/internal/aliasfix/aliasfix.go index f1dbfce2cb0..a259501362c 100644 --- a/internal/aliasfix/aliasfix.go +++ b/internal/aliasfix/aliasfix.go @@ -45,7 +45,7 @@ func ProcessPath(path string) error { return err } if dir.IsDir() { - filepath.WalkDir(path, func(path string, d fs.DirEntry, err error) error { + err := filepath.WalkDir(path, func(path string, d fs.DirEntry, err error) error { if err == nil && !d.IsDir() && strings.HasSuffix(d.Name(), ".go") { err = processFile(path, nil) } @@ -54,6 +54,9 @@ func ProcessPath(path string) error { } return nil }) + if err != nil { + return err + } } else { if err := processFile(path, nil); err != nil { return err @@ -64,22 +67,16 @@ func ProcessPath(path string) error { // processFile checks to see if the given file needs any imports rewritten and // does so if needed. Note an io.Writer is injected here for testability. -func processFile(name string, w io.Writer) error { - if w == nil { - file, err := os.Open(name) - if err != nil { - return err - } - defer file.Close() - w = file - } - f, err := parser.ParseFile(fset, name, nil, parser.ParseComments) +func processFile(name string, w io.Writer) (err error) { + var f *ast.File + f, err = parser.ParseFile(fset, name, nil, parser.ParseComments) if err != nil { return err } var modified bool for _, imp := range f.Imports { - importPath, err := strconv.Unquote(imp.Path.Value) + var importPath string + importPath, err = strconv.Unquote(imp.Path.Value) if err != nil { return err } @@ -98,18 +95,41 @@ func processFile(name string, w io.Writer) error { modified = true } } - if modified { - var buf bytes.Buffer - if err := format.Node(&buf, fset, f); err != nil { + if !modified { + return nil + } + + if w == nil { + backup := name + ".bak" + if err = os.Rename(name, backup); err != nil { return err } - b, err := imports.Process(name, buf.Bytes(), nil) + defer func() { + if err != nil { + os.Rename(backup, name) + } else { + os.Remove(backup) + } + }() + var file *os.File + file, err = os.Create(name) if err != nil { return err } - if _, err := w.Write(b); err != nil { - return err - } + defer file.Close() + w = file + } + var buf bytes.Buffer + if err := format.Node(&buf, fset, f); err != nil { + return err } + b, err := imports.Process(name, buf.Bytes(), nil) + if err != nil { + return err + } + if _, err := w.Write(b); err != nil { + return err + } + return nil }