diff --git a/options.go b/options.go index a08bcfe..9478e91 100644 --- a/options.go +++ b/options.go @@ -4,8 +4,8 @@ import ( "bufio" "bytes" "errors" - "fmt" "io" + "log" "os" "path/filepath" "strings" @@ -72,17 +72,13 @@ func WithHostKeyPEM(pem []byte) ssh.Option { // WithAuthorizedKeys allows to use a SSH authorized_keys file to allowlist users. func WithAuthorizedKeys(path string) ssh.Option { return func(s *ssh.Server) error { - keys, err := parseAuthorizedKeys(path) - if err != nil { + if _, err := os.Stat(path); err != nil { return err } return WithPublicKeyAuth(func(_ ssh.Context, key ssh.PublicKey) bool { - for _, upk := range keys { - if ssh.KeysEqual(upk, key) { - return true - } - } - return false + return isAuthorized(path, func(k ssh.PublicKey) bool { + return ssh.KeysEqual(key, k) + }) })(s) } } @@ -92,11 +88,9 @@ func WithAuthorizedKeys(path string) ssh.Option { // Analogous to the TrustedUserCAKeys OpenSSH option. func WithTrustedUserCAKeys(path string) ssh.Option { return func(s *ssh.Server) error { - cas, err := parseAuthorizedKeys(path) - if err != nil { + if _, err := os.Stat(path); err != nil { return err } - return WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { cert, ok := key.(*gossh.Certificate) if !ok { @@ -104,38 +98,33 @@ func WithTrustedUserCAKeys(path string) ssh.Option { return false } - checker := &gossh.CertChecker{ - IsUserAuthority: func(auth gossh.PublicKey) bool { - for _, ca := range cas { - if bytes.Equal(auth.Marshal(), ca.Marshal()) { - // its a cert signed by one of the CAs - return true - } - } - // it is a cert, but signed by another CA - return false - }, - } + return isAuthorized(path, func(k ssh.PublicKey) bool { + checker := &gossh.CertChecker{ + IsUserAuthority: func(auth gossh.PublicKey) bool { + // its a cert signed by one of the CAs + return bytes.Equal(auth.Marshal(), k.Marshal()) + }, + } - if !checker.IsUserAuthority(cert.SignatureKey) { - return false - } + if !checker.IsUserAuthority(cert.SignatureKey) { + return false + } - if err := checker.CheckCert(ctx.User(), cert); err != nil { - return false - } + if err := checker.CheckCert(ctx.User(), cert); err != nil { + return false + } - return true + return true + }) })(s) } } -func parseAuthorizedKeys(path string) ([]ssh.PublicKey, error) { - var keys []ssh.PublicKey - +func isAuthorized(path string, checker func(k ssh.PublicKey) bool) bool { f, err := os.Open(path) if err != nil { - return keys, fmt.Errorf("failed to parse %q: %w", path, err) + log.Printf("failed to parse %q: %s", path, err) + return false } defer f.Close() // nolint: errcheck @@ -146,7 +135,8 @@ func parseAuthorizedKeys(path string) ([]ssh.PublicKey, error) { if errors.Is(err, io.EOF) { break } - return keys, fmt.Errorf("failed to parse %q: %w", path, err) + log.Printf("failed to parse %q: %s", path, err) + return false } if strings.TrimSpace(string(line)) == "" { continue @@ -156,11 +146,14 @@ func parseAuthorizedKeys(path string) ([]ssh.PublicKey, error) { } upk, _, _, _, err := ssh.ParseAuthorizedKey(line) if err != nil { - return keys, fmt.Errorf("failed to parse %q: %w", path, err) + log.Printf("failed to parse %q: %s", path, err) + return false + } + if checker(upk) { + return true } - keys = append(keys, upk) } - return keys, nil + return false } // WithPublicKeyAuth returns an ssh.Option that sets the public key auth handler. diff --git a/options_test.go b/options_test.go index 5108944..f8845b6 100755 --- a/options_test.go +++ b/options_test.go @@ -25,23 +25,17 @@ func TestWithMaxTimeout(t *testing.T) { requireEqual(t, time.Second, s.MaxTimeout) } -func TestParseAuthorizedKeys(t *testing.T) { +func TestIsAuthorized(t *testing.T) { t.Run("valid", func(t *testing.T) { - keys, err := parseAuthorizedKeys("testdata/authorized_keys") - requireNoError(t, err) - requireEqual(t, 6, len(keys)) + requireEqual(t, true, isAuthorized("testdata/authorized_keys", func(k ssh.PublicKey) bool { return true })) }) t.Run("invalid", func(t *testing.T) { - keys, err := parseAuthorizedKeys("testdata/invalid_authorized_keys") - requireEqual(t, `failed to parse "testdata/invalid_authorized_keys": ssh: no key found`, err.Error()) - requireEqual(t, 0, len(keys)) + requireEqual(t, false, isAuthorized("testdata/invalid_authorized_keys", func(k ssh.PublicKey) bool { return true })) }) t.Run("file not found", func(t *testing.T) { - keys, err := parseAuthorizedKeys("testdata/nope_authorized_keys") - requireEqual(t, `failed to parse "testdata/nope_authorized_keys": open testdata/nope_authorized_keys: no such file or directory`, err.Error()) - requireEqual(t, 0, len(keys)) + requireEqual(t, false, isAuthorized("testdata/nope_authorized_keys", func(k ssh.PublicKey) bool { return true })) }) } @@ -65,12 +59,18 @@ func TestWithAuthorizedKeys(t *testing.T) { t.Run("invalid", func(t *testing.T) { s := ssh.Server{} - requireEqual( + requireNoError( t, - `failed to parse "testdata/invalid_authorized_keys": ssh: no key found`, - WithAuthorizedKeys("testdata/invalid_authorized_keys")(&s).Error(), + WithAuthorizedKeys("testdata/invalid_authorized_keys")(&s), ) }) + + t.Run("file not found", func(t *testing.T) { + s := ssh.Server{} + if err := WithAuthorizedKeys("testdata/nope_authorized_keys")(&s); err == nil { + t.Fatal("expected an error, got nil") + } + }) } func TestWithTrustedUserCAKeys(t *testing.T) { @@ -102,7 +102,7 @@ func TestWithTrustedUserCAKeys(t *testing.T) { t.Run("invalid ca key", func(t *testing.T) { s := &ssh.Server{} if err := WithTrustedUserCAKeys("testdata/invalid-path")(s); err == nil { - t.Fatal("expedted an error, got nil") + t.Fatal("expected an error, got nil") } })