Skip to content

Commit

Permalink
refactor: prepare for other uses of the ast.File
Browse files Browse the repository at this point in the history
  • Loading branch information
dnephin committed May 29, 2022
1 parent 2686168 commit 3abbc52
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 56 deletions.
79 changes: 23 additions & 56 deletions internal/source/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@ import (
"go/token"
"os"
"runtime"
"strconv"
"strings"
)

const baseStackIndex = 1

// FormattedCallExprArg returns the argument from an ast.CallExpr at the
// index in the call stack. The argument is formatted using FormatNode.
func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
Expand All @@ -32,28 +28,26 @@ func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
// CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at
// the index in the call stack.
func CallExprArgs(stackIndex int) ([]ast.Expr, error) {
_, filename, lineNum, ok := runtime.Caller(baseStackIndex + stackIndex)
_, filename, line, ok := runtime.Caller(stackIndex + 1)
if !ok {
return nil, errors.New("failed to get call stack")
}
debug("call stack position: %s:%d", filename, lineNum)
debug("call stack position: %s:%d", filename, line)

node, err := getNodeAtLine(filename, lineNum)
if err != nil {
return nil, err
}
debug("found node: %s", debugFormatNode{node})

return getCallExprArgs(node)
}

func getNodeAtLine(filename string, lineNum int) (ast.Node, error) {
fileset := token.NewFileSet()
astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors)
if err != nil {
return nil, fmt.Errorf("failed to parse source file %s: %w", filename, err)
}

expr, err := getCallExprArgs(fileset, astFile, line)
if err != nil {
return nil, fmt.Errorf("call from %s:%d: %w", filename, line, err)
}
return expr, nil
}

func getNodeAtLine(fileset *token.FileSet, astFile *ast.File, lineNum int) (ast.Node, error) {
if node := scanToLine(fileset, astFile, lineNum); node != nil {
return node, nil
}
Expand All @@ -63,8 +57,7 @@ func getNodeAtLine(filename string, lineNum int) (ast.Node, error) {
return node, err
}
}
return nil, fmt.Errorf(
"failed to find an expression on line %d in %s", lineNum, filename)
return nil, nil
}

func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
Expand All @@ -73,7 +66,7 @@ func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
switch {
case node == nil || matchedNode != nil:
return false
case nodePosition(fileset, node).Line == lineNum:
case fileset.Position(node.Pos()).Line == lineNum:
matchedNode = node
return false
}
Expand All @@ -82,46 +75,17 @@ func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
return matchedNode
}

// In golang 1.9 the line number changed from being the line where the statement
// ended to the line where the statement began.
func nodePosition(fileset *token.FileSet, node ast.Node) token.Position {
if goVersionBefore19 {
return fileset.Position(node.End())
}
return fileset.Position(node.Pos())
}

// GoVersionLessThan returns true if runtime.Version() is semantically less than
// version major.minor. Returns false if a release version can not be parsed from
// runtime.Version().
func GoVersionLessThan(major, minor int64) bool {
version := runtime.Version()
// not a release version
if !strings.HasPrefix(version, "go") {
return false
}
version = strings.TrimPrefix(version, "go")
parts := strings.Split(version, ".")
if len(parts) < 2 {
return false
}
rMajor, err := strconv.ParseInt(parts[0], 10, 32)
if err != nil {
return false
}
if rMajor != major {
return rMajor < major
}
rMinor, err := strconv.ParseInt(parts[1], 10, 32)
if err != nil {
return false
func getCallExprArgs(fileset *token.FileSet, astFile *ast.File, line int) ([]ast.Expr, error) {
node, err := getNodeAtLine(fileset, astFile, line)
switch {
case err != nil:
return nil, err
case node == nil:
return nil, fmt.Errorf("failed to find an expression")
}
return rMinor < minor
}

var goVersionBefore19 = GoVersionLessThan(1, 9)
debug("found node: %s", debugFormatNode{node})

func getCallExprArgs(node ast.Node) ([]ast.Expr, error) {
visitor := &callExprVisitor{}
ast.Walk(visitor, node)
if visitor.expr == nil {
Expand Down Expand Up @@ -172,6 +136,9 @@ type debugFormatNode struct {
}

func (n debugFormatNode) String() string {
if n.Node == nil {
return "none"
}
out, err := FormatNode(n.Node)
if err != nil {
return fmt.Sprintf("failed to format %s: %s", n.Node, err)
Expand Down
35 changes: 35 additions & 0 deletions internal/source/version.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package source

import (
"runtime"
"strconv"
"strings"
)

// GoVersionLessThan returns true if runtime.Version() is semantically less than
// version major.minor. Returns false if a release version can not be parsed from
// runtime.Version().
func GoVersionLessThan(major, minor int64) bool {
version := runtime.Version()
// not a release version
if !strings.HasPrefix(version, "go") {
return false
}
version = strings.TrimPrefix(version, "go")
parts := strings.Split(version, ".")
if len(parts) < 2 {
return false
}
rMajor, err := strconv.ParseInt(parts[0], 10, 32)
if err != nil {
return false
}
if rMajor != major {
return rMajor < major
}
rMinor, err := strconv.ParseInt(parts[1], 10, 32)
if err != nil {
return false
}
return rMinor < minor
}

0 comments on commit 3abbc52

Please sign in to comment.