From 43bcca9bb1746402ec884526670ac06fc04e42f2 Mon Sep 17 00:00:00 2001 From: Michael Golowka <72365+pcman312@users.noreply.github.com> Date: Tue, 28 Apr 2020 14:48:36 -0600 Subject: [PATCH] Add alias types for IAM and GCE logins (#89) (#95) Resolves https://github.com/hashicorp/vault/issues/8761 Allows users to specify an alias type field for IAM and GCE logins which will then switch between the current behavior (a unique ID for IAM, and the instance ID for GCE) and a newly created role_id field. The role_id is a UUID generated when the role is created. All existing roles without a role_id will have one generated and saved when the role is read or written. --- Makefile | 3 + go.mod | 2 + go.sum | 10 + plugin/aliasing.go | 91 +++++++ plugin/aliasing_test.go | 150 +++++++++++ plugin/gcp_role.go | 429 ++++++++++++++++++++++++++++++ plugin/mocks_test.go | 308 ++++++++++++++++++++++ plugin/path_config.go | 10 +- plugin/path_login.go | 26 +- plugin/path_login_test.go | 2 +- plugin/path_role.go | 487 ++++++---------------------------- plugin/path_role_test.go | 539 +++++++++++++++++++++++++++++++++++++- 12 files changed, 1635 insertions(+), 422 deletions(-) create mode 100644 plugin/aliasing.go create mode 100644 plugin/aliasing_test.go create mode 100644 plugin/gcp_role.go create mode 100644 plugin/mocks_test.go diff --git a/Makefile b/Makefile index 8428eca6..a20899de 100644 --- a/Makefile +++ b/Makefile @@ -51,5 +51,8 @@ fmtcheck: fmt: gofmt -w $(GOFMT_FILES) +mocks: + mockgen -destination ${CURDIR}/plugin/mocks_test.go -package gcpauth github.com/hashicorp/vault/sdk/logical SystemView,Storage + .PHONY: bin default generate test vet bootstrap fmt fmtcheck diff --git a/go.mod b/go.mod index 87647f25..c1193bd6 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,12 @@ module github.com/hashicorp/vault-plugin-auth-gcp go 1.12 require ( + github.com/golang/mock v1.4.3 github.com/hashicorp/errwrap v1.0.0 github.com/hashicorp/go-cleanhttp v0.5.1 github.com/hashicorp/go-gcp-common v0.6.0 github.com/hashicorp/go-hclog v0.12.0 + github.com/hashicorp/go-uuid v1.0.2 github.com/hashicorp/vault/api v1.0.5-0.20200317185738-82f498082f02 github.com/hashicorp/vault/sdk v0.1.14-0.20200317185738-82f498082f02 github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d // indirect diff --git a/go.sum b/go.sum index 650821c0..1224e306 100644 --- a/go.sum +++ b/go.sum @@ -43,7 +43,10 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekf github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.4.3 h1:GV+pQPG/EUUbkh47niozDcADz6go/dUwhVzdUQHIVRw= +github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -215,6 +218,7 @@ golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqY golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191008105621-543471e840be h1:QAcqgptGM8IQBC9K/RC4o+O9YmqEm0diQn9QmZw/0mU= golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20181227161524-e6919f6577db h1:6/JqlYfC1CCaLnGceQTI+sDGhC9UBSPAsBqI0Gun6kU= golang.org/x/text v0.3.1-0.20181227161524-e6919f6577db/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -227,6 +231,8 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135 h1:5Beo0mZN8dRzgrMMkDp0jc8YXQKx9DiJ2k1dkvGsn5A= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= google.golang.org/api v0.0.0-20181220000619-583d854617af/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.2.0/go.mod h1:IfRCZScioGtypHNTlz3gFk67J8uePVW7uDTBzXuIkhU= @@ -262,3 +268,7 @@ honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20180920025451-e3ad64cb4ed3/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +rsc.io/quote/v3 v3.1.0 h1:9JKUTTIUgS6kzR9mK1YuGKv6Nl+DijDNIc0ghT58FaY= +rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= +rsc.io/sampler v1.3.0 h1:7uVkIFmeBqHfdjD+gZwtXXI+RODJ2Wc4O7MPEh/QiW4= +rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/plugin/aliasing.go b/plugin/aliasing.go new file mode 100644 index 00000000..b5617468 --- /dev/null +++ b/plugin/aliasing.go @@ -0,0 +1,91 @@ +package gcpauth + +import ( + "fmt" + "sort" + "strconv" + "strings" + + "google.golang.org/api/compute/v1" + "google.golang.org/api/iam/v1" +) + +type iamAliaser func(role *gcpRole, svcAccount *iam.ServiceAccount) (alias string) +type gceAliaser func(role *gcpRole, instance *compute.Instance) (alias string) + +const ( + defaultIAMAlias = "unique_id" + defaultGCEAlias = "instance_id" +) + +var ( + allowedIAMAliases = map[string]iamAliaser{ + defaultIAMAlias: getIAMSvcAccountUniqueID, + "": getIAMSvcAccountUniqueID, // For backwards compatibility + + "role_id": getIAMRoleID, + } + allowedGCEAliases = map[string]gceAliaser{ + defaultGCEAlias: getGCEInstanceID, + "": getGCEInstanceID, // For backwards compatibility + + "role_id": getGCERoleID, + } + + allowedIAMAliasesSlice = iamMapKeyToSlice(allowedIAMAliases) + allowedGCEAliasesSlice = gceMapKeyToSlice(allowedGCEAliases) +) + +func iamMapKeyToSlice(m map[string]iamAliaser) (s []string) { + for key := range m { + if key == "" { + continue + } + s = append(s, key) + } + sort.Strings(s) + return s +} + +func gceMapKeyToSlice(m map[string]gceAliaser) (s []string) { + for key := range m { + if key == "" { + continue + } + s = append(s, key) + } + sort.Strings(s) + return s +} + +func getIAMSvcAccountUniqueID(_ *gcpRole, svcAccount *iam.ServiceAccount) (alias string) { + return svcAccount.UniqueId +} + +func getIAMRoleID(role *gcpRole, _ *iam.ServiceAccount) (alias string) { + return role.RoleID +} + +func getGCEInstanceID(_ *gcpRole, instance *compute.Instance) (alias string) { + return fmt.Sprintf("gce-%s", strconv.FormatUint(instance.Id, 10)) +} + +func getGCERoleID(role *gcpRole, _ *compute.Instance) (alias string) { + return role.RoleID +} + +func getIAMAlias(role *gcpRole, svcAccount *iam.ServiceAccount) (alias string, err error) { + aliaser, exists := allowedIAMAliases[role.IAMAliasType] + if !exists { + return "", fmt.Errorf("invalid IAM alias type: must be one of: %s", strings.Join(allowedIAMAliasesSlice, ", ")) + } + return aliaser(role, svcAccount), nil +} + +func getGCEAlias(role *gcpRole, instance *compute.Instance) (alias string, err error) { + aliaser, exists := allowedGCEAliases[role.GCEAliasType] + if !exists { + return "", fmt.Errorf("invalid GCE alias type: must be one of: %s", strings.Join(allowedIAMAliasesSlice, ", ")) + } + return aliaser(role, instance), nil +} diff --git a/plugin/aliasing_test.go b/plugin/aliasing_test.go new file mode 100644 index 00000000..22d8005e --- /dev/null +++ b/plugin/aliasing_test.go @@ -0,0 +1,150 @@ +package gcpauth + +import ( + "testing" + + "google.golang.org/api/compute/v1" + "google.golang.org/api/iam/v1" +) + +func TestGetIAMAlias(t *testing.T) { + type testCase struct { + role *gcpRole + svcAccount *iam.ServiceAccount + expectedAlias string + expectErr bool + } + + tests := map[string]testCase{ + "invalid type": { + role: &gcpRole{ + IAMAliasType: "bogus", + RoleID: "testRoleID", + }, + svcAccount: &iam.ServiceAccount{ + UniqueId: "iamUniqueID", + }, + expectedAlias: "", + expectErr: true, + }, + "empty type goes to default": { + role: &gcpRole{ + IAMAliasType: "", + RoleID: "testRoleID", + }, + svcAccount: &iam.ServiceAccount{ + UniqueId: "iamUniqueID", + }, + expectedAlias: "iamUniqueID", + expectErr: false, + }, + "default type": { + role: &gcpRole{ + IAMAliasType: defaultIAMAlias, + RoleID: "testRoleID", + }, + svcAccount: &iam.ServiceAccount{ + UniqueId: "iamUniqueID", + }, + expectedAlias: "iamUniqueID", + expectErr: false, + }, + "role_id": { + role: &gcpRole{ + IAMAliasType: "role_id", + RoleID: "testRoleID", + }, + svcAccount: &iam.ServiceAccount{ + UniqueId: "iamUniqueID", + }, + expectedAlias: "testRoleID", + expectErr: false, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + actualAlias, err := getIAMAlias(test.role, test.svcAccount) + if test.expectErr && err == nil { + t.Fatalf("err expected, got nil") + } + if !test.expectErr && err != nil { + t.Fatalf("no error expected, got: %s", err) + } + if actualAlias != test.expectedAlias { + t.Fatalf("Actual alias: %s Expected Alias: %s", actualAlias, test.expectedAlias) + } + }) + } +} + +func TestGetGCEAlias(t *testing.T) { + type testCase struct { + role *gcpRole + instance *compute.Instance + expectedAlias string + expectErr bool + } + + tests := map[string]testCase{ + "invalid type": { + role: &gcpRole{ + GCEAliasType: "bogus", + RoleID: "testRoleID", + }, + instance: &compute.Instance{ + Id: 123, + }, + expectedAlias: "", + expectErr: true, + }, + "empty type goes to default": { + role: &gcpRole{ + GCEAliasType: "", + RoleID: "testRoleID", + }, + instance: &compute.Instance{ + Id: 123, + }, + expectedAlias: "gce-123", + expectErr: false, + }, + "default type": { + role: &gcpRole{ + GCEAliasType: defaultGCEAlias, + RoleID: "testRoleID", + }, + instance: &compute.Instance{ + Id: 123, + }, + expectedAlias: "gce-123", + expectErr: false, + }, + "role_id": { + role: &gcpRole{ + GCEAliasType: "role_id", + RoleID: "testRoleID", + }, + instance: &compute.Instance{ + Id: 123, + }, + expectedAlias: "testRoleID", + expectErr: false, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + actualAlias, err := getGCEAlias(test.role, test.instance) + if test.expectErr && err == nil { + t.Fatalf("err expected, got nil") + } + if !test.expectErr && err != nil { + t.Fatalf("no error expected, got: %s", err) + } + if actualAlias != test.expectedAlias { + t.Fatalf("Actual alias: %s Expected Alias: %s", actualAlias, test.expectedAlias) + } + }) + } +} diff --git a/plugin/gcp_role.go b/plugin/gcp_role.go new file mode 100644 index 00000000..8ec4d292 --- /dev/null +++ b/plugin/gcp_role.go @@ -0,0 +1,429 @@ +package gcpauth + +import ( + "errors" + "fmt" + "strings" + "time" + + "github.com/hashicorp/go-gcp-common/gcputil" + "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/helper/strutil" + "github.com/hashicorp/vault/sdk/helper/tokenutil" + "github.com/hashicorp/vault/sdk/logical" +) + +const ( + currentGCPRoleVersion = 1 +) + +type gcpRole struct { + tokenutil.TokenParams + + // RoleID is a unique identifier for this role. + RoleID string `json:"role_id"` + + // Type of this role. See path_role constants for currently supported types. + RoleType string `json:"role_type,omitempty"` + + // Policies for Vault to assign to authorized entities. + Policies []string `json:"policies,omitempty"` + + // TTL of Vault auth leases under this role. + TTL time.Duration `json:"ttl,omitempty"` + + // Max total TTL including renewals, of Vault auth leases under this role. + MaxTTL time.Duration `json:"max_ttl,omitempty"` + + // Period, If set, indicates that this token should not expire and + // should be automatically renewed within this time period + // with TTL equal to this value. + Period time.Duration `json:"period,omitempty"` + + // Projects that entities must belong to + BoundProjects []string `json:"bound_projects,omitempty"` + + // Service accounts allowed to login under this role. + BoundServiceAccounts []string `json:"bound_service_accounts,omitempty"` + + // AddGroupAliases adds Vault group aliases to the response. + AddGroupAliases bool `json:"add_group_aliases,omitempty"` + + // --| IAM-only attributes |-- + // MaxJwtExp is the duration from time of authentication that a JWT used to authenticate to role must expire within. + // TODO(emilymye): Allow this to be updated for GCE roles once 'exp' parameter has been allowed for GCE metadata. + MaxJwtExp time.Duration `json:"max_jwt_exp,omitempty"` + + // AllowGCEInference, if false, does not allow a GCE instance to login under this 'iam' role. If true (default), + // a service account is inferred from the instance metadata and used as the authenticating instance. + AllowGCEInference bool `json:"allow_gce_inference,omitempty"` + + // --| GCE-only attributes |-- + // BoundRegions that instances must belong to in order to login under this role. + BoundRegions []string `json:"bound_regions,omitempty"` + + // BoundZones that instances must belong to in order to login under this role. + BoundZones []string `json:"bound_zones,omitempty"` + + // BoundInstanceGroups are the instance group that instances must belong to in order to login under this role. + BoundInstanceGroups []string `json:"bound_instance_groups,omitempty"` + + // BoundLabels that instances must currently have set in order to login under this role. + BoundLabels map[string]string `json:"bound_labels,omitempty"` + + // IAMAliasType specifies the alias name to use with IAM roles. Can be either "unique_id" (default) or "role_id" + IAMAliasType string `json:"iam_alias,omitempty"` + + // GCEAliasType specifies the alias name to use with GCE roles. Can be either "instance_id" (default) or "role_id" + GCEAliasType string `json:"gce_alias,omitempty"` + + // Version indicates the version of this configuration. Allows for more advanced logic around + // upgrades and different behavior between config versions. + Version int `json:"version,omitempty"` + + // Deprecated fields + // TODO: Remove in 0.5.0+ + ProjectId string `json:"project_id,omitempty"` + BoundRegion string `json:"bound_region,omitempty"` + BoundZone string `json:"bound_zone,omitempty"` + BoundInstanceGroup string `json:"bound_instance_group,omitempty"` +} + +// updateRole updates the given role with values parsed/validated from given FieldData. +// Exactly one of the response and error will be nil. The response is only used to pass back warnings. +// This method does not validate the role. Validation is done before storage. +func (role *gcpRole) updateRole(sys logical.SystemView, req *logical.Request, data *framework.FieldData) (warnings []string, err error) { + if e := role.ParseTokenFields(req, data); e != nil { + return nil, e + } + + // Handle token field upgrades + { + if e := tokenutil.UpgradeValue(data, "policies", "token_policies", &role.Policies, &role.TokenPolicies); e != nil { + return nil, e + } + + if e := tokenutil.UpgradeValue(data, "ttl", "token_ttl", &role.TTL, &role.TokenTTL); e != nil { + return nil, e + } + + if e := tokenutil.UpgradeValue(data, "max_ttl", "token_max_ttl", &role.MaxTTL, &role.TokenMaxTTL); e != nil { + return nil, e + } + + if e := tokenutil.UpgradeValue(data, "period", "token_period", &role.Period, &role.TokenPeriod); e != nil { + return nil, e + } + } + + // Set role type + if rt, ok := data.GetOk("type"); ok { + roleType := rt.(string) + if role.RoleType != roleType && req.Operation == logical.UpdateOperation { + return nil, fmt.Errorf("role type cannot be changed for an existing role") + } + role.RoleType = roleType + } else if req.Operation == logical.CreateOperation { + return nil, fmt.Errorf(errEmptyRoleType) + } + + def := sys.DefaultLeaseTTL() + if role.TokenTTL > def { + warnings = append(warnings, fmt.Sprintf(`Given token ttl of %q is greater `+ + `than the maximum system/mount TTL of %q. The TTL will be capped at `+ + `%q during login.`, role.TokenTTL, def, def)) + } + + // Update token Max TTL. + def = sys.MaxLeaseTTL() + if role.TokenMaxTTL > def { + warnings = append(warnings, fmt.Sprintf(`Given token max ttl of %q is greater `+ + `than the maximum system/mount MaxTTL of %q. The MaxTTL will be `+ + `capped at %q during login.`, role.TokenMaxTTL, def, def)) + } + if role.TokenPeriod > def { + warnings = append(warnings, fmt.Sprintf(`Given token period of %q is greater `+ + `than the maximum system/mount period of %q. The period will be `+ + `capped at %q during login.`, role.TokenPeriod, def, def)) + } + + // Update bound GCP service accounts. + if sa, ok := data.GetOk("bound_service_accounts"); ok { + role.BoundServiceAccounts = sa.([]string) + } else { + // Check for older version of param name + if sa, ok := data.GetOk("service_accounts"); ok { + warnings = append(warnings, `The "service_accounts" field is deprecated. `+ + `Please use "bound_service_accounts" instead. The "service_accounts" `+ + `field will be removed in a later release, so please update accordingly.`) + role.BoundServiceAccounts = sa.([]string) + } + } + if len(role.BoundServiceAccounts) > 0 { + role.BoundServiceAccounts = strutil.TrimStrings(role.BoundServiceAccounts) + role.BoundServiceAccounts = strutil.RemoveDuplicates(role.BoundServiceAccounts, false) + } + + // Update bound GCP projects. + boundProjects, givenBoundProj := data.GetOk("bound_projects") + if givenBoundProj { + role.BoundProjects = boundProjects.([]string) + } + if projectId, ok := data.GetOk("project_id"); ok { + if givenBoundProj { + return warnings, errors.New("only one of 'bound_projects' or 'project_id' can be given") + } + warnings = append(warnings, + `The "project_id" (singular) field is deprecated. `+ + `Please use plural "bound_projects" instead to bind required GCP projects. `+ + `The "project_id" field will be removed in a later release, so please update accordingly.`) + role.BoundProjects = []string{projectId.(string)} + } + if len(role.BoundProjects) > 0 { + role.BoundProjects = strutil.TrimStrings(role.BoundProjects) + role.BoundProjects = strutil.RemoveDuplicates(role.BoundProjects, false) + } + + // Update bound GCP projects. + addGroupAliases, ok := data.GetOk("add_group_aliases") + if ok { + role.AddGroupAliases = addGroupAliases.(bool) + } + + // Update fields specific to this type + switch role.RoleType { + case iamRoleType: + if err = checkInvalidRoleTypeArgs(data, gceOnlyFieldSchema); err != nil { + return warnings, err + } + if warnings, err = role.updateIamFields(data, req.Operation); err != nil { + return warnings, err + } + iamAliasType, ok := data.GetOk("iam_alias") + if ok { + role.IAMAliasType = iamAliasType.(string) + } + case gceRoleType: + if err = checkInvalidRoleTypeArgs(data, iamOnlyFieldSchema); err != nil { + return warnings, err + } + if warnings, err = role.updateGceFields(data, req.Operation); err != nil { + return warnings, err + } + gceAliasType, ok := data.GetOk("gce_alias") + if ok { + role.GCEAliasType = gceAliasType.(string) + } + } + + return warnings, nil +} + +func (role *gcpRole) validate(sys logical.SystemView) (warnings []string, err error) { + warnings = []string{} + + switch role.RoleType { + case iamRoleType: + if warnings, err = role.validateForIAM(); err != nil { + return warnings, err + } + case gceRoleType: + if warnings, err = role.validateForGCE(); err != nil { + return warnings, err + } + case "": + return warnings, errors.New(errEmptyRoleType) + default: + return warnings, fmt.Errorf("role type '%s' is invalid", role.RoleType) + } + + defaultLeaseTTL := sys.DefaultLeaseTTL() + if role.TokenTTL > defaultLeaseTTL { + warnings = append(warnings, fmt.Sprintf( + "Given ttl of %d seconds greater than current mount/system default of %d seconds; ttl will be capped at login time", + role.TokenTTL/time.Second, defaultLeaseTTL/time.Second)) + } + + defaultMaxTTL := sys.MaxLeaseTTL() + if role.TokenMaxTTL > defaultMaxTTL { + warnings = append(warnings, fmt.Sprintf( + "Given max_ttl of %d seconds greater than current mount/system default of %d seconds; max_ttl will be capped at login time", + role.TokenMaxTTL/time.Second, defaultMaxTTL/time.Second)) + } + if role.TokenMaxTTL < time.Duration(0) { + return warnings, errors.New("max_ttl cannot be negative") + } + if role.TokenMaxTTL != 0 && role.TokenMaxTTL < role.TokenTTL { + return warnings, errors.New("ttl should be shorter than max_ttl") + } + + if role.TokenPeriod > sys.MaxLeaseTTL() { + return warnings, fmt.Errorf("'period' of '%s' is greater than the backend's maximum lease TTL of '%s'", role.TokenPeriod.String(), sys.MaxLeaseTTL().String()) + } + + if _, exists := allowedIAMAliases[role.IAMAliasType]; !exists { + return warnings, fmt.Errorf("iam_alias must be one of: %s", strings.Join(allowedIAMAliasesSlice, ", ")) + } + if _, exists := allowedGCEAliases[role.GCEAliasType]; !exists { + return warnings, fmt.Errorf("gce_alias must be one of: %s", strings.Join(allowedGCEAliasesSlice, ", ")) + } + + return warnings, nil +} + +// updateIamFields updates IAM-only fields for a role. +func (role *gcpRole) updateIamFields(data *framework.FieldData, op logical.Operation) (warnings []string, err error) { + if allowGCEInference, ok := data.GetOk("allow_gce_inference"); ok { + role.AllowGCEInference = allowGCEInference.(bool) + } else if op == logical.CreateOperation { + role.AllowGCEInference = data.Get("allow_gce_inference").(bool) + } + + if maxJwtExp, ok := data.GetOk("max_jwt_exp"); ok { + role.MaxJwtExp = time.Duration(maxJwtExp.(int)) * time.Second + } else if op == logical.CreateOperation { + role.MaxJwtExp = time.Duration(defaultIamMaxJwtExpMinutes) * time.Minute + } + + if role.IAMAliasType == "" { + role.IAMAliasType = defaultIAMAlias + } + + return warnings, nil +} + +// updateGceFields updates GCE-only fields for a role. +func (role *gcpRole) updateGceFields(data *framework.FieldData, op logical.Operation) (warnings []string, err error) { + if regions, ok := data.GetOk("bound_regions"); ok { + role.BoundRegions = regions.([]string) + } else if op == logical.CreateOperation { + role.BoundRegions = data.Get("bound_regions").([]string) + } + + if zones, ok := data.GetOk("bound_zones"); ok { + role.BoundZones = zones.([]string) + } else if op == logical.CreateOperation { + role.BoundZones = data.Get("bound_zones").([]string) + } + + if instanceGroups, ok := data.GetOk("bound_instance_groups"); ok { + role.BoundInstanceGroups = instanceGroups.([]string) + } else if op == logical.CreateOperation { + role.BoundInstanceGroups = data.Get("bound_instance_groups").([]string) + } + + if boundRegion, ok := data.GetOk("bound_region"); ok { + if _, ok := data.GetOk("bound_regions"); ok { + return warnings, fmt.Errorf(`cannot specify both "bound_region" and "bound_regions"`) + } + + warnings = append(warnings, `The "bound_region" field is deprecated. `+ + `Please use "bound_regions" (plural) instead. You can still specify a `+ + `single region, but multiple regions are also now supported. The `+ + `"bound_region" field will be removed in a later release, so please `+ + `update accordingly.`) + role.BoundRegions = append(role.BoundRegions, boundRegion.(string)) + } + + if boundZone, ok := data.GetOk("bound_zone"); ok { + if _, ok := data.GetOk("bound_zones"); ok { + return warnings, fmt.Errorf(`cannot specify both "bound_zone" and "bound_zones"`) + } + + warnings = append(warnings, `The "bound_zone" field is deprecated. `+ + `Please use "bound_zones" (plural) instead. You can still specify a `+ + `single zone, but multiple zones are also now supported. The `+ + `"bound_zone" field will be removed in a later release, so please `+ + `update accordingly.`) + role.BoundZones = append(role.BoundZones, boundZone.(string)) + } + + if boundInstanceGroup, ok := data.GetOk("bound_instance_group"); ok { + if _, ok := data.GetOk("bound_instance_groups"); ok { + return warnings, fmt.Errorf(`cannot specify both "bound_instance_group" and "bound_instance_groups"`) + } + + warnings = append(warnings, `The "bound_instance_group" field is deprecated. `+ + `Please use "bound_instance_groups" (plural) instead. You can still specify a `+ + `single instance group, but multiple instance groups are also now supported. The `+ + `"bound_instance_group" field will be removed in a later release, so please `+ + `update accordingly.`) + role.BoundInstanceGroups = append(role.BoundInstanceGroups, boundInstanceGroup.(string)) + } + + if labelsRaw, ok := data.GetOk("bound_labels"); ok { + labels, invalidLabels := gcputil.ParseGcpLabels(labelsRaw.([]string)) + if len(invalidLabels) > 0 { + return warnings, fmt.Errorf("invalid labels given: %q", invalidLabels) + } + role.BoundLabels = labels + } + + if len(role.Policies) > 0 { + role.Policies = strutil.TrimStrings(role.Policies) + role.Policies = strutil.RemoveDuplicates(role.Policies, false) + } + + if len(role.BoundRegions) > 0 { + role.BoundRegions = strutil.TrimStrings(role.BoundRegions) + role.BoundRegions = strutil.RemoveDuplicates(role.BoundRegions, false) + } + + if len(role.BoundZones) > 0 { + role.BoundZones = strutil.TrimStrings(role.BoundZones) + role.BoundZones = strutil.RemoveDuplicates(role.BoundZones, false) + } + + if len(role.BoundInstanceGroups) > 0 { + role.BoundInstanceGroups = strutil.TrimStrings(role.BoundInstanceGroups) + role.BoundInstanceGroups = strutil.RemoveDuplicates(role.BoundInstanceGroups, false) + } + + if role.GCEAliasType == "" { + role.GCEAliasType = defaultGCEAlias + } + + return warnings, nil +} + +// validateIamFields validates the IAM-only fields for a role. +func (role *gcpRole) validateForIAM() (warnings []string, err error) { + if len(role.BoundServiceAccounts) == 0 { + return []string{}, errors.New(errEmptyIamServiceAccounts) + } + + if len(role.BoundServiceAccounts) > 1 && strutil.StrListContains(role.BoundServiceAccounts, serviceAccountsWildcard) { + return []string{}, fmt.Errorf("cannot provide IAM service account wildcard '%s' (for all service accounts) with other service accounts", serviceAccountsWildcard) + } + + maxMaxJwtExp := time.Duration(maxJwtExpMaxMinutes) * time.Minute + if role.MaxJwtExp > maxMaxJwtExp { + return warnings, fmt.Errorf("max_jwt_exp cannot be more than %d minutes", maxJwtExpMaxMinutes) + } + + return []string{}, nil +} + +// validateGceFields validates the GCE-only fields for a role. +func (role *gcpRole) validateForGCE() (warnings []string, err error) { + warnings = []string{} + + hasRegion := len(role.BoundRegions) > 0 + hasZone := len(role.BoundZones) > 0 + hasRegionOrZone := hasRegion || hasZone + + hasInstanceGroup := len(role.BoundInstanceGroups) > 0 + + if hasInstanceGroup && !hasRegionOrZone { + return warnings, errors.New(`region or zone information must be specified if an instance group is given`) + } + + if hasRegion && hasZone { + warnings = append(warnings, `Given both "bound_regions" and "bound_zones" `+ + `fields for role type "gce", "bound_regions" will be ignored in favor `+ + `of the more specific "bound_zones" field. To fix this warning, update `+ + `the role to remove either the "bound_regions" or "bound_zones" field.`) + } + + return warnings, nil +} diff --git a/plugin/mocks_test.go b/plugin/mocks_test.go new file mode 100644 index 00000000..6785ab2b --- /dev/null +++ b/plugin/mocks_test.go @@ -0,0 +1,308 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/hashicorp/vault/sdk/logical (interfaces: SystemView,Storage) + +// Package gcpauth is a generated GoMock package. +package gcpauth + +import ( + context "context" + gomock "github.com/golang/mock/gomock" + consts "github.com/hashicorp/vault/sdk/helper/consts" + license "github.com/hashicorp/vault/sdk/helper/license" + pluginutil "github.com/hashicorp/vault/sdk/helper/pluginutil" + wrapping "github.com/hashicorp/vault/sdk/helper/wrapping" + logical "github.com/hashicorp/vault/sdk/logical" + reflect "reflect" + time "time" +) + +// MockSystemView is a mock of SystemView interface +type MockSystemView struct { + ctrl *gomock.Controller + recorder *MockSystemViewMockRecorder +} + +// MockSystemViewMockRecorder is the mock recorder for MockSystemView +type MockSystemViewMockRecorder struct { + mock *MockSystemView +} + +// NewMockSystemView creates a new mock instance +func NewMockSystemView(ctrl *gomock.Controller) *MockSystemView { + mock := &MockSystemView{ctrl: ctrl} + mock.recorder = &MockSystemViewMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockSystemView) EXPECT() *MockSystemViewMockRecorder { + return m.recorder +} + +// CachingDisabled mocks base method +func (m *MockSystemView) CachingDisabled() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CachingDisabled") + ret0, _ := ret[0].(bool) + return ret0 +} + +// CachingDisabled indicates an expected call of CachingDisabled +func (mr *MockSystemViewMockRecorder) CachingDisabled() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CachingDisabled", reflect.TypeOf((*MockSystemView)(nil).CachingDisabled)) +} + +// DefaultLeaseTTL mocks base method +func (m *MockSystemView) DefaultLeaseTTL() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DefaultLeaseTTL") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// DefaultLeaseTTL indicates an expected call of DefaultLeaseTTL +func (mr *MockSystemViewMockRecorder) DefaultLeaseTTL() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DefaultLeaseTTL", reflect.TypeOf((*MockSystemView)(nil).DefaultLeaseTTL)) +} + +// EntityInfo mocks base method +func (m *MockSystemView) EntityInfo(arg0 string) (*logical.Entity, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EntityInfo", arg0) + ret0, _ := ret[0].(*logical.Entity) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EntityInfo indicates an expected call of EntityInfo +func (mr *MockSystemViewMockRecorder) EntityInfo(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EntityInfo", reflect.TypeOf((*MockSystemView)(nil).EntityInfo), arg0) +} + +// GroupsForEntity mocks base method +func (m *MockSystemView) GroupsForEntity(arg0 string) ([]*logical.Group, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GroupsForEntity", arg0) + ret0, _ := ret[0].([]*logical.Group) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GroupsForEntity indicates an expected call of GroupsForEntity +func (mr *MockSystemViewMockRecorder) GroupsForEntity(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupsForEntity", reflect.TypeOf((*MockSystemView)(nil).GroupsForEntity), arg0) +} + +// HasFeature mocks base method +func (m *MockSystemView) HasFeature(arg0 license.Features) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasFeature", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasFeature indicates an expected call of HasFeature +func (mr *MockSystemViewMockRecorder) HasFeature(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasFeature", reflect.TypeOf((*MockSystemView)(nil).HasFeature), arg0) +} + +// LocalMount mocks base method +func (m *MockSystemView) LocalMount() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalMount") + ret0, _ := ret[0].(bool) + return ret0 +} + +// LocalMount indicates an expected call of LocalMount +func (mr *MockSystemViewMockRecorder) LocalMount() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalMount", reflect.TypeOf((*MockSystemView)(nil).LocalMount)) +} + +// LookupPlugin mocks base method +func (m *MockSystemView) LookupPlugin(arg0 context.Context, arg1 string, arg2 consts.PluginType) (*pluginutil.PluginRunner, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LookupPlugin", arg0, arg1, arg2) + ret0, _ := ret[0].(*pluginutil.PluginRunner) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LookupPlugin indicates an expected call of LookupPlugin +func (mr *MockSystemViewMockRecorder) LookupPlugin(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupPlugin", reflect.TypeOf((*MockSystemView)(nil).LookupPlugin), arg0, arg1, arg2) +} + +// MaxLeaseTTL mocks base method +func (m *MockSystemView) MaxLeaseTTL() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MaxLeaseTTL") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// MaxLeaseTTL indicates an expected call of MaxLeaseTTL +func (mr *MockSystemViewMockRecorder) MaxLeaseTTL() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaxLeaseTTL", reflect.TypeOf((*MockSystemView)(nil).MaxLeaseTTL)) +} + +// MlockEnabled mocks base method +func (m *MockSystemView) MlockEnabled() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MlockEnabled") + ret0, _ := ret[0].(bool) + return ret0 +} + +// MlockEnabled indicates an expected call of MlockEnabled +func (mr *MockSystemViewMockRecorder) MlockEnabled() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MlockEnabled", reflect.TypeOf((*MockSystemView)(nil).MlockEnabled)) +} + +// PluginEnv mocks base method +func (m *MockSystemView) PluginEnv(arg0 context.Context) (*logical.PluginEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PluginEnv", arg0) + ret0, _ := ret[0].(*logical.PluginEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PluginEnv indicates an expected call of PluginEnv +func (mr *MockSystemViewMockRecorder) PluginEnv(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PluginEnv", reflect.TypeOf((*MockSystemView)(nil).PluginEnv), arg0) +} + +// ReplicationState mocks base method +func (m *MockSystemView) ReplicationState() consts.ReplicationState { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReplicationState") + ret0, _ := ret[0].(consts.ReplicationState) + return ret0 +} + +// ReplicationState indicates an expected call of ReplicationState +func (mr *MockSystemViewMockRecorder) ReplicationState() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplicationState", reflect.TypeOf((*MockSystemView)(nil).ReplicationState)) +} + +// ResponseWrapData mocks base method +func (m *MockSystemView) ResponseWrapData(arg0 context.Context, arg1 map[string]interface{}, arg2 time.Duration, arg3 bool) (*wrapping.ResponseWrapInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResponseWrapData", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*wrapping.ResponseWrapInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ResponseWrapData indicates an expected call of ResponseWrapData +func (mr *MockSystemViewMockRecorder) ResponseWrapData(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResponseWrapData", reflect.TypeOf((*MockSystemView)(nil).ResponseWrapData), arg0, arg1, arg2, arg3) +} + +// Tainted mocks base method +func (m *MockSystemView) Tainted() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Tainted") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Tainted indicates an expected call of Tainted +func (mr *MockSystemViewMockRecorder) Tainted() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Tainted", reflect.TypeOf((*MockSystemView)(nil).Tainted)) +} + +// MockStorage is a mock of Storage interface +type MockStorage struct { + ctrl *gomock.Controller + recorder *MockStorageMockRecorder +} + +// MockStorageMockRecorder is the mock recorder for MockStorage +type MockStorageMockRecorder struct { + mock *MockStorage +} + +// NewMockStorage creates a new mock instance +func NewMockStorage(ctrl *gomock.Controller) *MockStorage { + mock := &MockStorage{ctrl: ctrl} + mock.recorder = &MockStorageMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStorage) EXPECT() *MockStorageMockRecorder { + return m.recorder +} + +// Delete mocks base method +func (m *MockStorage) Delete(arg0 context.Context, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete +func (mr *MockStorageMockRecorder) Delete(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockStorage)(nil).Delete), arg0, arg1) +} + +// Get mocks base method +func (m *MockStorage) Get(arg0 context.Context, arg1 string) (*logical.StorageEntry, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0, arg1) + ret0, _ := ret[0].(*logical.StorageEntry) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get +func (mr *MockStorageMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockStorage)(nil).Get), arg0, arg1) +} + +// List mocks base method +func (m *MockStorage) List(arg0 context.Context, arg1 string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", arg0, arg1) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List +func (mr *MockStorageMockRecorder) List(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockStorage)(nil).List), arg0, arg1) +} + +// Put mocks base method +func (m *MockStorage) Put(arg0 context.Context, arg1 *logical.StorageEntry) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Put", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Put indicates an expected call of Put +func (mr *MockStorageMockRecorder) Put(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockStorage)(nil).Put), arg0, arg1) +} diff --git a/plugin/path_config.go b/plugin/path_config.go index 856b9b6b..309e7b42 100644 --- a/plugin/path_config.go +++ b/plugin/path_config.go @@ -2,9 +2,9 @@ package gcpauth import ( "context" - "errors" - "encoding/json" + "errors" + "net/http" "github.com/hashicorp/errwrap" "github.com/hashicorp/go-gcp-common/gcputil" @@ -44,7 +44,7 @@ Deprecated. This field does nothing and be removed in a future release`, func (b *GcpAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { if err := validateFields(req, d); err != nil { - return nil, logical.CodedError(422, err.Error()) + return nil, logical.CodedError(http.StatusUnprocessableEntity, err.Error()) } c, err := b.config(ctx, req.Storage) @@ -58,7 +58,7 @@ func (b *GcpAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque changed, err := c.Update(d) if err != nil { - return nil, logical.CodedError(400, err.Error()) + return nil, logical.CodedError(http.StatusBadRequest, err.Error()) } // Only do the following if the config is different @@ -83,7 +83,7 @@ func (b *GcpAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque func (b *GcpAuthBackend) pathConfigRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { if err := validateFields(req, d); err != nil { - return nil, logical.CodedError(422, err.Error()) + return nil, logical.CodedError(http.StatusUnprocessableEntity, err.Error()) } config, err := b.config(ctx, req.Storage) diff --git a/plugin/path_login.go b/plugin/path_login.go index 35a2d01e..855f86bb 100644 --- a/plugin/path_login.go +++ b/plugin/path_login.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net/http" "strconv" "strings" "time" @@ -59,7 +60,7 @@ GCE identity metadata token ('iam', 'gce' roles).`, func (b *GcpAuthBackend) pathLogin(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { // Validate we didn't get extraneous fields if err := validateFields(req, data); err != nil { - return nil, logical.CodedError(422, err.Error()) + return nil, logical.CodedError(http.StatusUnprocessableEntity, err.Error()) } loginInfo, err := b.parseAndValidateJwt(ctx, req, data) @@ -315,14 +316,20 @@ func (b *GcpAuthBackend) pathIamLogin(ctx context.Context, req *logical.Request, return nil, errors.New("service account is empty") } + alias, err := getIAMAlias(role, serviceAccount) + if err != nil { + return logical.ErrorResponse("unable to create alias: %s", err), nil + } + if req.Operation == logical.AliasLookaheadOperation { - return &logical.Response{ + resp := &logical.Response{ Auth: &logical.Auth{ Alias: &logical.Alias{ - Name: serviceAccount.UniqueId, + Name: alias, }, }, - }, nil + } + return resp, nil } // Validate service account can login against role. @@ -332,7 +339,7 @@ func (b *GcpAuthBackend) pathIamLogin(ctx context.Context, req *logical.Request, auth := &logical.Auth{ Alias: &logical.Alias{ - Name: serviceAccount.UniqueId, + Name: alias, }, Metadata: authMetadata(loginInfo, serviceAccount), DisplayName: serviceAccount.Email, @@ -449,11 +456,16 @@ func (b *GcpAuthBackend) pathGceLogin(ctx context.Context, req *logical.Request, return logical.ErrorResponse(err.Error()), nil } + alias, err := getGCEAlias(role, instance) + if err != nil { + return logical.ErrorResponse("unable to create alias: %s", err), nil + } + if req.Operation == logical.AliasLookaheadOperation { return &logical.Response{ Auth: &logical.Auth{ Alias: &logical.Alias{ - Name: fmt.Sprintf("gce-%s", strconv.FormatUint(instance.Id, 10)), + Name: alias, }, }, }, nil @@ -475,7 +487,7 @@ func (b *GcpAuthBackend) pathGceLogin(ctx context.Context, req *logical.Request, auth := &logical.Auth{ InternalData: map[string]interface{}{}, Alias: &logical.Alias{ - Name: fmt.Sprintf("gce-%s", strconv.FormatUint(instance.Id, 10)), + Name: alias, }, Metadata: authMetadata(loginInfo, serviceAccount), DisplayName: instance.Name, diff --git a/plugin/path_login_test.go b/plugin/path_login_test.go index 9085800f..650b15f5 100644 --- a/plugin/path_login_test.go +++ b/plugin/path_login_test.go @@ -160,7 +160,7 @@ func TestLogin_IAM(t *testing.T) { } for _, tc := range cases { - tc := tc + tc := tc // Since the t.Run is parallel, this is needed to prevent scope sharing between loops t.Run(tc.name, func(t *testing.T) { t.Parallel() diff --git a/plugin/path_role.go b/plugin/path_role.go index 3acc6b90..e8335223 100644 --- a/plugin/path_role.go +++ b/plugin/path_role.go @@ -2,15 +2,14 @@ package gcpauth import ( "context" - "errors" "fmt" + "net/http" "strings" - "time" "github.com/hashicorp/go-gcp-common/gcputil" + "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/sdk/framework" vaultconsts "github.com/hashicorp/vault/sdk/helper/consts" - "github.com/hashicorp/vault/sdk/helper/strutil" "github.com/hashicorp/vault/sdk/helper/tokenutil" "github.com/hashicorp/vault/sdk/logical" ) @@ -105,6 +104,11 @@ var iamOnlyFieldSchema = map[string]*framework.FieldSchema{ Default: true, Description: `'iam' roles only. If false, Vault will not not allow GCE instances to login in against this role`, }, + "iam_alias": { + Type: framework.TypeString, + Default: defaultIAMAlias, + Description: "Indicates what value to use when generating an alias for IAM authentications.", + }, } var gceOnlyFieldSchema = map[string]*framework.FieldSchema{ @@ -136,6 +140,11 @@ var gceOnlyFieldSchema = map[string]*framework.FieldSchema{ "\"key:value\" strings that must be present on the GCE instance " + "in order to authenticate. This option only applies to \"gce\" roles.", }, + "gce_alias": { + Type: framework.TypeString, + Default: defaultGCEAlias, + Description: "Indicates what value to use when generating an alias for GCE authentications.", + }, } // pathsRole creates paths for listing roles and CRUD operations. @@ -274,64 +283,69 @@ func (b *GcpAuthBackend) pathRoleRead(ctx context.Context, req *logical.Request, return nil, nil } - resp := make(map[string]interface{}) - role.PopulateTokenData(resp) + respData := make(map[string]interface{}) + role.PopulateTokenData(respData) + + respData["role_id"] = role.RoleID if role.RoleType != "" { - resp["type"] = role.RoleType + respData["type"] = role.RoleType } if len(role.BoundServiceAccounts) > 0 { - resp["bound_service_accounts"] = role.BoundServiceAccounts + respData["bound_service_accounts"] = role.BoundServiceAccounts } if len(role.BoundProjects) > 0 { - resp["bound_projects"] = role.BoundProjects + respData["bound_projects"] = role.BoundProjects } - resp["add_group_aliases"] = role.AddGroupAliases + respData["add_group_aliases"] = role.AddGroupAliases switch role.RoleType { case iamRoleType: if role.MaxJwtExp != 0 { - resp["max_jwt_exp"] = int64(role.MaxJwtExp.Seconds()) + respData["max_jwt_exp"] = int64(role.MaxJwtExp.Seconds()) } - resp["allow_gce_inference"] = role.AllowGCEInference + respData["allow_gce_inference"] = role.AllowGCEInference + respData["iam_alias"] = role.IAMAliasType case gceRoleType: if len(role.BoundRegions) > 0 { - resp["bound_regions"] = role.BoundRegions + respData["bound_regions"] = role.BoundRegions } if len(role.BoundZones) > 0 { - resp["bound_zones"] = role.BoundZones + respData["bound_zones"] = role.BoundZones } if len(role.BoundInstanceGroups) > 0 { - resp["bound_instance_groups"] = role.BoundInstanceGroups + respData["bound_instance_groups"] = role.BoundInstanceGroups } if len(role.BoundLabels) > 0 { - resp["bound_labels"] = role.BoundLabels + respData["bound_labels"] = role.BoundLabels } + respData["gce_alias"] = role.GCEAliasType } // Upgrade vals if len(role.Policies) > 0 { - resp["policies"] = resp["token_policies"] + respData["policies"] = respData["token_policies"] } if role.TTL > 0 { - resp["ttl"] = int64(role.TTL.Seconds()) + respData["ttl"] = int64(role.TTL.Seconds()) } if role.MaxTTL > 0 { - resp["max_ttl"] = int64(role.MaxTTL.Seconds()) + respData["max_ttl"] = int64(role.MaxTTL.Seconds()) } if role.Period > 0 { - resp["period"] = int64(role.Period.Seconds()) + respData["period"] = int64(role.Period.Seconds()) } - return &logical.Response{ - Data: resp, - }, nil + resp := &logical.Response{ + Data: respData, + } + return resp, nil } func (b *GcpAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { // Validate we didn't get extraneous fields if err := validateFields(req, data); err != nil { - return nil, logical.CodedError(422, err.Error()) + return nil, logical.CodedError(http.StatusUnprocessableEntity, err.Error()) } name := strings.ToLower(data.Get("name").(string)) @@ -347,6 +361,14 @@ func (b *GcpAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical. role = &gcpRole{} } + if role.RoleID == "" { + roleID, err := uuid.GenerateUUID() + if err != nil { + return nil, logical.CodedError(http.StatusInternalServerError, fmt.Sprintf("unable to generate roleID: %s", err)) + } + role.RoleID = roleID + } + warnings, err := role.updateRole(b.System(), req, data) if err != nil { resp := logical.ErrorResponse(err.Error()) @@ -380,7 +402,7 @@ const pathListRolesHelpDesc = `Lists all roles under the GCP backends by name.` func (b *GcpAuthBackend) pathRoleEditIamServiceAccounts(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { // Validate we didn't get extraneous fields if err := validateFields(req, data); err != nil { - return nil, logical.CodedError(422, err.Error()) + return nil, logical.CodedError(http.StatusUnprocessableEntity, err.Error()) } var warnings []string @@ -437,7 +459,7 @@ func editStringValues(initial []string, toAdd []string, toRemove []string) []str func (b *GcpAuthBackend) pathRoleEditGceLabels(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { // Validate we didn't get extraneous fields if err := validateFields(req, data); err != nil { - return nil, logical.CodedError(422, err.Error()) + return nil, logical.CodedError(http.StatusUnprocessableEntity, err.Error()) } var warnings []string @@ -480,7 +502,7 @@ func (b *GcpAuthBackend) pathRoleEditGceLabels(ctx context.Context, req *logical return b.storeRole(ctx, req.Storage, roleName, role, warnings) } -// role reads a gcpRole from storage. This assumes the caller has already obtained the role lock. +// role from storage. This assumes the caller has already obtained the role lock. func (b *GcpAuthBackend) role(ctx context.Context, s logical.Storage, name string) (*gcpRole, error) { name = strings.ToLower(name) @@ -540,6 +562,25 @@ func (b *GcpAuthBackend) role(ctx context.Context, s logical.Storage, name strin modified = true } + if role.RoleType == "iam" && role.IAMAliasType == "" { + role.IAMAliasType = defaultIAMAlias + modified = true + } + if role.RoleType == "gce" && role.GCEAliasType == "" { + role.GCEAliasType = defaultGCEAlias + modified = true + } + + // Ensure the role has a RoleID + if role.RoleID == "" { + roleID, err := uuid.GenerateUUID() + if err != nil { + return nil, logical.CodedError(http.StatusInternalServerError, fmt.Sprintf("unable to generate roleID for role missing an ID: %s", err)) + } + role.RoleID = roleID + modified = true + } + if modified && (b.System().LocalMount() || !b.System().ReplicationState().HasState(vaultconsts.ReplicationPerformanceSecondary)) { b.Logger().Info("upgrading role to new schema", "role", name) @@ -576,6 +617,19 @@ func (b *GcpAuthBackend) storeRole(ctx context.Context, s logical.Storage, roleN return logical.ErrorResponse(err.Error()), nil } + // Set default alias names + if role.IAMAliasType == "" { + role.IAMAliasType = defaultIAMAlias + } + if role.GCEAliasType == "" { + role.GCEAliasType = defaultGCEAlias + } + + // Ensure a version is specified + if role.Version == 0 { + role.Version = currentGCPRoleVersion + } + entry, err := logical.StorageEntryJSON(fmt.Sprintf("role/%s", roleName), role) if err != nil { return nil, err @@ -587,387 +641,6 @@ func (b *GcpAuthBackend) storeRole(ctx context.Context, s logical.Storage, roleN return &resp, nil } -type gcpRole struct { - tokenutil.TokenParams - - // Type of this role. See path_role constants for currently supported types. - RoleType string `json:"role_type,omitempty"` - - // Policies for Vault to assign to authorized entities. - Policies []string `json:"policies,omitempty"` - - // TTL of Vault auth leases under this role. - TTL time.Duration `json:"ttl,omitempty"` - - // Max total TTL including renewals, of Vault auth leases under this role. - MaxTTL time.Duration `json:"max_ttl,omitempty"` - - // Period, If set, indicates that this token should not expire and - // should be automatically renewed within this time period - // with TTL equal to this value. - Period time.Duration `json:"period,omitempty"` - - // Projects that entities must belong to - BoundProjects []string `json:"bound_projects,omitempty"` - - // Service accounts allowed to login under this role. - BoundServiceAccounts []string `json:"bound_service_accounts,omitempty"` - - // AddGroupAliases adds Vault group aliases to the response. - AddGroupAliases bool `json:"add_group_aliases,omitempty"` - - // --| IAM-only attributes |-- - // MaxJwtExp is the duration from time of authentication that a JWT used to authenticate to role must expire within. - // TODO(emilymye): Allow this to be updated for GCE roles once 'exp' parameter has been allowed for GCE metadata. - MaxJwtExp time.Duration `json:"max_jwt_exp,omitempty"` - - // AllowGCEInference, if false, does not allow a GCE instance to login under this 'iam' role. If true (default), - // a service account is inferred from the instance metadata and used as the authenticating instance. - AllowGCEInference bool `json:"allow_gce_inference,omitempty"` - - // --| GCE-only attributes |-- - // BoundRegions that instances must belong to in order to login under this role. - BoundRegions []string `json:"bound_regions,omitempty"` - - // BoundZones that instances must belong to in order to login under this role. - BoundZones []string `json:"bound_zones,omitempty"` - - // BoundInstanceGroups are the instance group that instances must belong to in order to login under this role. - BoundInstanceGroups []string `json:"bound_instance_groups,omitempty"` - - // BoundLabels that instances must currently have set in order to login under this role. - BoundLabels map[string]string `json:"bound_labels,omitempty"` - - // Deprecated fields - // TODO: Remove in 0.5.0+ - ProjectId string `json:"project_id,omitempty"` - BoundRegion string `json:"bound_region,omitempty"` - BoundZone string `json:"bound_zone,omitempty"` - BoundInstanceGroup string `json:"bound_instance_group,omitempty"` -} - -// Update updates the given role with values parsed/validated from given FieldData. -// Exactly one of the response and error will be nil. The response is only used to pass back warnings. -// This method does not validate the role. Validation is done before storage. -func (role *gcpRole) updateRole(sys logical.SystemView, req *logical.Request, data *framework.FieldData) (warnings []string, err error) { - if e := role.ParseTokenFields(req, data); e != nil { - return nil, e - } - - // Handle token field upgrades - { - if e := tokenutil.UpgradeValue(data, "policies", "token_policies", &role.Policies, &role.TokenPolicies); e != nil { - return nil, e - } - - if e := tokenutil.UpgradeValue(data, "ttl", "token_ttl", &role.TTL, &role.TokenTTL); e != nil { - return nil, e - } - - if e := tokenutil.UpgradeValue(data, "max_ttl", "token_max_ttl", &role.MaxTTL, &role.TokenMaxTTL); e != nil { - return nil, e - } - - if e := tokenutil.UpgradeValue(data, "period", "token_period", &role.Period, &role.TokenPeriod); e != nil { - return nil, e - } - } - - // Set role type - if rt, ok := data.GetOk("type"); ok { - roleType := rt.(string) - if role.RoleType != roleType && req.Operation == logical.UpdateOperation { - err = errors.New("role type cannot be changed for an existing role") - return - } - role.RoleType = roleType - } else if req.Operation == logical.CreateOperation { - err = errors.New(errEmptyRoleType) - return - } - - def := sys.DefaultLeaseTTL() - if role.TokenTTL > def { - warnings = append(warnings, fmt.Sprintf(`Given token ttl of %q is greater `+ - `than the maximum system/mount TTL of %q. The TTL will be capped at `+ - `%q during login.`, role.TokenTTL, def, def)) - } - - // Update token Max TTL. - def = sys.MaxLeaseTTL() - if role.TokenMaxTTL > def { - warnings = append(warnings, fmt.Sprintf(`Given token max ttl of %q is greater `+ - `than the maximum system/mount MaxTTL of %q. The MaxTTL will be `+ - `capped at %q during login.`, role.TokenMaxTTL, def, def)) - } - if role.TokenPeriod > def { - warnings = append(warnings, fmt.Sprintf(`Given token period of %q is greater `+ - `than the maximum system/mount period of %q. The period will be `+ - `capped at %q during login.`, role.TokenPeriod, def, def)) - } - - // Update bound GCP service accounts. - if sa, ok := data.GetOk("bound_service_accounts"); ok { - role.BoundServiceAccounts = sa.([]string) - } else { - // Check for older version of param name - if sa, ok := data.GetOk("service_accounts"); ok { - warnings = append(warnings, `The "service_accounts" field is deprecated. `+ - `Please use "bound_service_accounts" instead. The "service_accounts" `+ - `field will be removed in a later release, so please update accordingly.`) - role.BoundServiceAccounts = sa.([]string) - } - } - if len(role.BoundServiceAccounts) > 0 { - role.BoundServiceAccounts = strutil.TrimStrings(role.BoundServiceAccounts) - role.BoundServiceAccounts = strutil.RemoveDuplicates(role.BoundServiceAccounts, false) - } - - // Update bound GCP projects. - boundProjects, givenBoundProj := data.GetOk("bound_projects") - if givenBoundProj { - role.BoundProjects = boundProjects.([]string) - } - if projectId, ok := data.GetOk("project_id"); ok { - if givenBoundProj { - return warnings, errors.New("only one of 'bound_projects' or 'project_id' can be given") - } - warnings = append(warnings, - `The "project_id" (singular) field is deprecated. `+ - `Please use plural "bound_projects" instead to bind required GCP projects. `+ - `The "project_id" field will be removed in a later release, so please update accordingly.`) - role.BoundProjects = []string{projectId.(string)} - } - if len(role.BoundProjects) > 0 { - role.BoundProjects = strutil.TrimStrings(role.BoundProjects) - role.BoundProjects = strutil.RemoveDuplicates(role.BoundProjects, false) - } - - // Update bound GCP projects. - addGroupAliases, ok := data.GetOk("add_group_aliases") - if ok { - role.AddGroupAliases = addGroupAliases.(bool) - } - - // Update fields specific to this type - switch role.RoleType { - case iamRoleType: - if err = checkInvalidRoleTypeArgs(data, gceOnlyFieldSchema); err != nil { - return - } - if warnings, err = role.updateIamFields(data, req.Operation); err != nil { - return - } - case gceRoleType: - if err = checkInvalidRoleTypeArgs(data, iamOnlyFieldSchema); err != nil { - return - } - if warnings, err = role.updateGceFields(data, req.Operation); err != nil { - return - } - } - - return -} - -func (role *gcpRole) validate(sys logical.SystemView) (warnings []string, err error) { - warnings = []string{} - - switch role.RoleType { - case iamRoleType: - if warnings, err = role.validateForIAM(); err != nil { - return warnings, err - } - case gceRoleType: - if warnings, err = role.validateForGCE(); err != nil { - return warnings, err - } - case "": - return warnings, errors.New(errEmptyRoleType) - default: - return warnings, fmt.Errorf("role type '%s' is invalid", role.RoleType) - } - - defaultLeaseTTL := sys.DefaultLeaseTTL() - if role.TokenTTL > defaultLeaseTTL { - warnings = append(warnings, fmt.Sprintf( - "Given ttl of %d seconds greater than current mount/system default of %d seconds; ttl will be capped at login time", - role.TokenTTL/time.Second, defaultLeaseTTL/time.Second)) - } - - defaultMaxTTL := sys.MaxLeaseTTL() - if role.TokenMaxTTL > defaultMaxTTL { - warnings = append(warnings, fmt.Sprintf( - "Given max_ttl of %d seconds greater than current mount/system default of %d seconds; max_ttl will be capped at login time", - role.TokenMaxTTL/time.Second, defaultMaxTTL/time.Second)) - } - if role.TokenMaxTTL < time.Duration(0) { - return warnings, errors.New("max_ttl cannot be negative") - } - if role.TokenMaxTTL != 0 && role.TokenMaxTTL < role.TokenTTL { - return warnings, errors.New("ttl should be shorter than max_ttl") - } - - if role.TokenPeriod > sys.MaxLeaseTTL() { - return warnings, fmt.Errorf("'period' of '%s' is greater than the backend's maximum lease TTL of '%s'", role.TokenPeriod.String(), sys.MaxLeaseTTL().String()) - } - - return warnings, nil -} - -// updateIamFields updates IAM-only fields for a role. -func (role *gcpRole) updateIamFields(data *framework.FieldData, op logical.Operation) (warnings []string, err error) { - if allowGCEInference, ok := data.GetOk("allow_gce_inference"); ok { - role.AllowGCEInference = allowGCEInference.(bool) - } else if op == logical.CreateOperation { - role.AllowGCEInference = data.Get("allow_gce_inference").(bool) - } - - if maxJwtExp, ok := data.GetOk("max_jwt_exp"); ok { - role.MaxJwtExp = time.Duration(maxJwtExp.(int)) * time.Second - } else if op == logical.CreateOperation { - role.MaxJwtExp = time.Duration(defaultIamMaxJwtExpMinutes) * time.Minute - } - - return -} - -// updateGceFields updates GCE-only fields for a role. -func (role *gcpRole) updateGceFields(data *framework.FieldData, op logical.Operation) (warnings []string, err error) { - if regions, ok := data.GetOk("bound_regions"); ok { - role.BoundRegions = regions.([]string) - } else if op == logical.CreateOperation { - role.BoundRegions = data.Get("bound_regions").([]string) - } - - if zones, ok := data.GetOk("bound_zones"); ok { - role.BoundZones = zones.([]string) - } else if op == logical.CreateOperation { - role.BoundZones = data.Get("bound_zones").([]string) - } - - if instanceGroups, ok := data.GetOk("bound_instance_groups"); ok { - role.BoundInstanceGroups = instanceGroups.([]string) - } else if op == logical.CreateOperation { - role.BoundInstanceGroups = data.Get("bound_instance_groups").([]string) - } - - if boundRegion, ok := data.GetOk("bound_region"); ok { - if _, ok := data.GetOk("bound_regions"); ok { - err = fmt.Errorf(`cannot specify both "bound_region" and "bound_regions"`) - return - } - - warnings = append(warnings, `The "bound_region" field is deprecated. `+ - `Please use "bound_regions" (plural) instead. You can still specify a `+ - `single region, but multiple regions are also now supported. The `+ - `"bound_region" field will be removed in a later release, so please `+ - `update accordingly.`) - role.BoundRegions = append(role.BoundRegions, boundRegion.(string)) - } - - if boundZone, ok := data.GetOk("bound_zone"); ok { - if _, ok := data.GetOk("bound_zones"); ok { - err = fmt.Errorf(`cannot specify both "bound_zone" and "bound_zones"`) - return - } - - warnings = append(warnings, `The "bound_zone" field is deprecated. `+ - `Please use "bound_zones" (plural) instead. You can still specify a `+ - `single zone, but multiple zones are also now supported. The `+ - `"bound_zone" field will be removed in a later release, so please `+ - `update accordingly.`) - role.BoundZones = append(role.BoundZones, boundZone.(string)) - } - - if boundInstanceGroup, ok := data.GetOk("bound_instance_group"); ok { - if _, ok := data.GetOk("bound_instance_groups"); ok { - err = fmt.Errorf(`cannot specify both "bound_instance_group" and "bound_instance_groups"`) - return - } - - warnings = append(warnings, `The "bound_instance_group" field is deprecated. `+ - `Please use "bound_instance_groups" (plural) instead. You can still specify a `+ - `single instance group, but multiple instance groups are also now supported. The `+ - `"bound_instance_group" field will be removed in a later release, so please `+ - `update accordingly.`) - role.BoundInstanceGroups = append(role.BoundInstanceGroups, boundInstanceGroup.(string)) - } - - if labelsRaw, ok := data.GetOk("bound_labels"); ok { - labels, invalidLabels := gcputil.ParseGcpLabels(labelsRaw.([]string)) - if len(invalidLabels) > 0 { - err = fmt.Errorf("invalid labels given: %q", invalidLabels) - return - } - role.BoundLabels = labels - } - - if len(role.Policies) > 0 { - role.Policies = strutil.TrimStrings(role.Policies) - role.Policies = strutil.RemoveDuplicates(role.Policies, false) - } - - if len(role.BoundRegions) > 0 { - role.BoundRegions = strutil.TrimStrings(role.BoundRegions) - role.BoundRegions = strutil.RemoveDuplicates(role.BoundRegions, false) - } - - if len(role.BoundZones) > 0 { - role.BoundZones = strutil.TrimStrings(role.BoundZones) - role.BoundZones = strutil.RemoveDuplicates(role.BoundZones, false) - } - - if len(role.BoundInstanceGroups) > 0 { - role.BoundInstanceGroups = strutil.TrimStrings(role.BoundInstanceGroups) - role.BoundInstanceGroups = strutil.RemoveDuplicates(role.BoundInstanceGroups, false) - } - - return -} - -// validateIamFields validates the IAM-only fields for a role. -func (role *gcpRole) validateForIAM() (warnings []string, err error) { - if len(role.BoundServiceAccounts) == 0 { - return []string{}, errors.New(errEmptyIamServiceAccounts) - } - - if len(role.BoundServiceAccounts) > 1 && strutil.StrListContains(role.BoundServiceAccounts, serviceAccountsWildcard) { - return []string{}, fmt.Errorf("cannot provide IAM service account wildcard '%s' (for all service accounts) with other service accounts", serviceAccountsWildcard) - } - - maxMaxJwtExp := time.Duration(maxJwtExpMaxMinutes) * time.Minute - if role.MaxJwtExp > maxMaxJwtExp { - return warnings, fmt.Errorf("max_jwt_exp cannot be more than %d minutes", maxJwtExpMaxMinutes) - } - - return []string{}, nil -} - -// validateGceFields validates the GCE-only fields for a role. -func (role *gcpRole) validateForGCE() (warnings []string, err error) { - warnings = []string{} - - hasRegion := len(role.BoundRegions) > 0 - hasZone := len(role.BoundZones) > 0 - hasRegionOrZone := hasRegion || hasZone - - hasInstanceGroup := len(role.BoundInstanceGroups) > 0 - - if hasInstanceGroup && !hasRegionOrZone { - return warnings, errors.New(`region or zone information must be specified if an instance group is given`) - } - - if hasRegion && hasZone { - warnings = append(warnings, `Given both "bound_regions" and "bound_zones" `+ - `fields for role type "gce", "bound_regions" will be ignored in favor `+ - `of the more specific "bound_zones" field. To fix this warning, update `+ - `the role to remove either the "bound_regions" or "bound_zones" field.`) - } - - return warnings, nil -} - // checkInvalidRoleTypeArgs checks that the data provided does not contain arguments // for a different role type. If it does find some, it will return an error with the // invalid args. diff --git a/plugin/path_role_test.go b/plugin/path_role_test.go index 4e51a420..29282d0e 100644 --- a/plugin/path_role_test.go +++ b/plugin/path_role_test.go @@ -2,16 +2,19 @@ package gcpauth import ( "context" + "encoding/json" "fmt" "math/rand" + "reflect" "strings" "testing" "time" - "reflect" - + "github.com/golang/mock/gomock" + "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/policyutil" "github.com/hashicorp/vault/sdk/helper/strutil" + "github.com/hashicorp/vault/sdk/helper/tokenutil" "github.com/hashicorp/vault/sdk/logical" ) @@ -64,6 +67,7 @@ func TestRoleUpdateIam(t *testing.T) { "name": roleName, "type": iamRoleType, "bound_service_accounts": serviceAccounts, + "iam_alias": defaultIAMAlias, }) serviceAccounts = append(serviceAccounts, "testaccount@google.com") @@ -93,6 +97,7 @@ func TestRoleUpdateIam(t *testing.T) { "allow_gce_inference": false, "add_group_aliases": true, "bound_service_accounts": serviceAccounts, + "iam_alias": defaultIAMAlias, }) } @@ -123,6 +128,7 @@ func TestRoleIam_Wildcard(t *testing.T) { testRoleRead(t, b, reqStorage, roleName, map[string]interface{}{ "type": iamRoleType, "bound_service_accounts": serviceAccounts, + "iam_alias": defaultIAMAlias, }) } @@ -145,6 +151,7 @@ func TestRoleIam_EditServiceAccounts(t *testing.T) { "type": iamRoleType, "bound_projects": projects, "bound_service_accounts": initial, + "iam_alias": defaultIAMAlias, } testRoleCreate(t, b, reqStorage, data) @@ -233,6 +240,7 @@ func TestRoleGce(t *testing.T) { "type": gceRoleType, "bound_projects": []string{}, "bound_service_accounts": []string{}, + "gce_alias": defaultGCEAlias, }) serviceAccounts := []string{"aserviceaccountid", "testaccount@google.com"} @@ -270,6 +278,7 @@ func TestRoleGce(t *testing.T) { "bound_instance_groups": []string{"devGroup"}, "bound_service_accounts": serviceAccounts, "add_group_aliases": true, + "gce_alias": defaultGCEAlias, }) } @@ -295,6 +304,7 @@ func TestRoleGce_EditLabels(t *testing.T) { "type": gceRoleType, "bound_projects": []string{projectId}, "bound_labels": labels, + "gce_alias": defaultGCEAlias, }) testRoleEditLabels(t, b, reqStorage, map[string]interface{}{ @@ -308,6 +318,7 @@ func TestRoleGce_EditLabels(t *testing.T) { "type": gceRoleType, "bound_projects": []string{projectId}, "bound_labels": labels, + "gce_alias": defaultGCEAlias, }) testRoleEditLabels(t, b, reqStorage, map[string]interface{}{ @@ -323,6 +334,7 @@ func TestRoleGce_EditLabels(t *testing.T) { "type": gceRoleType, "bound_projects": []string{projectId}, "bound_labels": labels, + "gce_alias": defaultGCEAlias, }) } @@ -354,6 +366,7 @@ func TestRoleGce_DeprecatedFields(t *testing.T) { "bound_regions": []string{"us-east1"}, "bound_zones": []string{"us-east1-a"}, "bound_instance_groups": []string{"my-ig"}, + "gce_alias": defaultGCEAlias, }) }) @@ -451,6 +464,511 @@ func TestRole_InvalidRoleType(t *testing.T) { }, []string{"role type", invalidRoleType, "is invalid"}) } +func TestRetrieveRole(t *testing.T) { + type testCase struct { + name string + + getName string + getResp *logical.StorageEntry + getErr error + getTimes int + + localMount bool + localMountTimes int + replicationState consts.ReplicationState + replicationStateTimes int + + putTimes int + + expectedRole *gcpRole + expectErr bool + } + + tests := map[string]testCase{ + "not found": { + name: "testrole", + + getName: "role/testrole", + getResp: nil, + getErr: nil, + getTimes: 1, + + localMountTimes: 0, + replicationStateTimes: 0, + + expectedRole: nil, + expectErr: false, + }, + "storage error": { + name: "testrole", + + getName: "role/testrole", + getResp: nil, + getErr: fmt.Errorf("test error"), + getTimes: 1, + + localMountTimes: 0, + replicationStateTimes: 0, + + expectedRole: nil, + expectErr: true, + }, + "bad data": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "role/testrole", + Value: []byte("asdfhoiasndf"), + }, + getErr: nil, + getTimes: 1, + + localMountTimes: 0, + replicationStateTimes: 0, + + expectedRole: nil, + expectErr: true, + }, + "iam nothing modified": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + RoleType: "iam", + IAMAliasType: defaultIAMAlias, + }), + }, + getErr: nil, + getTimes: 1, + + localMountTimes: 0, + replicationStateTimes: 0, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + RoleType: "iam", + IAMAliasType: defaultIAMAlias, + }, + expectErr: false, + }, + "gce nothing modified": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + RoleType: "gce", + GCEAliasType: defaultGCEAlias, + }), + }, + getErr: nil, + getTimes: 1, + + localMountTimes: 0, + replicationStateTimes: 0, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + RoleType: "gce", + GCEAliasType: defaultGCEAlias, + }, + expectErr: false, + }, + "projectID upgrade": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + ProjectId: "projectID", + BoundProjects: []string{}, + }), + }, + getErr: nil, + getTimes: 1, + + localMount: true, + localMountTimes: 1, + replicationStateTimes: 0, + + putTimes: 1, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + BoundProjects: []string{"projectID"}, + }, + expectErr: false, + }, + "boundRegion upgrade": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + BoundRegion: "boundRegion", + BoundRegions: []string{}, + }), + }, + getErr: nil, + getTimes: 1, + + localMount: true, + localMountTimes: 1, + replicationStateTimes: 0, + + putTimes: 1, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + BoundRegions: []string{"boundRegion"}, + }, + expectErr: false, + }, + "boundZone upgrade": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + BoundZone: "boundZone", + BoundZones: []string{}, + }), + }, + getErr: nil, + getTimes: 1, + + localMount: true, + localMountTimes: 1, + replicationStateTimes: 0, + + putTimes: 1, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + BoundZones: []string{"boundZone"}, + }, + expectErr: false, + }, + "boundInstanceGroup upgrade": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + BoundInstanceGroup: "boundInstanceGroup", + BoundInstanceGroups: []string{}, + }), + }, + getErr: nil, + getTimes: 1, + + localMount: true, + localMountTimes: 1, + replicationStateTimes: 0, + + putTimes: 1, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + BoundInstanceGroups: []string{"boundInstanceGroup"}, + }, + expectErr: false, + }, + "TTL upgrade": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + TokenParams: tokenutil.TokenParams{}, + TTL: 1 * time.Second, + }), + }, + getErr: nil, + getTimes: 1, + + localMount: true, + localMountTimes: 1, + replicationStateTimes: 0, + + putTimes: 1, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + TokenParams: tokenutil.TokenParams{ + TokenTTL: 1 * time.Second, + }, + TTL: 1 * time.Second, + }, + expectErr: false, + }, + "MaxTTL upgrade": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + TokenParams: tokenutil.TokenParams{}, + MaxTTL: 1 * time.Second, + }), + }, + getErr: nil, + getTimes: 1, + + localMount: true, + localMountTimes: 1, + replicationStateTimes: 0, + + putTimes: 1, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + TokenParams: tokenutil.TokenParams{ + TokenMaxTTL: 1 * time.Second, + }, + MaxTTL: 1 * time.Second, + }, + expectErr: false, + }, + "TokenPeriod upgrade": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + TokenParams: tokenutil.TokenParams{}, + Period: 1 * time.Second, + }), + }, + getErr: nil, + getTimes: 1, + + localMount: true, + localMountTimes: 1, + replicationStateTimes: 0, + + putTimes: 1, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + TokenParams: tokenutil.TokenParams{ + TokenPeriod: 1 * time.Second, + }, + Period: 1 * time.Second, + }, + expectErr: false, + }, + "TokenPolicies upgrade": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + TokenParams: tokenutil.TokenParams{}, + Policies: []string{"policy1", "policy2"}, + }), + }, + getErr: nil, + getTimes: 1, + + localMount: true, + localMountTimes: 1, + replicationStateTimes: 0, + + putTimes: 1, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + TokenParams: tokenutil.TokenParams{ + TokenPolicies: []string{"policy1", "policy2"}, + }, + Policies: []string{"policy1", "policy2"}, + }, + expectErr: false, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + storage := NewMockStorage(ctrl) + storage.EXPECT().Get(ctx, test.getName).Return(test.getResp, test.getErr).Times(test.getTimes) + + putReq := &logical.StorageEntry{ + Key: fmt.Sprintf("role/%s", test.name), + Value: append(toJSON(t, test.expectedRole), '\n'), // Add a newline because StorageEntryJSON somehow adds it + } + storage.EXPECT().Put(ctx, putReq).Return(nil).Times(test.putTimes) + + systemView := NewMockSystemView(ctrl) + systemView.EXPECT().LocalMount().Return(test.localMount).Times(test.localMountTimes) + systemView.EXPECT().ReplicationState().Return(test.replicationState).Times(test.replicationStateTimes) + + be, err := Factory(ctx, &logical.BackendConfig{System: systemView}) + if err != nil { + t.Fatalf("no error expected, got: %s", err) + } + b := be.(*GcpAuthBackend) + + actualResp, err := b.role(ctx, storage, test.name) + if test.expectErr && err == nil { + t.Fatalf("err expected, got nil") + } + if !test.expectErr && err != nil { + t.Fatalf("no error expected, got: %s", err) + } + if !reflect.DeepEqual(actualResp, test.expectedRole) { + t.Fatalf("Actual role: %#v\nExpected role: %#v", actualResp, test.expectedRole) + } + }) + } + + t.Run("storage put error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + name := "testrole" + + getResp := &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + ProjectId: "projectID", + }), + } + + putReq := &logical.StorageEntry{ + Key: fmt.Sprintf("role/%s", name), + Value: append(toJSON(t, + gcpRole{ + RoleID: "testroleid", + BoundProjects: []string{"projectID"}, + }, + ), '\n'), // Add a newline because StorageEntryJSON somehow adds it + } + + storage := NewMockStorage(ctrl) + storage.EXPECT().Get(ctx, fmt.Sprintf("role/%s", name)).Return(getResp, nil) + storage.EXPECT().Put(ctx, putReq).Return(fmt.Errorf("test error")) + + systemView := NewMockSystemView(ctrl) + systemView.EXPECT().LocalMount().Return(true) + + be, err := Factory(ctx, &logical.BackendConfig{System: systemView}) + if err != nil { + t.Fatalf("no error expected, got: %s", err) + } + b := be.(*GcpAuthBackend) + + actualResp, err := b.role(ctx, storage, name) + if err == nil { + t.Fatalf("err expected, got nil") + } + if actualResp != nil { + t.Fatalf("no role expected, but got: %#v", actualResp) + } + }) + + t.Run("roleID is generated when one does not exist", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + name := "testrole" + + getResp := &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{}), + } + + storage := NewMockStorage(ctrl) + storage.EXPECT().Get(ctx, fmt.Sprintf("role/%s", name)).Return(getResp, nil) + + var actualRawPut *logical.StorageEntry + storage.EXPECT().Put(ctx, gomock.Any()).DoAndReturn(func(_ context.Context, put *logical.StorageEntry) error { + actualRawPut = put + return nil + }) + + systemView := NewMockSystemView(ctrl) + systemView.EXPECT().LocalMount().Return(true) + + be, err := Factory(ctx, &logical.BackendConfig{System: systemView}) + if err != nil { + t.Fatalf("no error expected, got: %s", err) + } + b := be.(*GcpAuthBackend) + + actualRole, err := b.role(ctx, storage, name) + if err != nil { + t.Fatalf("no err expected, got: %s", err) + } + + if actualRole.RoleID == "" { + t.Fatalf("RoleID not set on returned role") + } + + expectedPutKey := fmt.Sprintf("role/%s", name) + if actualRawPut.Key != expectedPutKey { + t.Fatalf("Actual put key: %s Expected put key: %s", actualRawPut.Key, expectedPutKey) + } + + putRole := gcpRole{} + err = json.Unmarshal(actualRawPut.Value, &putRole) + if err != nil { + t.Fatalf("no err expected, got: %s", err) + } + + if putRole.RoleID != actualRole.RoleID { + t.Fatalf("Saved RoleID [%s] does not match returned RoleID [%s]", putRole.RoleID, actualRole.RoleID) + } + }) +} + //-- Utils -- func testRoleCreate(tb testing.TB, b logical.Backend, s logical.Storage, d map[string]interface{}) { tb.Helper() @@ -558,6 +1076,13 @@ func testRoleRead(tb testing.TB, b logical.Backend, s logical.Storage, roleName tb.Fatal(resp.Error()) } + // Because role_id is generated, ensure that it exists but don't worry about the specific value + roleID, exists := resp.Data["role_id"] + if !exists || roleID == "" { + tb.Fatal("missing or empty role_id") + } + delete(resp.Data, "role_id") + if err := checkData(resp, expected, expectedDefaults); err != nil { tb.Fatal(err) } @@ -642,3 +1167,13 @@ func testRoleAndProject(tb testing.TB) (string, string) { return roleName, projectId } + +func toJSON(t testing.TB, val interface{}) []byte { + t.Helper() + + b, err := json.Marshal(val) + if err != nil { + t.Fatalf("Failed to marshal to JSON: %s", err) + } + return b +}