Skip to content

Commit

Permalink
update tests for assertion func changes
Browse files Browse the repository at this point in the history
  • Loading branch information
fairclothjm committed Apr 18, 2024
1 parent 875e7a2 commit 78f1070
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 48 deletions.
58 changes: 30 additions & 28 deletions azure_test.go
Expand Up @@ -16,7 +16,9 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/coreos/go-oidc"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault-plugin-auth-azure/client"
"github.com/hashicorp/vault/sdk/logical"
)

// mockKeySet is used in tests to bypass signature validation and return only
Expand Down Expand Up @@ -45,43 +47,43 @@ func newMockVerifier() client.TokenVerifier {
}

type mockComputeClient struct {
computeClientFunc func(vmName string) (armcompute.VirtualMachinesClientGetResponse, error)
computeClientFunc computeClientFunc
}

type mockVMSSClient struct {
vmssClientFunc func(vmssName string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error)
vmssClientFunc vmssClientFunc
}

type mockMSIClient struct {
msiClientFunc func(resourceName string) (armmsi.UserAssignedIdentitiesClientGetResponse, error)
msiListFunc func(resourceGroup string) armmsi.UserAssignedIdentitiesClientListByResourceGroupResponse
msiClientFunc msiClientFunc
msiListFunc msiListFunc
}

type mockResourceClient struct {
resourceClientFunc func(resourceID string) (armresources.ClientGetByIDResponse, error)
resourceClientFunc resourceClientFunc
}

type mockProvidersClient struct {
providersClientFunc func(string) (armresources.ProvidersClientGetResponse, error)
providersClientFunc providersClientFunc
}

func (c *mockComputeClient) Get(_ context.Context, _, vmName string, _ *armcompute.VirtualMachinesClientGetOptions) (armcompute.VirtualMachinesClientGetResponse, error) {
func (c *mockComputeClient) Get(ctx context.Context, _, vmName string, _ *armcompute.VirtualMachinesClientGetOptions) (armcompute.VirtualMachinesClientGetResponse, error) {
if c.computeClientFunc != nil {
return c.computeClientFunc(vmName)
return c.computeClientFunc(ctx, hclog.NewNullLogger(), nil, vmName)
}
return armcompute.VirtualMachinesClientGetResponse{}, nil
}

func (c *mockVMSSClient) Get(_ context.Context, _, vmssName string, _ *armcompute.VirtualMachineScaleSetsClientGetOptions) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) {
func (c *mockVMSSClient) Get(ctx context.Context, _, vmssName string, _ *armcompute.VirtualMachineScaleSetsClientGetOptions) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) {
if c.vmssClientFunc != nil {
return c.vmssClientFunc(vmssName)
return c.vmssClientFunc(ctx, hclog.NewNullLogger(), nil, vmssName)
}
return armcompute.VirtualMachineScaleSetsClientGetResponse{}, nil
}

