diff --git a/mount/mount_unix.go b/mount/mount_unix.go index 4053fbbe..964263fd 100644 --- a/mount/mount_unix.go +++ b/mount/mount_unix.go @@ -21,7 +21,7 @@ func Mount(device, target, mType, options string) error { // Unmount lazily unmounts a filesystem on supported platforms, otherwise does // a normal unmount. If target is not a mount point, no error is returned. -func Unmount(target string) error { +var Unmount = func(target string) error { err := unix.Unmount(target, mntDetach) if err == nil || err == unix.EINVAL { //nolint:errorlint // unix errors are bare // Ignore "not mounted" error here. Note the same error diff --git a/mount/mount_unix_test.go b/mount/mount_unix_test.go index 1a703edd..3f345de4 100644 --- a/mount/mount_unix_test.go +++ b/mount/mount_unix_test.go @@ -4,6 +4,7 @@ package mount import ( + "fmt" "io/ioutil" "os" "path" @@ -251,3 +252,42 @@ func TestRecursiveUnmountTooGreedy(t *testing.T) { t.Fatal("expected dir-other to be mounted, but it's not") } } + +func TestRecursiveUnmount_SubMountFailsToUnmount(t *testing.T) { + if os.Getuid() != 0 { + t.Skip("root required") + } + + tmp, err := ioutil.TempDir("", t.Name()) + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmp) + + // Create a bunch of tmpfs mounts. Make sure "dir" itself is not + // a mount point, or we'll hit the fast path in RecursiveUnmount. + dirs := []string{"dir-other", "dir/subdir1", "dir/subdir1/subsub", "dir/subdir2/subsub"} + for _, d := range dirs { + dir := path.Join(tmp, d) + if err := os.MkdirAll(dir, 0o700); err != nil { + t.Fatal(err) + } + if err := Mount("tmpfs", dir, "tmpfs", ""); err != nil { + t.Fatal(err) + } + //nolint:errcheck + defer Unmount(dir) + } + + var mockUnmount = func(target string) error { + return fmt.Errorf("error on calling unmount") + } + var originalUnmount = Unmount + Unmount = mockUnmount + + // Unmount dir, make sure dir-other is still mounted. + if err := RecursiveUnmount(path.Join(tmp, "dir")); err == nil { + t.Fatal(err) + } + Unmount = originalUnmount +}