Skip to content

Commit

Permalink
Reworking the logic per feedback, adding basic test.
Browse files Browse the repository at this point in the history
  • Loading branch information
robison committed May 3, 2021
1 parent 14df1c3 commit 9733066
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 15 deletions.
125 changes: 124 additions & 1 deletion builtin/logical/ssh/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ import (

"golang.org/x/crypto/ssh"

"github.com/hashicorp/vault/builtin/credential/userpass"
"github.com/hashicorp/vault/helper/testhelpers/docker"
logicaltest "github.com/hashicorp/vault/helper/testhelpers/logical"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/mapstructure"
)
Expand Down Expand Up @@ -158,7 +160,7 @@ func prepareTestContainer(t *testing.T, tag, caPublicKeyPEM string) (func(), str

// Install util-linux for non-busybox flock that supports timeout option
err = testSSH("vaultssh", sshAddress, ssh.PublicKeys(signer), fmt.Sprintf(`
set -e;
set -e;
sudo ln -s /config /home/vaultssh
sudo apk add util-linux;
echo "LogLevel DEBUG" | sudo tee -a /config/ssh_host_keys/sshd_config;
Expand Down Expand Up @@ -1318,6 +1320,127 @@ func TestBackend_DisallowUserProvidedKeyIDs(t *testing.T) {
logicaltest.Test(t, testCase)
}

func TestBackend_DefaultExtensionsTemplating(t *testing.T) {
coreConfig := &vault.CoreConfig{
CredentialBackends: map[string]logical.Factory{
"userpass": userpass.Factory,
},
LogicalBackends: map[string]logical.Factory{
"ssh": Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client

// Write test policy for userpass auth method.
err := client.Sys().PutPolicy("test", `
path "ssh/*" {
capabilities = ["update"]
}`)
if err != nil {
t.Fatal(err)
}

// Enable userpass auth method.
if err := client.Sys().EnableAuth("userpass", "userpass", ""); err != nil {
t.Fatal(err)
}

userIdentity := "userpassname"
// Configure test role for userpass.
if _, err := client.Logical().Write("auth/userpass/users/" + userIdentity, map[string]interface{}{
"password": "test",
"policies": "test",
}); err != nil {
t.Fatal(err)
}

// Login userpass for test role and keep client token.
secret, err := client.Logical().Write("auth/userpass/login/userpassname", map[string]interface{}{
"password": "test",
})
if err != nil || secret == nil {
t.Fatal(err)
}
userpassToken := secret.Auth.ClientToken

// Get auth accessor for identity template.
auths, err := client.Sys().ListAuth()
if err != nil {
t.Fatal(err)
}
userpassAccessor := auths["userpass/"].Accessor

// Mount SSH.
err = client.Sys().Mount("ssh", &api.MountInput{
Type: "ssh",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "60h",
},
})
if err != nil {
t.Fatal(err)
}

// Generate internal SSH CA.
_, err = client.Logical().Write("ssh/config/ca", map[string]interface{}{
"generate_signing_key": true,
"key_bits": 2048,
"key_type": "ca",
})
if err != nil {
t.Fatal(err)
}

// Write SSH role.
_, err = client.Logical().Write("ssh/roles/test", map[string]interface{}{
"key_type": "ca",
"allowed_extensions": "login@zipzap.com",
"allow_user_certificates": true,
"allowed_users": "tuber",
"default_user": "tuber",
"default_extensions_template": true,
"default_extensions": map[string]interface{}{
"login@foobar.com": "{{identity.entity.aliases." + userpassAccessor + ".name}}",
},
})
if err != nil {
t.Fatal(err)
}

// Issue SSH certificate with userpassToken.
client.SetToken(userpassToken)
resp, err := client.Logical().Write("ssh/sign/test", map[string]interface{}{
"public_key": publicKey4096,
})
if err != nil {
t.Fatal(err)
}
signedKey := resp.Data["signed_key"].(string)
key, _ := base64.StdEncoding.DecodeString(strings.Split(signedKey, " ")[1])

parsedKey, err := ssh.ParsePublicKey(key)
if err != nil {
t.Fatal(err)
}

cert := parsedKey.(*ssh.Certificate)

expectedExtensionPermissions := map[string]string{
"login@foobar.com": userIdentity,
}

if !reflect.DeepEqual(cert.Permissions.Extensions, expectedExtensionPermissions) {
t.Fatalf("incorrect Permissions.Extensions: Expected: %v, Actual: %v", expectedExtensionPermissions, cert.Permissions.Extensions)
}

}

func configCaStep(caPublicKey, caPrivateKey string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Expand Down
27 changes: 13 additions & 14 deletions builtin/logical/ssh/path_sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,29 +359,27 @@ func (b *backend) calculateCriticalOptions(data *framework.FieldData, role *sshR
func (b *backend) calculateExtensions(data *framework.FieldData, req *logical.Request, role *sshRole) (map[string]string, error) {
unparsedExtensions := data.Get("extensions").(map[string]interface{})
extensions := make(map[string]string)

if len(unparsedExtensions) > 0 {
extensions = convertMapToStringValue(unparsedExtensions)
parsedExtensions := convertMapToStringValue(unparsedExtensions)
if role.AllowedExtensions != "" {
notAllowed := []string{}
allowedExtensions := strings.Split(role.AllowedExtensions, ",")

for extension := range extensions {
if !strutil.StrListContains(allowedExtensions, extension) {
notAllowed = append(notAllowed, extension)
for extensionKey, extensionValue := range parsedExtensions {
if !strutil.StrListContains(allowedExtensions, extensionKey) {
notAllowed = append(notAllowed, extensionKey)
} else {
extensions[extensionKey] = extensionValue
}
}

if len(notAllowed) != 0 {
return nil, fmt.Errorf("extensions %v are not on allowed list", notAllowed)
}
}
} else {
extensions = role.DefaultExtensions
}

if role.DefaultExtensionsTemplate {
templatedExtensions := make(map[string]string)
for extensionKey, extensionValue := range extensions {
} else if role.DefaultExtensionsTemplate {
for extensionKey, extensionValue := range role.DefaultExtensions {
// Look for templating markers {{ .* }}
matched, _ := regexp.MatchString(`^{{.+?}}$`, extensionValue)
if matched {
Expand All @@ -390,17 +388,18 @@ func (b *backend) calculateExtensions(data *framework.FieldData, req *logical.Re
templateExtensionValue, err := framework.PopulateIdentityTemplate(extensionValue, req.EntityID, b.System())
if err == nil {
// Template returned an extension value that we can use
templatedExtensions[extensionKey] = templateExtensionValue
extensions[extensionKey] = templateExtensionValue
} else {
return nil, fmt.Errorf("template '%s' could not be rendered -> %s", extensionValue, err)
}
}
} else {
// Static extension value or err template
templatedExtensions[extensionKey] = extensionValue
extensions[extensionKey] = extensionValue
}
}
return templatedExtensions, nil
} else {
extensions = role.DefaultExtensions
}

return extensions, nil
Expand Down

0 comments on commit 9733066

Please sign in to comment.