Skip to content

Commit

Permalink
Refactor updating user values
Browse files Browse the repository at this point in the history
  • Loading branch information
vishalnayak committed Mar 16, 2016
1 parent 5905429 commit cfbab2c
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 24 deletions.
29 changes: 27 additions & 2 deletions builtin/credential/userpass/path_user_password.go
Expand Up @@ -3,6 +3,8 @@ package userpass
import (
"fmt"

"golang.org/x/crypto/bcrypt"

"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
Expand Down Expand Up @@ -33,11 +35,34 @@ func pathUserPassword(b *backend) *framework.Path {

func (b *backend) pathUserPasswordUpdate(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {

username := d.Get("username").(string)

userEntry, err := b.user(req.Storage, username)
if err != nil {
return nil, err
}

err = b.updateUserPassword(req, d, userEntry)
if err != nil {
return nil, err
}

return nil, b.setUser(req.Storage, username, userEntry)
}

func (b *backend) updateUserPassword(req *logical.Request, d *framework.FieldData, userEntry *UserEntry) error {
password := d.Get("password").(string)
if password == "" {
return nil, fmt.Errorf("missing password")
return fmt.Errorf("missing password")
}
// Generate a hash of the password
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
return b.userCreateUpdate(req, d)
userEntry.PasswordHash = hash
return nil
}

const pathUserPasswordHelpSyn = `
Expand Down
26 changes: 25 additions & 1 deletion builtin/credential/userpass/path_user_policies.go
@@ -1,6 +1,8 @@
package userpass

import (
"strings"

"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
Expand Down Expand Up @@ -30,7 +32,29 @@ func pathUserPolicies(b *backend) *framework.Path {

func (b *backend) pathUserPoliciesUpdate(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
return b.userCreateUpdate(req, d)

username := d.Get("username").(string)

userEntry, err := b.user(req.Storage, username)
if err != nil {
return nil, err
}

err = b.updateUserPolicies(req, d, userEntry)
if err != nil {
return nil, err
}

return nil, b.setUser(req.Storage, username, userEntry)
}

func (b *backend) updateUserPolicies(req *logical.Request, d *framework.FieldData, userEntry *UserEntry) error {
policies := strings.Split(d.Get("policies").(string), ",")
for i, p := range policies {
policies[i] = strings.TrimSpace(p)
}
userEntry.Policies = policies
return nil
}

const pathUserPoliciesHelpSyn = `
Expand Down
44 changes: 23 additions & 21 deletions builtin/credential/userpass/path_users.go
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"golang.org/x/crypto/bcrypt"
)

func pathUsers(b *backend) *framework.Path {
Expand Down Expand Up @@ -132,33 +131,36 @@ func (b *backend) userCreateUpdate(req *logical.Request, d *framework.FieldData)
userEntry = &UserEntry{}
}

// Set/update the values of UserEntry only if fields are supplied
if passwordRaw, ok := d.GetOk("password"); ok {
// Generate a hash of the password
hash, err := bcrypt.GenerateFromPassword([]byte(passwordRaw.(string)), bcrypt.DefaultCost)
// "password" will always be set here
err = b.updateUserPassword(req, d, userEntry)
if err != nil {
return nil, err
}

if _, ok := d.GetOk("policies"); ok {
err = b.updateUserPolicies(req, d, userEntry)
if err != nil {
return nil, err
}
userEntry.PasswordHash = hash
}

if policiesRaw, ok := d.GetOk("policies"); ok {
policies := strings.Split(policiesRaw.(string), ",")
for i, p := range policies {
policies[i] = strings.TrimSpace(p)
}
userEntry.Policies = policies
ttlStr := ""
if ttlStrRaw, ok := d.GetOk("ttl"); ok {
ttlStr = ttlStrRaw.(string)
} else if req.Operation == logical.CreateOperation {
ttlStr = d.Get("ttl").(string)
}

_, ttlSet := d.GetOk("ttl")
_, maxTTLSet := d.GetOk("max_ttl")
if ttlSet || maxTTLSet {
ttlStr := d.Get("ttl").(string)
maxTTLStr := d.Get("max_ttl").(string)
userEntry.TTL, userEntry.MaxTTL, err = b.SanitizeTTL(ttlStr, maxTTLStr)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("err: %s", err)), nil
}
maxTTLStr := ""
if maxTTLStrRaw, ok := d.GetOk("max_ttl"); ok {
maxTTLStr = maxTTLStrRaw.(string)
} else if req.Operation == logical.CreateOperation {
maxTTLStr = d.Get("max_ttl").(string)
}

userEntry.TTL, userEntry.MaxTTL, err = b.SanitizeTTL(ttlStr, maxTTLStr)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("err: %s", err)), nil
}

return nil, b.setUser(req.Storage, username, userEntry)
Expand Down

0 comments on commit cfbab2c

Please sign in to comment.