/
accesscontrol_test.go
87 lines (75 loc) 路 1.93 KB
/
accesscontrol_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
package accesscontrol_test
import (
"fmt"
"testing"
"github.com/charmbracelet/ssh"
"github.com/charmbracelet/wish/accesscontrol"
"github.com/charmbracelet/wish/testsession"
gossh "golang.org/x/crypto/ssh"
)
const out = "hello world"
func TestMiddleware(t *testing.T) {
requireDenied := func(tb testing.TB, s, cmd string) {
tb.Helper()
expected := fmt.Sprintf("Command is not allowed: %s\n", cmd)
if s != expected {
t.Errorf("expected %q, got %q", expected, s)
}
}
requireOutput := func(tb testing.TB, s string) {
tb.Helper()
if out != s {
t.Errorf("expected %q, got %q", out, s)
}
}
t.Run("no allowed cmds no cmd", func(t *testing.T) {
out, err := setup(t).Output("")
if err != nil {
t.Error(err)
}
requireOutput(t, string(out))
})
t.Run("no allowed cmds with cmd", func(t *testing.T) {
out, err := setup(t).Output("echo")
if err == nil {
t.Errorf("should have errored")
}
requireDenied(t, string(out), "echo")
})
t.Run("allowed cmds no cmd", func(t *testing.T) {
out, err := setup(t, "echo").Output("")
if err != nil {
t.Error(err)
}
requireOutput(t, string(out))
})
t.Run("allowed cmds with allowed cmd", func(t *testing.T) {
out, err := setup(t, "echo").Output("echo")
if err != nil {
t.Error(err)
}
requireOutput(t, string(out))
})
t.Run("allowed cmds with disallowed cmd", func(t *testing.T) {
out, err := setup(t, "echo").Output("cat")
if err == nil {
t.Error(err)
}
requireDenied(t, string(out), "cat")
})
t.Run("allowed cmds with allowed cmd followed disallowed cmd", func(t *testing.T) {
out, err := setup(t, "echo").Output("cat echo")
if err == nil {
t.Error(err)
}
requireDenied(t, string(out), "cat")
})
}
func setup(tb testing.TB, allowedCmds ...string) *gossh.Session {
tb.Helper()
return testsession.New(tb, &ssh.Server{
Handler: accesscontrol.Middleware(allowedCmds...)(func(s ssh.Session) {
s.Write([]byte(out))
}),
}, nil)
}