Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: always read the authorized keys file #88

Merged
merged 1 commit into from Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
71 changes: 32 additions & 39 deletions options.go
Expand Up @@ -4,8 +4,8 @@ import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -72,17 +72,13 @@ func WithHostKeyPEM(pem []byte) ssh.Option {
// WithAuthorizedKeys allows to use a SSH authorized_keys file to allowlist users.
func WithAuthorizedKeys(path string) ssh.Option {
return func(s *ssh.Server) error {
keys, err := parseAuthorizedKeys(path)
if err != nil {
if _, err := os.Stat(path); err != nil {
return err
}
return WithPublicKeyAuth(func(_ ssh.Context, key ssh.PublicKey) bool {
for _, upk := range keys {
if ssh.KeysEqual(upk, key) {
return true
}
}
return false
return isAuthorized(path, func(k ssh.PublicKey) bool {
return ssh.KeysEqual(key, k)
})
})(s)
}
}
Expand All @@ -92,50 +88,43 @@ func WithAuthorizedKeys(path string) ssh.Option {
// Analogous to the TrustedUserCAKeys OpenSSH option.
func WithTrustedUserCAKeys(path string) ssh.Option {
return func(s *ssh.Server) error {
cas, err := parseAuthorizedKeys(path)
if err != nil {
if _, err := os.Stat(path); err != nil {
return err
}

return WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
cert, ok := key.(*gossh.Certificate)
if !ok {
// not a certificate...
return false
}

checker := &gossh.CertChecker{
IsUserAuthority: func(auth gossh.PublicKey) bool {
for _, ca := range cas {
if bytes.Equal(auth.Marshal(), ca.Marshal()) {
// its a cert signed by one of the CAs
return true
}
}
// it is a cert, but signed by another CA
return false
},
}
return isAuthorized(path, func(k ssh.PublicKey) bool {
checker := &gossh.CertChecker{
IsUserAuthority: func(auth gossh.PublicKey) bool {
// its a cert signed by one of the CAs
return bytes.Equal(auth.Marshal(), k.Marshal())
},
}

if !checker.IsUserAuthority(cert.SignatureKey) {
return false
}
if !checker.IsUserAuthority(cert.SignatureKey) {
return false
}

if err := checker.CheckCert(ctx.User(), cert); err != nil {
return false
}
if err := checker.CheckCert(ctx.User(), cert); err != nil {
return false
}

return true
return true
})
})(s)
}
}

func parseAuthorizedKeys(path string) ([]ssh.PublicKey, error) {
var keys []ssh.PublicKey

func isAuthorized(path string, checker func(k ssh.PublicKey) bool) bool {
f, err := os.Open(path)
if err != nil {
return keys, fmt.Errorf("failed to parse %q: %w", path, err)
log.Printf("failed to parse %q: %s", path, err)
return false
}
defer f.Close() // nolint: errcheck

Expand All @@ -146,7 +135,8 @@ func parseAuthorizedKeys(path string) ([]ssh.PublicKey, error) {
if errors.Is(err, io.EOF) {
break
}
return keys, fmt.Errorf("failed to parse %q: %w", path, err)
log.Printf("failed to parse %q: %s", path, err)
return false
}
if strings.TrimSpace(string(line)) == "" {
continue
Expand All @@ -156,11 +146,14 @@ func parseAuthorizedKeys(path string) ([]ssh.PublicKey, error) {
}
upk, _, _, _, err := ssh.ParseAuthorizedKey(line)
if err != nil {
return keys, fmt.Errorf("failed to parse %q: %w", path, err)
log.Printf("failed to parse %q: %s", path, err)
return false
}
if checker(upk) {
return true
}
keys = append(keys, upk)
}
return keys, nil
return false
}

// WithPublicKeyAuth returns an ssh.Option that sets the public key auth handler.
Expand Down
28 changes: 14 additions & 14 deletions options_test.go
Expand Up @@ -25,23 +25,17 @@ func TestWithMaxTimeout(t *testing.T) {
requireEqual(t, time.Second, s.MaxTimeout)
}

func TestParseAuthorizedKeys(t *testing.T) {
func TestIsAuthorized(t *testing.T) {
t.Run("valid", func(t *testing.T) {
keys, err := parseAuthorizedKeys("testdata/authorized_keys")
requireNoError(t, err)
requireEqual(t, 6, len(keys))
requireEqual(t, true, isAuthorized("testdata/authorized_keys", func(k ssh.PublicKey) bool { return true }))
})

t.Run("invalid", func(t *testing.T) {
keys, err := parseAuthorizedKeys("testdata/invalid_authorized_keys")
requireEqual(t, `failed to parse "testdata/invalid_authorized_keys": ssh: no key found`, err.Error())
requireEqual(t, 0, len(keys))
requireEqual(t, false, isAuthorized("testdata/invalid_authorized_keys", func(k ssh.PublicKey) bool { return true }))
})

t.Run("file not found", func(t *testing.T) {
keys, err := parseAuthorizedKeys("testdata/nope_authorized_keys")
requireEqual(t, `failed to parse "testdata/nope_authorized_keys": open testdata/nope_authorized_keys: no such file or directory`, err.Error())
requireEqual(t, 0, len(keys))
requireEqual(t, false, isAuthorized("testdata/nope_authorized_keys", func(k ssh.PublicKey) bool { return true }))
})
}

Expand All @@ -65,12 +59,18 @@ func TestWithAuthorizedKeys(t *testing.T) {

t.Run("invalid", func(t *testing.T) {
s := ssh.Server{}
requireEqual(
requireNoError(
t,
`failed to parse "testdata/invalid_authorized_keys": ssh: no key found`,
WithAuthorizedKeys("testdata/invalid_authorized_keys")(&s).Error(),
WithAuthorizedKeys("testdata/invalid_authorized_keys")(&s),
)
})

t.Run("file not found", func(t *testing.T) {
s := ssh.Server{}
if err := WithAuthorizedKeys("testdata/nope_authorized_keys")(&s); err == nil {
t.Fatal("expected an error, got nil")
}
})
}

func TestWithTrustedUserCAKeys(t *testing.T) {
Expand Down Expand Up @@ -102,7 +102,7 @@ func TestWithTrustedUserCAKeys(t *testing.T) {
t.Run("invalid ca key", func(t *testing.T) {
s := &ssh.Server{}
if err := WithTrustedUserCAKeys("testdata/invalid-path")(s); err == nil {
t.Fatal("expedted an error, got nil")
t.Fatal("expected an error, got nil")
}
})

Expand Down