diff --git a/internal/goenv/goenv.go b/internal/goenv/goenv.go new file mode 100644 index 00000000..2f207aa0 --- /dev/null +++ b/internal/goenv/goenv.go @@ -0,0 +1,54 @@ +package goenv + +import ( + "errors" + "os/exec" + "runtime" + "strconv" + "strings" +) + +func Read() (map[string]string, error) { + out, err := exec.Command("go", "env").CombinedOutput() + if err != nil { + return nil, err + } + return parseGoEnv(out, runtime.GOOS) +} + +func parseGoEnv(data []byte, goos string) (map[string]string, error) { + vars := make(map[string]string) + + lines := strings.Split(strings.ReplaceAll(string(data), "\r\n", "\n"), "\n") + + if goos == "windows" { + // Line format is: `set $name=$value` + for _, l := range lines { + l = strings.TrimPrefix(l, "set ") + parts := strings.Split(l, "=") + if len(parts) != 2 { + continue + } + vars[parts[0]] = parts[1] + } + } else { + // Line format is: `$name="$value"` + for _, l := range lines { + parts := strings.Split(strings.TrimSpace(l), "=") + if len(parts) != 2 { + continue + } + val, err := strconv.Unquote(parts[1]) + if err != nil { + continue + } + vars[parts[0]] = val + } + } + + if len(vars) == 0 { + return nil, errors.New("empty env set") + } + + return vars, nil +} diff --git a/internal/goenv/goenv_test.go b/internal/goenv/goenv_test.go new file mode 100644 index 00000000..957e7059 --- /dev/null +++ b/internal/goenv/goenv_test.go @@ -0,0 +1,84 @@ +package goenv + +import ( + "strings" + "testing" +) + +func TestParse(t *testing.T) { + tests := []struct { + goos string + lines []string + goroot string + gopath string + }{ + { + goos: "windows", + lines: []string{ + "set GOROOT=C:\\Program Files\\Go\r\n", + "set GOPATH=C:\\Users\\me\\go\r\n", + }, + goroot: "C:\\Program Files\\Go", + gopath: "C:\\Users\\me\\go", + }, + + // Don't do trim on Windows. + { + goos: "windows", + lines: []string{ + "set GOROOT=C:\\Program Files\\Go \r\n", + "set GOPATH=C:\\Users\\me\\go \r\n", + }, + goroot: "C:\\Program Files\\Go ", + gopath: "C:\\Users\\me\\go ", + }, + + { + goos: "linux", + lines: []string{ + "GOROOT=\"/usr/local/go\"\n", + "GOPATH=\"/home/me/go\"\n", + }, + goroot: "/usr/local/go", + gopath: "/home/me/go", + }, + + // Trim lines on Linux. + { + goos: "linux", + lines: []string{ + " GOROOT=\"/usr/local/go\" \n", + "GOPATH=\"/home/me/go\" \n", + }, + goroot: "/usr/local/go", + gopath: "/home/me/go", + }, + + // Quotes preserve the whitespace. + { + goos: "linux", + lines: []string{ + " GOROOT=\"/usr/local/go \" \n", + "GOPATH=\"/home/me/go \" \n", + }, + goroot: "/usr/local/go ", + gopath: "/home/me/go ", + }, + } + + for i, test := range tests { + data := []byte(strings.Join(test.lines, "")) + vars, err := parseGoEnv(data, test.goos) + if err != nil { + t.Fatalf("test %d failed: %v", i, err) + } + if vars["GOROOT"] != test.goroot { + t.Errorf("test %d GOROOT mismatch: have %q, want %q", i, vars["GOROOT"], test.goroot) + continue + } + if vars["GOPATH"] != test.gopath { + t.Errorf("test %d GOPATH mismatch: have %q, want %q", i, vars["GOPATH"], test.gopath) + continue + } + } +} diff --git a/ruleguard/engine.go b/ruleguard/engine.go index 1a0e577a..e00706c3 100644 --- a/ruleguard/engine.go +++ b/ruleguard/engine.go @@ -1,7 +1,6 @@ package ruleguard import ( - "bytes" "errors" "fmt" "go/ast" @@ -10,12 +9,12 @@ import ( "go/types" "io" "io/ioutil" - "os/exec" + "os" "sort" - "strconv" "strings" "sync" + "github.com/quasilyte/go-ruleguard/internal/goenv" "github.com/quasilyte/go-ruleguard/internal/stdinfo" "github.com/quasilyte/go-ruleguard/ruleguard/ir" "github.com/quasilyte/go-ruleguard/ruleguard/quasigo" @@ -239,35 +238,27 @@ func (state *engineState) findTypeNoCache(importer *goImporter, currentPkg *type } func inferBuildContext() *build.Context { - goEnv := func() map[string]string { - out, err := exec.Command("go", "env").CombinedOutput() - if err != nil { - return nil - } - vars := make(map[string]string) - for _, l := range bytes.Split(out, []byte("\n")) { - parts := strings.Split(strings.TrimSpace(string(l)), "=") - if len(parts) != 2 { - continue - } - val, err := strconv.Unquote(parts[1]) - if err != nil { - continue - } - vars[parts[0]] = val - } - return vars - } - // Inherit most fields from the build.Default. ctx := build.Default - env := goEnv() + env, err := goenv.Read() + if err != nil { + return &ctx + } ctx.GOROOT = env["GOROOT"] ctx.GOPATH = env["GOPATH"] ctx.GOARCH = env["GOARCH"] ctx.GOOS = env["GOOS"] + switch os.Getenv("CGO_ENABLED") { + case "0": + ctx.CgoEnabled = false + case "1": + ctx.CgoEnabled = true + default: + ctx.CgoEnabled = env["CGO_ENABLED"] == "1" + } + return &ctx }