Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable login from Azure VMs with user-assigned identities #33

Merged
merged 2 commits into from Apr 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion azure.go
Expand Up @@ -11,7 +11,7 @@ import (
"sync"
"time"

"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2017-12-01/compute"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-07-01/compute"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/azure/auth"
Expand Down
2 changes: 1 addition & 1 deletion azure_test.go
Expand Up @@ -7,7 +7,7 @@ import (
"fmt"
"strings"

"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2017-12-01/compute"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-07-01/compute"
oidc "github.com/coreos/go-oidc"
)

Expand Down
52 changes: 30 additions & 22 deletions path_login.go
Expand Up @@ -6,7 +6,7 @@ import (
"fmt"
"time"

"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2017-12-01/compute"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-07-01/compute"
"github.com/Azure/go-autorest/autorest/to"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/sdk/framework"
Expand Down Expand Up @@ -208,7 +208,8 @@ func (b *azureAuthBackend) verifyResource(ctx context.Context, subscriptionID, r
return errors.New("subscription_id and resource_group_name are required")
}

var principalID, location *string
var location *string
principalIDs := map[string]struct{}{}

switch {
// If vmss name is specified, the vm name will be ignored and only the scale set
Expand All @@ -224,21 +225,25 @@ func (b *azureAuthBackend) verifyResource(ctx context.Context, subscriptionID, r
return errwrap.Wrapf("unable to retrieve virtual machine scale set metadata: {{err}}", err)
}

if vmss.Identity == nil {
return errors.New("vmss client did not return identity information")
}
if vmss.Identity.PrincipalID == nil {
return errors.New("vmss principal id is empty")
}

// Check bound scale sets
if len(role.BoundScaleSets) > 0 && !strListContains(role.BoundScaleSets, vmssName) {
return errors.New("scale set not authorized")
}

principalID = vmss.Identity.PrincipalID
location = vmss.Location

if vmss.Identity == nil {
return errors.New("vmss client did not return identity information")
}
// if system-assigned identity's principal id is available
if vmss.Identity.PrincipalID != nil {
principalIDs[to.String(vmss.Identity.PrincipalID)] = struct{}{}
break
}
// if not, look for user-assigned identities
for _, userIdentity := range vmss.Identity.UserAssignedIdentities {
principalIDs[to.String(userIdentity.PrincipalID)] = struct{}{}
}
case vmName != "":
client, err := b.provider.ComputeClient(subscriptionID)
if err != nil {
Expand All @@ -250,29 +255,32 @@ func (b *azureAuthBackend) verifyResource(ctx context.Context, subscriptionID, r
return errwrap.Wrapf("unable to retrieve virtual machine metadata: {{err}}", err)
}

location = vm.Location

if vm.Identity == nil {
return errors.New("vm client did not return identity information")
}

if vm.Identity.PrincipalID == nil {
return errors.New("vm principal id is empty")
}

// Check bound scale sets
if len(role.BoundScaleSets) > 0 {
return errors.New("bound scale set defined but this vm isn't in a scale set")
}

principalID = vm.Identity.PrincipalID
location = vm.Location

// if system-assigned identity's principal id is available
if vm.Identity.PrincipalID != nil {
principalIDs[to.String(vm.Identity.PrincipalID)] = struct{}{}
break
}
// if not, look for user-assigned identities
for _, userIdentity := range vm.Identity.UserAssignedIdentities {
principalIDs[to.String(userIdentity.PrincipalID)] = struct{}{}
}
default:
return errors.New("either vm_name or vmss_name is required")
}

// Ensure the principal id for the VM matches the verified token OID
if to.String(principalID) != claims.ObjectID {
return errors.New("token object id does not match virtual machine principal id")
// Ensure the token OID is the principal id of the system-assigned identity
// or one of the user-assigned identities of the VM
if _, ok := principalIDs[claims.ObjectID]; !ok {
return errors.New("token object id does not match virtual machine identities")
}

// Check bound subscriptions
Expand Down
74 changes: 73 additions & 1 deletion path_login_test.go
Expand Up @@ -8,7 +8,7 @@ import (
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2017-12-01/compute"
"github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-07-01/compute"
oidc "github.com/coreos/go-oidc"
"github.com/hashicorp/vault/sdk/helper/policyutil"
"github.com/hashicorp/vault/sdk/logical"
Expand Down Expand Up @@ -216,6 +216,78 @@ func TestLogin_BoundResourceGroup(t *testing.T) {
testLoginFailure(t, b, s, loginData, claims, roleData)
}

func TestLogin_BoundResourceGroupWithUserAssignedID(t *testing.T) {
principalID := "prinID"
badPrincipalID := "badID"
c := func(vmName string) (compute.VirtualMachine, error) {
id := compute.VirtualMachineIdentity{
UserAssignedIdentities: map[string]*compute.VirtualMachineIdentityUserAssignedIdentitiesValue{
"mockuserassignedmsi": &compute.VirtualMachineIdentityUserAssignedIdentitiesValue{
PrincipalID: &principalID,
},
},
}
return compute.VirtualMachine{
Identity: &id,
}, nil
}
v := func(vmName string) (compute.VirtualMachineScaleSet, error) {
id := compute.VirtualMachineScaleSetIdentity{
UserAssignedIdentities: map[string]*compute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue{
"mockuserassignedmsi": &compute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue{
PrincipalID: &principalID,
},
},
}
return compute.VirtualMachineScaleSet{
Identity: &id,
}, nil
}
b, s := getTestBackendWithComputeClient(t, c, v)

roleName := "testrole"
rg := "rg"
roleData := map[string]interface{}{
"name": roleName,
"policies": []string{"dev", "prod"},
"bound_resource_groups": []string{rg},
}
testRoleCreate(t, b, s, roleData)

claims := map[string]interface{}{
"exp": time.Now().Add(60 * time.Second).Unix(),
"nbf": time.Now().Add(-60 * time.Second).Unix(),
"oid": principalID,
}
badClaims := map[string]interface{}{
"exp": time.Now().Add(60 * time.Second).Unix(),
"nbf": time.Now().Add(-60 * time.Second).Unix(),
"oid": badPrincipalID,
}

loginData := map[string]interface{}{
"role": roleName,
}
testLoginFailure(t, b, s, loginData, claims, roleData)

loginData["subscription_id"] = "sub"
testLoginFailure(t, b, s, loginData, claims, roleData)

loginData["resource_group_name"] = rg
testLoginFailure(t, b, s, loginData, claims, roleData)

loginData["vmss_name"] = "vmss"
testLoginSuccess(t, b, s, loginData, claims, roleData)
delete(loginData, "vmss_name")

loginData["vm_name"] = "vm"
testLoginSuccess(t, b, s, loginData, claims, roleData)
testLoginFailure(t, b, s, loginData, badClaims, roleData)

loginData["resource_group_name"] = "bad rg"
testLoginFailure(t, b, s, loginData, claims, roleData)
}

func TestLogin_BoundLocation(t *testing.T) {
principalID := "prinID"
location := "loc"
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.