func (c *mockMSIClient) Get(_ context.Context, _, resourceName string, _ *armmsi.UserAssignedIdentitiesClientGetOptions) (armmsi.UserAssignedIdentitiesClientGetResponse, error) {
func (c *mockMSIClient) Get(ctx context.Context, _, resourceName string, _ *armmsi.UserAssignedIdentitiesClientGetOptions) (armmsi.UserAssignedIdentitiesClientGetResponse, error) {
if c.msiClientFunc != nil {
return c.msiClientFunc(resourceName)
return c.msiClientFunc(ctx, hclog.NewNullLogger(), nil, resourceName)
}
return armmsi.UserAssignedIdentitiesClientGetResponse{}, nil
}
Expand All @@ -101,33 +103,33 @@ func (c *mockMSIClient) NewListByResourceGroupPager(resourceGroup string, _ *arm
return nil
}

func (c *mockResourceClient) GetByID(_ context.Context, resourceID, _ string, _ *armresources.ClientGetByIDOptions) (armresources.ClientGetByIDResponse, error) {
func (c *mockResourceClient) GetByID(ctx context.Context, resourceID, _ string, _ *armresources.ClientGetByIDOptions) (armresources.ClientGetByIDResponse, error) {
if c.resourceClientFunc != nil {
return c.resourceClientFunc(resourceID)
return c.resourceClientFunc(ctx, hclog.NewNullLogger(), nil, resourceID)
}
return armresources.ClientGetByIDResponse{}, nil
}

func (c *mockProvidersClient) Get(_ context.Context, resourceID string, _ *armresources.ProvidersClientGetOptions) (armresources.ProvidersClientGetResponse, error) {
func (c *mockProvidersClient) Get(ctx context.Context, resourceID string, _ *armresources.ProvidersClientGetOptions) (armresources.ProvidersClientGetResponse, error) {
if c.providersClientFunc != nil {
return c.providersClientFunc(resourceID)
return c.providersClientFunc(ctx, hclog.NewNullLogger(), nil, resourceID)
}
return armresources.ProvidersClientGetResponse{}, nil
}

type computeClientFunc func(vmName string) (armcompute.VirtualMachinesClientGetResponse, error)
type computeClientFunc func(ctx context.Context, logger hclog.Logger, sys logical.SystemView, vmName string) (armcompute.VirtualMachinesClientGetResponse, error)

type vmssClientFunc func(vmssName string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error)
type vmssClientFunc func(ctx context.Context, logger hclog.Logger, sys logical.SystemView, vmssName string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error)

type msiClientFunc func(resourceName string) (armmsi.UserAssignedIdentitiesClientGetResponse, error)
type msiClientFunc func(ctx context.Context, logger hclog.Logger, sys logical.SystemView, resourceName string) (armmsi.UserAssignedIdentitiesClientGetResponse, error)

type msiListFunc func(resoucename string) armmsi.UserAssignedIdentitiesClientListByResourceGroupResponse

type msGraphClientFunc func() (client.MSGraphClient, error)
type msGraphClientFunc func(ctx context.Context, logger hclog.Logger, sys logical.SystemView) (client.MSGraphClient, error)

type resourceClientFunc func(resourceID string) (armresources.ClientGetByIDResponse, error)
type resourceClientFunc func(ctx context.Context, logger hclog.Logger, sys logical.SystemView, resourceID string) (armresources.ClientGetByIDResponse, error)

type providersClientFunc func(string) (armresources.ProvidersClientGetResponse, error)
type providersClientFunc func(ctx context.Context, logger hclog.Logger, sys logical.SystemView, s string) (armresources.ProvidersClientGetResponse, error)

type mockProvider struct {
computeClientFunc
Expand All @@ -153,36 +155,36 @@ func (*mockProvider) TokenVerifier() client.TokenVerifier {
return newMockVerifier()
}

func (p *mockProvider) ComputeClient(string) (client.ComputeClient, error) {
func (p *mockProvider) ComputeClient(ctx context.Context, logger hclog.Logger, sys logical.SystemView, s string) (client.ComputeClient, error) {
return &mockComputeClient{
computeClientFunc: p.computeClientFunc,
}, nil
}

func (p *mockProvider) VMSSClient(string) (client.VMSSClient, error) {
func (p *mockProvider) VMSSClient(ctx context.Context, logger hclog.Logger, sys logical.SystemView, s string) (client.VMSSClient, error) {
return &mockVMSSClient{
vmssClientFunc: p.vmssClientFunc,
}, nil
}

func (p *mockProvider) MSIClient(string) (client.MSIClient, error) {
func (p *mockProvider) MSIClient(ctx context.Context, logger hclog.Logger, sys logical.SystemView, s string) (client.MSIClient, error) {
return &mockMSIClient{
msiClientFunc: p.msiClientFunc,
msiListFunc: p.msiListFunc,
}, nil
}

func (p *mockProvider) MSGraphClient() (client.MSGraphClient, error) {
func (p *mockProvider) MSGraphClient(ctx context.Context, logger hclog.Logger, sys logical.SystemView) (client.MSGraphClient, error) {
return nil, nil
}

func (p *mockProvider) ResourceClient(string) (client.ResourceClient, error) {
func (p *mockProvider) ResourceClient(ctx context.Context, logger hclog.Logger, sys logical.SystemView, s string) (client.ResourceClient, error) {
return &mockResourceClient{
resourceClientFunc: p.resourceClientFunc,
}, nil
}

func (p *mockProvider) ProvidersClient(string) (client.ProvidersClient, error) {
func (p *mockProvider) ProvidersClient(ctx context.Context, logger hclog.Logger, sys logical.SystemView, s string) (client.ProvidersClient, error) {
return &mockProvidersClient{
providersClientFunc: p.providersClientFunc,
}, nil
Expand Down
2 changes: 1 addition & 1 deletion path_config_test.go
Expand Up @@ -129,7 +129,7 @@ func TestConfig(t *testing.T) {
"tenant_id": "foo",
}

err = testConfigUpdate(t, b, s, configSubset)
_, err = testConfigUpdate(t, b, s, configSubset)
if err != nil {
t.Fatal(err)
}
Expand Down
39 changes: 20 additions & 19 deletions path_login_test.go
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/coreos/go-oidc"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault-plugin-auth-azure/client"
"github.com/hashicorp/vault/sdk/helper/policyutil"
"github.com/hashicorp/vault/sdk/logical"
Expand Down Expand Up @@ -184,7 +185,7 @@ func TestLogin_ManagedIdentity(t *testing.T) {
roleName := "test-role"

// setup test response functions that mock the client GetByID response
nilIdentityRespFunc := func(_ string) (armresources.ClientGetByIDResponse, error) {
nilIdentityRespFunc := func(context.Context, hclog.Logger, logical.SystemView, string) (armresources.ClientGetByIDResponse, error) {
return armresources.ClientGetByIDResponse{}, nil
}
userAssignedRespFunc, systemAssignedRespFunc := getResourceByIDResponses(t, principalID)
Expand All @@ -195,7 +196,7 @@ func TestLogin_ManagedIdentity(t *testing.T) {
claims map[string]interface{}
roleData map[string]interface{}
loginData map[string]interface{}
clientFunc func(resourceID string) (armresources.ClientGetByIDResponse, error)
clientFunc func(context.Context, hclog.Logger, logical.SystemView, string) (armresources.ClientGetByIDResponse, error)
expectError bool
}{
"login happy path user-assigned managed identity": {
Expand Down Expand Up @@ -945,8 +946,8 @@ func TestGetAPIVersionForResource(t *testing.T) {
// the azure arm resource client responses. If principalID is an empty string
// then no identity data will be set in the response.
func getResourceByIDResponses(t *testing.T, principalID string) (
func(_ string) (armresources.ClientGetByIDResponse, error),
func(_ string) (armresources.ClientGetByIDResponse, error),
func(context.Context, hclog.Logger, logical.SystemView, string) (armresources.ClientGetByIDResponse, error),
func(context.Context, hclog.Logger, logical.SystemView, string) (armresources.ClientGetByIDResponse, error),
) {
t.Helper()
u := armresources.ClientGetByIDResponse{
Expand All @@ -972,10 +973,10 @@ func getResourceByIDResponses(t *testing.T, principalID string) (
s.GenericResource.Identity.PrincipalID = &principalID
}

userAssignedRespFunc := func(_ string) (armresources.ClientGetByIDResponse, error) {
userAssignedRespFunc := func(context.Context, hclog.Logger, logical.SystemView, string) (armresources.ClientGetByIDResponse, error) {
return u, nil
}
systemAssignedRespFunc := func(_ string) (armresources.ClientGetByIDResponse, error) {
systemAssignedRespFunc := func(context.Context, hclog.Logger, logical.SystemView, string) (armresources.ClientGetByIDResponse, error) {
return s, nil
}

Expand All @@ -984,7 +985,7 @@ func getResourceByIDResponses(t *testing.T, principalID string) (

// getProvidersResponse is a test helper to get the function that returns
// the azure arm resource providers client response.
func getProvidersResponse(t *testing.T, resourceID string) func(_ string) (armresources.ProvidersClientGetResponse, error) {
func getProvidersResponse(t *testing.T, resourceID string) func(context.Context, hclog.Logger, logical.SystemView, string) (armresources.ProvidersClientGetResponse, error) {
t.Helper()

resourceType, err := arm.ParseResourceType(resourceID)
Expand All @@ -1008,7 +1009,7 @@ func getProvidersResponse(t *testing.T, resourceID string) func(_ string) (armre
},
},
}
providersRespFunc := func(_ string) (armresources.ProvidersClientGetResponse, error) {
providersRespFunc := func(context.Context, hclog.Logger, logical.SystemView, string) (armresources.ProvidersClientGetResponse, error) {
return u, nil
}
return providersRespFunc
Expand Down Expand Up @@ -1036,14 +1037,14 @@ func testJWT(t *testing.T, payload map[string]interface{}) string {
}

func getTestBackendFunctions(withLocation bool) (
func(_ string) (armcompute.VirtualMachinesClientGetResponse, error),
func(_ string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error),
func(_ string) (armmsi.UserAssignedIdentitiesClientGetResponse, error),
func(context.Context, hclog.Logger, logical.SystemView, string) (armcompute.VirtualMachinesClientGetResponse, error),
func(context.Context, hclog.Logger, logical.SystemView, string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error),
func(context.Context, hclog.Logger, logical.SystemView, string) (armmsi.UserAssignedIdentitiesClientGetResponse, error),
) {
principalID := "123e4567-e89b-12d3-a456-426655440000"

if !withLocation {
c := func(_ string) (armcompute.VirtualMachinesClientGetResponse, error) {
c := func(context.Context, hclog.Logger, logical.SystemView, string) (armcompute.VirtualMachinesClientGetResponse, error) {
id := armcompute.VirtualMachineIdentity{
PrincipalID: &principalID,
}
Expand All @@ -1053,7 +1054,7 @@ func getTestBackendFunctions(withLocation bool) (
},
}, nil
}
v := func(_ string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) {
v := func(context.Context, hclog.Logger, logical.SystemView, string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) {
id := armcompute.VirtualMachineScaleSetIdentity{
PrincipalID: &principalID,
}
Expand All @@ -1062,7 +1063,7 @@ func getTestBackendFunctions(withLocation bool) (
}}, nil
}

m := func(_ string) (armmsi.UserAssignedIdentitiesClientGetResponse, error) {
m := func(context.Context, hclog.Logger, logical.SystemView, string) (armmsi.UserAssignedIdentitiesClientGetResponse, error) {
userAssignedIdentityProperties := armmsi.UserAssignedIdentityProperties{
PrincipalID: &principalID,
}
Expand All @@ -1075,7 +1076,7 @@ func getTestBackendFunctions(withLocation bool) (
} else {
location := "loc"

c := func(vmName string) (armcompute.VirtualMachinesClientGetResponse, error) {
c := func(_ context.Context, _ hclog.Logger, _ logical.SystemView, vmName string) (armcompute.VirtualMachinesClientGetResponse, error) {
id := armcompute.VirtualMachineIdentity{
PrincipalID: &principalID,
}
Expand All @@ -1094,7 +1095,7 @@ func getTestBackendFunctions(withLocation bool) (
}
return armcompute.VirtualMachinesClientGetResponse{}, nil
}
v := func(vmssName string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) {
v := func(_ context.Context, _ hclog.Logger, _ logical.SystemView, vmssName string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) {
id := armcompute.VirtualMachineScaleSetIdentity{
PrincipalID: &principalID,
}
Expand All @@ -1114,7 +1115,7 @@ func getTestBackendFunctions(withLocation bool) (
return armcompute.VirtualMachineScaleSetsClientGetResponse{}, nil
}

m := func(_ string) (armmsi.UserAssignedIdentitiesClientGetResponse, error) {
m := func(context.Context, hclog.Logger, logical.SystemView, string) (armmsi.UserAssignedIdentitiesClientGetResponse, error) {
userAssignedIdentityProperties := armmsi.UserAssignedIdentityProperties{
PrincipalID: &principalID,
}
Expand All @@ -1127,8 +1128,8 @@ func getTestBackendFunctions(withLocation bool) (
}
}

func getTestMSGraphClient() func() (client.MSGraphClient, error) {
return func() (client.MSGraphClient, error) {
func getTestMSGraphClient() func(context.Context, hclog.Logger, logical.SystemView) (client.MSGraphClient, error) {
return func(context.Context, hclog.Logger, logical.SystemView) (client.MSGraphClient, error) {
return nil, nil
}
}

0 comments on commit 78f1070

Please sign in to comment.