diff --git a/cpu/cpu_linux_test.go b/cpu/cpu_linux_test.go index f66d6b915..ec3330fe1 100644 --- a/cpu/cpu_linux_test.go +++ b/cpu/cpu_linux_test.go @@ -2,7 +2,6 @@ package cpu import ( "errors" - "os" "os/exec" "strconv" "strings" @@ -10,8 +9,7 @@ import ( ) func TestTimesEmpty(t *testing.T) { - orig := os.Getenv("HOST_PROC") - os.Setenv("HOST_PROC", "testdata/linux/times_empty") + t.Setenv("HOST_PROC", "testdata/linux/times_empty") _, err := Times(true) if err != nil { t.Error("Times(true) failed") @@ -20,12 +18,10 @@ func TestTimesEmpty(t *testing.T) { if err != nil { t.Error("Times(false) failed") } - os.Setenv("HOST_PROC", orig) } func TestCPUparseStatLine_424(t *testing.T) { - orig := os.Getenv("HOST_PROC") - os.Setenv("HOST_PROC", "testdata/linux/424/proc") + t.Setenv("HOST_PROC", "testdata/linux/424/proc") { l, err := Times(true) if err != nil || len(l) == 0 { @@ -40,7 +36,6 @@ func TestCPUparseStatLine_424(t *testing.T) { } t.Logf("Times(false): %#v", l) } - os.Setenv("HOST_PROC", orig) } func TestCPUCountsAgainstLscpu(t *testing.T) { @@ -93,9 +88,7 @@ func TestCPUCountsAgainstLscpu(t *testing.T) { } func TestCPUCountsLogicalAndroid_1037(t *testing.T) { // https://github.com/shirou/gopsutil/issues/1037 - orig := os.Getenv("HOST_PROC") - os.Setenv("HOST_PROC", "testdata/linux/1037/proc") - defer os.Setenv("HOST_PROC", orig) + t.Setenv("HOST_PROC", "testdata/linux/1037/proc") count, err := Counts(true) if err != nil { diff --git a/cpu/cpu_plan9_test.go b/cpu/cpu_plan9_test.go index 9acf4bf98..2820a3f41 100644 --- a/cpu/cpu_plan9_test.go +++ b/cpu/cpu_plan9_test.go @@ -4,7 +4,6 @@ package cpu import ( - "os" "path/filepath" "testing" @@ -30,13 +29,9 @@ var timesTests = []struct { } func TestTimesPlan9(t *testing.T) { - origRoot := os.Getenv("HOST_ROOT") - t.Cleanup(func() { - os.Setenv("HOST_ROOT", origRoot) - }) for _, tt := range timesTests { t.Run(tt.mockedRootFS, func(t *testing.T) { - os.Setenv("HOST_ROOT", filepath.Join("testdata/plan9", tt.mockedRootFS)) + t.Setenv("HOST_ROOT", filepath.Join("testdata/plan9", tt.mockedRootFS)) stats, err := Times(false) skipIfNotImplementedErr(t, err) if err != nil { diff --git a/internal/common/common.go b/internal/common/common.go index db93b6711..9cb752bbd 100644 --- a/internal/common/common.go +++ b/internal/common/common.go @@ -364,16 +364,6 @@ func HostDev(combineWith ...string) string { return GetEnv("HOST_DEV", "/dev", combineWith...) } -// MockEnv set environment variable and return revert function. -// MockEnv should be used testing only. -func MockEnv(key string, value string) func() { - original := os.Getenv(key) - os.Setenv(key, value) - return func() { - os.Setenv(key, original) - } -} - // getSysctrlEnv sets LC_ALL=C in a list of env vars for use when running // sysctl commands (see DoSysctrl). func getSysctrlEnv(env []string) []string { diff --git a/mem/mem_linux_test.go b/mem/mem_linux_test.go index 61e16a1b2..a0590c961 100644 --- a/mem/mem_linux_test.go +++ b/mem/mem_linux_test.go @@ -4,7 +4,6 @@ package mem import ( - "os" "path/filepath" "reflect" "strings" @@ -111,12 +110,9 @@ var virtualMemoryTests = []struct { } func TestVirtualMemoryLinux(t *testing.T) { - origProc := os.Getenv("HOST_PROC") - defer os.Setenv("HOST_PROC", origProc) - for _, tt := range virtualMemoryTests { t.Run(tt.mockedRootFS, func(t *testing.T) { - os.Setenv("HOST_PROC", filepath.Join("testdata/linux/virtualmemory/", tt.mockedRootFS, "proc")) + t.Setenv("HOST_PROC", filepath.Join("testdata/linux/virtualmemory/", tt.mockedRootFS, "proc")) stat, err := VirtualMemory() skipIfNotImplementedErr(t, err) diff --git a/mem/mem_plan9_test.go b/mem/mem_plan9_test.go index b3480ca4f..1ae353d3a 100644 --- a/mem/mem_plan9_test.go +++ b/mem/mem_plan9_test.go @@ -4,7 +4,6 @@ package mem import ( - "os" "reflect" "testing" ) @@ -27,14 +26,9 @@ var virtualMemoryTests = []struct { } func TestVirtualMemoryPlan9(t *testing.T) { - origProc := os.Getenv("HOST_ROOT") - t.Cleanup(func() { - os.Setenv("HOST_ROOT", origProc) - }) - for _, tt := range virtualMemoryTests { t.Run(tt.mockedRootFS, func(t *testing.T) { - os.Setenv("HOST_ROOT", "testdata/plan9/virtualmemory/") + t.Setenv("HOST_ROOT", "testdata/plan9/virtualmemory/") stat, err := VirtualMemory() skipIfNotImplementedErr(t, err) @@ -62,14 +56,9 @@ var swapMemoryTests = []struct { } func TestSwapMemoryPlan9(t *testing.T) { - origProc := os.Getenv("HOST_ROOT") - t.Cleanup(func() { - os.Setenv("HOST_ROOT", origProc) - }) - for _, tt := range swapMemoryTests { t.Run(tt.mockedRootFS, func(t *testing.T) { - os.Setenv("HOST_ROOT", "testdata/plan9/virtualmemory/") + t.Setenv("HOST_ROOT", "testdata/plan9/virtualmemory/") swap, err := SwapMemory() skipIfNotImplementedErr(t, err) diff --git a/process/process_linux_test.go b/process/process_linux_test.go index 9003095be..76dee4437 100644 --- a/process/process_linux_test.go +++ b/process/process_linux_test.go @@ -11,7 +11,6 @@ import ( "strings" "testing" - "github.com/shirou/gopsutil/v3/internal/common" "github.com/stretchr/testify/assert" ) @@ -62,8 +61,7 @@ func Test_Process_splitProcStat_fromFile(t *testing.T) { if err != nil { t.Error(err) } - f := common.MockEnv("HOST_PROC", "testdata/linux") - defer f() + t.Setenv("HOST_PROC", "testdata/linux") for _, pid := range pids { pid, err := strconv.ParseInt(pid.Name(), 0, 32) if err != nil { @@ -99,8 +97,7 @@ func Test_fillFromCommWithContext(t *testing.T) { if err != nil { t.Error(err) } - f := common.MockEnv("HOST_PROC", "testdata/linux") - defer f() + t.Setenv("HOST_PROC", "testdata/linux") for _, pid := range pids { pid, err := strconv.ParseInt(pid.Name(), 0, 32) if err != nil { @@ -121,8 +118,7 @@ func Test_fillFromStatusWithContext(t *testing.T) { if err != nil { t.Error(err) } - f := common.MockEnv("HOST_PROC", "testdata/linux") - defer f() + t.Setenv("HOST_PROC", "testdata/linux") for _, pid := range pids { pid, err := strconv.ParseInt(pid.Name(), 0, 32) if err != nil { @@ -139,8 +135,7 @@ func Test_fillFromStatusWithContext(t *testing.T) { } func Benchmark_fillFromCommWithContext(b *testing.B) { - f := common.MockEnv("HOST_PROC", "testdata/linux") - defer f() + b.Setenv("HOST_PROC", "testdata/linux") pid := 1060 p, _ := NewProcess(int32(pid)) for i := 0; i < b.N; i++ { @@ -149,8 +144,7 @@ func Benchmark_fillFromCommWithContext(b *testing.B) { } func Benchmark_fillFromStatusWithContext(b *testing.B) { - f := common.MockEnv("HOST_PROC", "testdata/linux") - defer f() + b.Setenv("HOST_PROC", "testdata/linux") pid := 1060 p, _ := NewProcess(int32(pid)) for i := 0; i < b.N; i++ { @@ -163,8 +157,7 @@ func Test_fillFromTIDStatWithContext_lx_brandz(t *testing.T) { if err != nil { t.Error(err) } - f := common.MockEnv("HOST_PROC", "testdata/lx_brandz") - defer f() + t.Setenv("HOST_PROC", "testdata/lx_brandz") for _, pid := range pids { pid, err := strconv.ParseInt(pid.Name(), 0, 32) if err != nil {