diff --git a/Makefile b/Makefile index 8428eca6..a20899de 100644 --- a/Makefile +++ b/Makefile @@ -51,5 +51,8 @@ fmtcheck: fmt: gofmt -w $(GOFMT_FILES) +mocks: + mockgen -destination ${CURDIR}/plugin/mocks_test.go -package gcpauth github.com/hashicorp/vault/sdk/logical SystemView,Storage + .PHONY: bin default generate test vet bootstrap fmt fmtcheck diff --git a/go.mod b/go.mod index 87647f25..c1193bd6 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,12 @@ module github.com/hashicorp/vault-plugin-auth-gcp go 1.12 require ( + github.com/golang/mock v1.4.3 github.com/hashicorp/errwrap v1.0.0 github.com/hashicorp/go-cleanhttp v0.5.1 github.com/hashicorp/go-gcp-common v0.6.0 github.com/hashicorp/go-hclog v0.12.0 + github.com/hashicorp/go-uuid v1.0.2 github.com/hashicorp/vault/api v1.0.5-0.20200317185738-82f498082f02 github.com/hashicorp/vault/sdk v0.1.14-0.20200317185738-82f498082f02 github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d // indirect diff --git a/go.sum b/go.sum index 650821c0..1224e306 100644 --- a/go.sum +++ b/go.sum @@ -43,7 +43,10 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekf github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.4.3 h1:GV+pQPG/EUUbkh47niozDcADz6go/dUwhVzdUQHIVRw= +github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -215,6 +218,7 @@ golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqY golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191008105621-543471e840be h1:QAcqgptGM8IQBC9K/RC4o+O9YmqEm0diQn9QmZw/0mU= golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20181227161524-e6919f6577db h1:6/JqlYfC1CCaLnGceQTI+sDGhC9UBSPAsBqI0Gun6kU= golang.org/x/text v0.3.1-0.20181227161524-e6919f6577db/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -227,6 +231,8 @@ golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135 h1:5Beo0mZN8dRzgrMMkDp0jc8YXQKx9DiJ2k1dkvGsn5A= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= google.golang.org/api v0.0.0-20181220000619-583d854617af/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= google.golang.org/api v0.2.0/go.mod h1:IfRCZScioGtypHNTlz3gFk67J8uePVW7uDTBzXuIkhU= @@ -262,3 +268,7 @@ honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20180920025451-e3ad64cb4ed3/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +rsc.io/quote/v3 v3.1.0 h1:9JKUTTIUgS6kzR9mK1YuGKv6Nl+DijDNIc0ghT58FaY= +rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= +rsc.io/sampler v1.3.0 h1:7uVkIFmeBqHfdjD+gZwtXXI+RODJ2Wc4O7MPEh/QiW4= +rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/plugin/aliasing.go b/plugin/aliasing.go new file mode 100644 index 00000000..b5617468 --- /dev/null +++ b/plugin/aliasing.go @@ -0,0 +1,91 @@ +package gcpauth + +import ( + "fmt" + "sort" + "strconv" + "strings" + + "google.golang.org/api/compute/v1" + "google.golang.org/api/iam/v1" +) + +type iamAliaser func(role *gcpRole, svcAccount *iam.ServiceAccount) (alias string) +type gceAliaser func(role *gcpRole, instance *compute.Instance) (alias string) + +const ( + defaultIAMAlias = "unique_id" + defaultGCEAlias = "instance_id" +) + +var ( + allowedIAMAliases = map[string]iamAliaser{ + defaultIAMAlias: getIAMSvcAccountUniqueID, + "": getIAMSvcAccountUniqueID, // For backwards compatibility + + "role_id": getIAMRoleID, + } + allowedGCEAliases = map[string]gceAliaser{ + defaultGCEAlias: getGCEInstanceID, + "": getGCEInstanceID, // For backwards compatibility + + "role_id": getGCERoleID, + } + + allowedIAMAliasesSlice = iamMapKeyToSlice(allowedIAMAliases) + allowedGCEAliasesSlice = gceMapKeyToSlice(allowedGCEAliases) +) + +func iamMapKeyToSlice(m map[string]iamAliaser) (s []string) { + for key := range m { + if key == "" { + continue + } + s = append(s, key) + } + sort.Strings(s) + return s +} + +func gceMapKeyToSlice(m map[string]gceAliaser) (s []string) { + for key := range m { + if key == "" { + continue + } + s = append(s, key) + } + sort.Strings(s) + return s +} + +func getIAMSvcAccountUniqueID(_ *gcpRole, svcAccount *iam.ServiceAccount) (alias string) { + return svcAccount.UniqueId +} + +func getIAMRoleID(role *gcpRole, _ *iam.ServiceAccount) (alias string) { + return role.RoleID +} + +func getGCEInstanceID(_ *gcpRole, instance *compute.Instance) (alias string) { + return fmt.Sprintf("gce-%s", strconv.FormatUint(instance.Id, 10)) +} + +func getGCERoleID(role *gcpRole, _ *compute.Instance) (alias string) { + return role.RoleID +} + +func getIAMAlias(role *gcpRole, svcAccount *iam.ServiceAccount) (alias string, err error) { + aliaser, exists := allowedIAMAliases[role.IAMAliasType] + if !exists { + return "", fmt.Errorf("invalid IAM alias type: must be one of: %s", strings.Join(allowedIAMAliasesSlice, ", ")) + } + return aliaser(role, svcAccount), nil +} + +func getGCEAlias(role *gcpRole, instance *compute.Instance) (alias string, err error) { + aliaser, exists := allowedGCEAliases[role.GCEAliasType] + if !exists { + return "", fmt.Errorf("invalid GCE alias type: must be one of: %s", strings.Join(allowedIAMAliasesSlice, ", ")) + } + return aliaser(role, instance), nil +} diff --git a/plugin/aliasing_test.go b/plugin/aliasing_test.go new file mode 100644 index 00000000..22d8005e --- /dev/null +++ b/plugin/aliasing_test.go @@ -0,0 +1,150 @@ +package gcpauth + +import ( + "testing" + + "google.golang.org/api/compute/v1" + "google.golang.org/api/iam/v1" +) + +func TestGetIAMAlias(t *testing.T) { + type testCase struct { + role *gcpRole + svcAccount *iam.ServiceAccount + expectedAlias string + expectErr bool + } + + tests := map[string]testCase{ + "invalid type": { + role: &gcpRole{ + IAMAliasType: "bogus", + RoleID: "testRoleID", + }, + svcAccount: &iam.ServiceAccount{ + UniqueId: "iamUniqueID", + }, + expectedAlias: "", + expectErr: true, + }, + "empty type goes to default": { + role: &gcpRole{ + IAMAliasType: "", + RoleID: "testRoleID", + }, + svcAccount: &iam.ServiceAccount{ + UniqueId: "iamUniqueID", + }, + expectedAlias: "iamUniqueID", + expectErr: false, + }, + "default type": { + role: &gcpRole{ + IAMAliasType: defaultIAMAlias, + RoleID: "testRoleID", + }, + svcAccount: &iam.ServiceAccount{ + UniqueId: "iamUniqueID", + }, + expectedAlias: "iamUniqueID", + expectErr: false, + }, + "role_id": { + role: &gcpRole{ + IAMAliasType: "role_id", + RoleID: "testRoleID", + }, + svcAccount: &iam.ServiceAccount{ + UniqueId: "iamUniqueID", + }, + expectedAlias: "testRoleID", + expectErr: false, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + actualAlias, err := getIAMAlias(test.role, test.svcAccount) + if test.expectErr && err == nil { + t.Fatalf("err expected, got nil") + } + if !test.expectErr && err != nil { + t.Fatalf("no error expected, got: %s", err) + } + if actualAlias != test.expectedAlias { + t.Fatalf("Actual alias: %s Expected Alias: %s", actualAlias, test.expectedAlias) + } + }) + } +} + +func TestGetGCEAlias(t *testing.T) { + type testCase struct { + role *gcpRole + instance *compute.Instance + expectedAlias string + expectErr bool + } + + tests := map[string]testCase{ + "invalid type": { + role: &gcpRole{ + GCEAliasType: "bogus", + RoleID: "testRoleID", + }, + instance: &compute.Instance{ + Id: 123, + }, + expectedAlias: "", + expectErr: true, + }, + "empty type goes to default": { + role: &gcpRole{ + GCEAliasType: "", + RoleID: "testRoleID", + }, + instance: &compute.Instance{ + Id: 123, + }, + expectedAlias: "gce-123", + expectErr: false, + }, + "default type": { + role: &gcpRole{ + GCEAliasType: defaultGCEAlias, + RoleID: "testRoleID", + }, + instance: &compute.Instance{ + Id: 123, + }, + expectedAlias: "gce-123", + expectErr: false, + }, + "role_id": { + role: &gcpRole{ + GCEAliasType: "role_id", + RoleID: "testRoleID", + }, + instance: &compute.Instance{ + Id: 123, + }, + expectedAlias: "testRoleID", + expectErr: false, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + actualAlias, err := getGCEAlias(test.role, test.instance) + if test.expectErr && err == nil { + t.Fatalf("err expected, got nil") + } + if !test.expectErr && err != nil { + t.Fatalf("no error expected, got: %s", err) + } + if actualAlias != test.expectedAlias { + t.Fatalf("Actual alias: %s Expected Alias: %s", actualAlias, test.expectedAlias) + } + }) + } +} diff --git a/plugin/gcp_role.go b/plugin/gcp_role.go new file mode 100644 index 00000000..8ec4d292 --- /dev/null +++ b/plugin/gcp_role.go @@ -0,0 +1,429 @@ +package gcpauth + +import ( + "errors" + "fmt" + "strings" + "time" + + "github.com/hashicorp/go-gcp-common/gcputil" + "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/helper/strutil" + "github.com/hashicorp/vault/sdk/helper/tokenutil" + "github.com/hashicorp/vault/sdk/logical" +) + +const ( + currentGCPRoleVersion = 1 +) + +type gcpRole struct { + tokenutil.TokenParams + + // RoleID is a unique identifier for this role. + RoleID string `json:"role_id"` + + // Type of this role. See path_role constants for currently supported types. + RoleType string `json:"role_type,omitempty"` + + // Policies for Vault to assign to authorized entities. + Policies []string `json:"policies,omitempty"` + + // TTL of Vault auth leases under this role. + TTL time.Duration `json:"ttl,omitempty"` + + // Max total TTL including renewals, of Vault auth leases under this role. + MaxTTL time.Duration `json:"max_ttl,omitempty"` + + // Period, If set, indicates that this token should not expire and + // should be automatically renewed within this time period + // with TTL equal to this value. + Period time.Duration `json:"period,omitempty"` + + // Projects that entities must belong to + BoundProjects []string `json:"bound_projects,omitempty"` + + // Service accounts allowed to login under this role. + BoundServiceAccounts []string `json:"bound_service_accounts,omitempty"` + + // AddGroupAliases adds Vault group aliases to the response. + AddGroupAliases bool `json:"add_group_aliases,omitempty"` + + // --| IAM-only attributes |-- + // MaxJwtExp is the duration from time of authentication that a JWT used to authenticate to role must expire within. + // TODO(emilymye): Allow this to be updated for GCE roles once 'exp' parameter has been allowed for GCE metadata. + MaxJwtExp time.Duration `json:"max_jwt_exp,omitempty"` + + // AllowGCEInference, if false, does not allow a GCE instance to login under this 'iam' role. If true (default), + // a service account is inferred from the instance metadata and used as the authenticating instance. + AllowGCEInference bool `json:"allow_gce_inference,omitempty"` + + // --| GCE-only attributes |-- + // BoundRegions that instances must belong to in order to login under this role. + BoundRegions []string `json:"bound_regions,omitempty"` + + // BoundZones that instances must belong to in order to login under this role. + BoundZones []string `json:"bound_zones,omitempty"` + + // BoundInstanceGroups are the instance group that instances must belong to in order to login under this role. + BoundInstanceGroups []string `json:"bound_instance_groups,omitempty"` + + // BoundLabels that instances must currently have set in order to login under this role. + BoundLabels map[string]string `json:"bound_labels,omitempty"` + + // IAMAliasType specifies the alias name to use with IAM roles. Can be either "unique_id" (default) or "role_id" + IAMAliasType string `json:"iam_alias,omitempty"` + + // GCEAliasType specifies the alias name to use with GCE roles. Can be either "instance_id" (default) or "role_id" + GCEAliasType string `json:"gce_alias,omitempty"` + + // Version indicates the version of this configuration. Allows for more advanced logic around + // upgrades and different behavior between config versions. + Version int `json:"version,omitempty"` + + // Deprecated fields + // TODO: Remove in 0.5.0+ + ProjectId string `json:"project_id,omitempty"` + BoundRegion string `json:"bound_region,omitempty"` + BoundZone string `json:"bound_zone,omitempty"` + BoundInstanceGroup string `json:"bound_instance_group,omitempty"` +} + +// updateRole updates the given role with values parsed/validated from given FieldData. +// Exactly one of the response and error will be nil. The response is only used to pass back warnings. +// This method does not validate the role. Validation is done before storage. +func (role *gcpRole) updateRole(sys logical.SystemView, req *logical.Request, data *framework.FieldData) (warnings []string, err error) { + if e := role.ParseTokenFields(req, data); e != nil { + return nil, e + } + + // Handle token field upgrades + { + if e := tokenutil.UpgradeValue(data, "policies", "token_policies", &role.Policies, &role.TokenPolicies); e != nil { + return nil, e + } + + if e := tokenutil.UpgradeValue(data, "ttl", "token_ttl", &role.TTL, &role.TokenTTL); e != nil { + return nil, e + } + + if e := tokenutil.UpgradeValue(data, "max_ttl", "token_max_ttl", &role.MaxTTL, &role.TokenMaxTTL); e != nil { + return nil, e + } + + if e := tokenutil.UpgradeValue(data, "period", "token_period", &role.Period, &role.TokenPeriod); e != nil { + return nil, e + } + } + + // Set role type + if rt, ok := data.GetOk("type"); ok { + roleType := rt.(string) + if role.RoleType != roleType && req.Operation == logical.UpdateOperation { + return nil, fmt.Errorf("role type cannot be changed for an existing role") + } + role.RoleType = roleType + } else if req.Operation == logical.CreateOperation { + return nil, fmt.Errorf(errEmptyRoleType) + } + + def := sys.DefaultLeaseTTL() + if role.TokenTTL > def { + warnings = append(warnings, fmt.Sprintf(`Given token ttl of %q is greater `+ + `than the maximum system/mount TTL of %q. The TTL will be capped at `+ + `%q during login.`, role.TokenTTL, def, def)) + } + + // Update token Max TTL. + def = sys.MaxLeaseTTL() + if role.TokenMaxTTL > def { + warnings = append(warnings, fmt.Sprintf(`Given token max ttl of %q is greater `+ + `than the maximum system/mount MaxTTL of %q. The MaxTTL will be `+ + `capped at %q during login.`, role.TokenMaxTTL, def, def)) + } + if role.TokenPeriod > def { + warnings = append(warnings, fmt.Sprintf(`Given token period of %q is greater `+ + `than the maximum system/mount period of %q. The period will be `+ + `capped at %q during login.`, role.TokenPeriod, def, def)) + } + + // Update bound GCP service accounts. + if sa, ok := data.GetOk("bound_service_accounts"); ok { + role.BoundServiceAccounts = sa.([]string) + } else { + // Check for older version of param name + if sa, ok := data.GetOk("service_accounts"); ok { + warnings = append(warnings, `The "service_accounts" field is deprecated. `+ + `Please use "bound_service_accounts" instead. The "service_accounts" `+ + `field will be removed in a later release, so please update accordingly.`) + role.BoundServiceAccounts = sa.([]string) + } + } + if len(role.BoundServiceAccounts) > 0 { + role.BoundServiceAccounts = strutil.TrimStrings(role.BoundServiceAccounts) + role.BoundServiceAccounts = strutil.RemoveDuplicates(role.BoundServiceAccounts, false) + } + + // Update bound GCP projects. + boundProjects, givenBoundProj := data.GetOk("bound_projects") + if givenBoundProj { + role.BoundProjects = boundProjects.([]string) + } + if projectId, ok := data.GetOk("project_id"); ok { + if givenBoundProj { + return warnings, errors.New("only one of 'bound_projects' or 'project_id' can be given") + } + warnings = append(warnings, + `The "project_id" (singular) field is deprecated. `+ + `Please use plural "bound_projects" instead to bind required GCP projects. `+ + `The "project_id" field will be removed in a later release, so please update accordingly.`) + role.BoundProjects = []string{projectId.(string)} + } + if len(role.BoundProjects) > 0 { + role.BoundProjects = strutil.TrimStrings(role.BoundProjects) + role.BoundProjects = strutil.RemoveDuplicates(role.BoundProjects, false) + } + + // Update bound GCP projects. + addGroupAliases, ok := data.GetOk("add_group_aliases") + if ok { + role.AddGroupAliases = addGroupAliases.(bool) + } + + // Update fields specific to this type + switch role.RoleType { + case iamRoleType: + if err = checkInvalidRoleTypeArgs(data, gceOnlyFieldSchema); err != nil { + return warnings, err + } + if warnings, err = role.updateIamFields(data, req.Operation); err != nil { + return warnings, err + } + iamAliasType, ok := data.GetOk("iam_alias") + if ok { + role.IAMAliasType = iamAliasType.(string) + } + case gceRoleType: + if err = checkInvalidRoleTypeArgs(data, iamOnlyFieldSchema); err != nil { + return warnings, err + } + if warnings, err = role.updateGceFields(data, req.Operation); err != nil { + return warnings, err + } + gceAliasType, ok := data.GetOk("gce_alias") + if ok { + role.GCEAliasType = gceAliasType.(string) + } + } + + return warnings, nil +} + +func (role *gcpRole) validate(sys logical.SystemView) (warnings []string, err error) { + warnings = []string{} + + switch role.RoleType { + case iamRoleType: + if warnings, err = role.validateForIAM(); err != nil { + return warnings, err + } + case gceRoleType: + if warnings, err = role.validateForGCE(); err != nil { + return warnings, err + } + case "": + return warnings, errors.New(errEmptyRoleType) + default: + return warnings, fmt.Errorf("role type '%s' is invalid", role.RoleType) + } + + defaultLeaseTTL := sys.DefaultLeaseTTL() + if role.TokenTTL > defaultLeaseTTL { + warnings = append(warnings, fmt.Sprintf( + "Given ttl of %d seconds greater than current mount/system default of %d seconds; ttl will be capped at login time", + role.TokenTTL/time.Second, defaultLeaseTTL/time.Second)) + } + + defaultMaxTTL := sys.MaxLeaseTTL() + if role.TokenMaxTTL > defaultMaxTTL { + warnings = append(warnings, fmt.Sprintf( + "Given max_ttl of %d seconds greater than current mount/system default of %d seconds; max_ttl will be capped at login time", + role.TokenMaxTTL/time.Second, defaultMaxTTL/time.Second)) + } + if role.TokenMaxTTL < time.Duration(0) { + return warnings, errors.New("max_ttl cannot be negative") + } + if role.TokenMaxTTL != 0 && role.TokenMaxTTL < role.TokenTTL { + return warnings, errors.New("ttl should be shorter than max_ttl") + } + + if role.TokenPeriod > sys.MaxLeaseTTL() { + return warnings, fmt.Errorf("'period' of '%s' is greater than the backend's maximum lease TTL of '%s'", role.TokenPeriod.String(), sys.MaxLeaseTTL().String()) + } + + if _, exists := allowedIAMAliases[role.IAMAliasType]; !exists { + return warnings, fmt.Errorf("iam_alias must be one of: %s", strings.Join(allowedIAMAliasesSlice, ", ")) + } + if _, exists := allowedGCEAliases[role.GCEAliasType]; !exists { + return warnings, fmt.Errorf("gce_alias must be one of: %s", strings.Join(allowedGCEAliasesSlice, ", ")) + } + + return warnings, nil +} + +// updateIamFields updates IAM-only fields for a role. +func (role *gcpRole) updateIamFields(data *framework.FieldData, op logical.Operation) (warnings []string, err error) { + if allowGCEInference, ok := data.GetOk("allow_gce_inference"); ok { + role.AllowGCEInference = allowGCEInference.(bool) + } else if op == logical.CreateOperation { + role.AllowGCEInference = data.Get("allow_gce_inference").(bool) + } + + if maxJwtExp, ok := data.GetOk("max_jwt_exp"); ok { + role.MaxJwtExp = time.Duration(maxJwtExp.(int)) * time.Second + } else if op == logical.CreateOperation { + role.MaxJwtExp = time.Duration(defaultIamMaxJwtExpMinutes) * time.Minute + } + + if role.IAMAliasType == "" { + role.IAMAliasType = defaultIAMAlias + } + + return warnings, nil +} + +// updateGceFields updates GCE-only fields for a role. +func (role *gcpRole) updateGceFields(data *framework.FieldData, op logical.Operation) (warnings []string, err error) { + if regions, ok := data.GetOk("bound_regions"); ok { + role.BoundRegions = regions.([]string) + } else if op == logical.CreateOperation { + role.BoundRegions = data.Get("bound_regions").([]string) + } + + if zones, ok := data.GetOk("bound_zones"); ok { + role.BoundZones = zones.([]string) + } else if op == logical.CreateOperation { + role.BoundZones = data.Get("bound_zones").([]string) + } + + if instanceGroups, ok := data.GetOk("bound_instance_groups"); ok { + role.BoundInstanceGroups = instanceGroups.([]string) + } else if op == logical.CreateOperation { + role.BoundInstanceGroups = data.Get("bound_instance_groups").([]string) + } + + if boundRegion, ok := data.GetOk("bound_region"); ok { + if _, ok := data.GetOk("bound_regions"); ok { + return warnings, fmt.Errorf(`cannot specify both "bound_region" and "bound_regions"`) + } + + warnings = append(warnings, `The "bound_region" field is deprecated. `+ + `Please use "bound_regions" (plural) instead. You can still specify a `+ + `single region, but multiple regions are also now supported. The `+ + `"bound_region" field will be removed in a later release, so please `+ + `update accordingly.`) + role.BoundRegions = append(role.BoundRegions, boundRegion.(string)) + } + + if boundZone, ok := data.GetOk("bound_zone"); ok { + if _, ok := data.GetOk("bound_zones"); ok { + return warnings, fmt.Errorf(`cannot specify both "bound_zone" and "bound_zones"`) + } + + warnings = append(warnings, `The "bound_zone" field is deprecated. `+ + `Please use "bound_zones" (plural) instead. You can still specify a `+ + `single zone, but multiple zones are also now supported. The `+ + `"bound_zone" field will be removed in a later release, so please `+ + `update accordingly.`) + role.BoundZones = append(role.BoundZones, boundZone.(string)) + } + + if boundInstanceGroup, ok := data.GetOk("bound_instance_group"); ok { + if _, ok := data.GetOk("bound_instance_groups"); ok { + return warnings, fmt.Errorf(`cannot specify both "bound_instance_group" and "bound_instance_groups"`) + } + + warnings = append(warnings, `The "bound_instance_group" field is deprecated. `+ + `Please use "bound_instance_groups" (plural) instead. You can still specify a `+ + `single instance group, but multiple instance groups are also now supported. The `+ + `"bound_instance_group" field will be removed in a later release, so please `+ + `update accordingly.`) + role.BoundInstanceGroups = append(role.BoundInstanceGroups, boundInstanceGroup.(string)) + } + + if labelsRaw, ok := data.GetOk("bound_labels"); ok { + labels, invalidLabels := gcputil.ParseGcpLabels(labelsRaw.([]string)) + if len(invalidLabels) > 0 { + return warnings, fmt.Errorf("invalid labels given: %q", invalidLabels) + } + role.BoundLabels = labels + } + + if len(role.Policies) > 0 { + role.Policies = strutil.TrimStrings(role.Policies) + role.Policies = strutil.RemoveDuplicates(role.Policies, false) + } + + if len(role.BoundRegions) > 0 { + role.BoundRegions = strutil.TrimStrings(role.BoundRegions) + role.BoundRegions = strutil.RemoveDuplicates(role.BoundRegions, false) + } + + if len(role.BoundZones) > 0 { + role.BoundZones = strutil.TrimStrings(role.BoundZones) + role.BoundZones = strutil.RemoveDuplicates(role.BoundZones, false) + } + + if len(role.BoundInstanceGroups) > 0 { + role.BoundInstanceGroups = strutil.TrimStrings(role.BoundInstanceGroups) + role.BoundInstanceGroups = strutil.RemoveDuplicates(role.BoundInstanceGroups, false) + } + + if role.GCEAliasType == "" { + role.GCEAliasType = defaultGCEAlias + } + + return warnings, nil +} + +// validateIamFields validates the IAM-only fields for a role. +func (role *gcpRole) validateForIAM() (warnings []string, err error) { + if len(role.BoundServiceAccounts) == 0 { + return []string{}, errors.New(errEmptyIamServiceAccounts) + } + + if len(role.BoundServiceAccounts) > 1 && strutil.StrListContains(role.BoundServiceAccounts, serviceAccountsWildcard) { + return []string{}, fmt.Errorf("cannot provide IAM service account wildcard '%s' (for all service accounts) with other service accounts", serviceAccountsWildcard) + } + + maxMaxJwtExp := time.Duration(maxJwtExpMaxMinutes) * time.Minute + if role.MaxJwtExp > maxMaxJwtExp { + return warnings, fmt.Errorf("max_jwt_exp cannot be more than %d minutes", maxJwtExpMaxMinutes) + } + + return []string{}, nil +} + +// validateGceFields validates the GCE-only fields for a role. +func (role *gcpRole) validateForGCE() (warnings []string, err error) { + warnings = []string{} + + hasRegion := len(role.BoundRegions) > 0 + hasZone := len(role.BoundZones) > 0 + hasRegionOrZone := hasRegion || hasZone + + hasInstanceGroup := len(role.BoundInstanceGroups) > 0 + + if hasInstanceGroup && !hasRegionOrZone { + return warnings, errors.New(`region or zone information must be specified if an instance group is given`) + } + + if hasRegion && hasZone { + warnings = append(warnings, `Given both "bound_regions" and "bound_zones" `+ + `fields for role type "gce", "bound_regions" will be ignored in favor `+ + `of the more specific "bound_zones" field. To fix this warning, update `+ + `the role to remove either the "bound_regions" or "bound_zones" field.`) + } + + return warnings, nil +} diff --git a/plugin/mocks_test.go b/plugin/mocks_test.go new file mode 100644 index 00000000..6785ab2b --- /dev/null +++ b/plugin/mocks_test.go @@ -0,0 +1,308 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/hashicorp/vault/sdk/logical (interfaces: SystemView,Storage) + +// Package gcpauth is a generated GoMock package. +package gcpauth + +import ( + context "context" + gomock "github.com/golang/mock/gomock" + consts "github.com/hashicorp/vault/sdk/helper/consts" + license "github.com/hashicorp/vault/sdk/helper/license" + pluginutil "github.com/hashicorp/vault/sdk/helper/pluginutil" + wrapping "github.com/hashicorp/vault/sdk/helper/wrapping" + logical "github.com/hashicorp/vault/sdk/logical" + reflect "reflect" + time "time" +) + +// MockSystemView is a mock of SystemView interface +type MockSystemView struct { + ctrl *gomock.Controller + recorder *MockSystemViewMockRecorder +} + +// MockSystemViewMockRecorder is the mock recorder for MockSystemView +type MockSystemViewMockRecorder struct { + mock *MockSystemView +} + +// NewMockSystemView creates a new mock instance +func NewMockSystemView(ctrl *gomock.Controller) *MockSystemView { + mock := &MockSystemView{ctrl: ctrl} + mock.recorder = &MockSystemViewMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockSystemView) EXPECT() *MockSystemViewMockRecorder { + return m.recorder +} + +// CachingDisabled mocks base method +func (m *MockSystemView) CachingDisabled() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CachingDisabled") + ret0, _ := ret[0].(bool) + return ret0 +} + +// CachingDisabled indicates an expected call of CachingDisabled +func (mr *MockSystemViewMockRecorder) CachingDisabled() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CachingDisabled", reflect.TypeOf((*MockSystemView)(nil).CachingDisabled)) +} + +// DefaultLeaseTTL mocks base method +func (m *MockSystemView) DefaultLeaseTTL() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DefaultLeaseTTL") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// DefaultLeaseTTL indicates an expected call of DefaultLeaseTTL +func (mr *MockSystemViewMockRecorder) DefaultLeaseTTL() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DefaultLeaseTTL", reflect.TypeOf((*MockSystemView)(nil).DefaultLeaseTTL)) +} + +// EntityInfo mocks base method +func (m *MockSystemView) EntityInfo(arg0 string) (*logical.Entity, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EntityInfo", arg0) + ret0, _ := ret[0].(*logical.Entity) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EntityInfo indicates an expected call of EntityInfo +func (mr *MockSystemViewMockRecorder) EntityInfo(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EntityInfo", reflect.TypeOf((*MockSystemView)(nil).EntityInfo), arg0) +} + +// GroupsForEntity mocks base method +func (m *MockSystemView) GroupsForEntity(arg0 string) ([]*logical.Group, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GroupsForEntity", arg0) + ret0, _ := ret[0].([]*logical.Group) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GroupsForEntity indicates an expected call of GroupsForEntity +func (mr *MockSystemViewMockRecorder) GroupsForEntity(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupsForEntity", reflect.TypeOf((*MockSystemView)(nil).GroupsForEntity), arg0) +} + +// HasFeature mocks base method +func (m *MockSystemView) HasFeature(arg0 license.Features) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasFeature", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasFeature indicates an expected call of HasFeature +func (mr *MockSystemViewMockRecorder) HasFeature(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasFeature", reflect.TypeOf((*MockSystemView)(nil).HasFeature), arg0) +} + +// LocalMount mocks base method +func (m *MockSystemView) LocalMount() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalMount") + ret0, _ := ret[0].(bool) + return ret0 +} + +// LocalMount indicates an expected call of LocalMount +func (mr *MockSystemViewMockRecorder) LocalMount() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalMount", reflect.TypeOf((*MockSystemView)(nil).LocalMount)) +} + +// LookupPlugin mocks base method +func (m *MockSystemView) LookupPlugin(arg0 context.Context, arg1 string, arg2 consts.PluginType) (*pluginutil.PluginRunner, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LookupPlugin", arg0, arg1, arg2) + ret0, _ := ret[0].(*pluginutil.PluginRunner) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LookupPlugin indicates an expected call of LookupPlugin +func (mr *MockSystemViewMockRecorder) LookupPlugin(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupPlugin", reflect.TypeOf((*MockSystemView)(nil).LookupPlugin), arg0, arg1, arg2) +} + +// MaxLeaseTTL mocks base method +func (m *MockSystemView) MaxLeaseTTL() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MaxLeaseTTL") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// MaxLeaseTTL indicates an expected call of MaxLeaseTTL +func (mr *MockSystemViewMockRecorder) MaxLeaseTTL() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaxLeaseTTL", reflect.TypeOf((*MockSystemView)(nil).MaxLeaseTTL)) +} + +// MlockEnabled mocks base method +func (m *MockSystemView) MlockEnabled() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MlockEnabled") + ret0, _ := ret[0].(bool) + return ret0 +} + +// MlockEnabled indicates an expected call of MlockEnabled +func (mr *MockSystemViewMockRecorder) MlockEnabled() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MlockEnabled", reflect.TypeOf((*MockSystemView)(nil).MlockEnabled)) +} + +// PluginEnv mocks base method +func (m *MockSystemView) PluginEnv(arg0 context.Context) (*logical.PluginEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PluginEnv", arg0) + ret0, _ := ret[0].(*logical.PluginEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PluginEnv indicates an expected call of PluginEnv +func (mr *MockSystemViewMockRecorder) PluginEnv(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PluginEnv", reflect.TypeOf((*MockSystemView)(nil).PluginEnv), arg0) +} + +// ReplicationState mocks base method +func (m *MockSystemView) ReplicationState() consts.ReplicationState { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReplicationState") + ret0, _ := ret[0].(consts.ReplicationState) + return ret0 +} + +// ReplicationState indicates an expected call of ReplicationState +func (mr *MockSystemViewMockRecorder) ReplicationState() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplicationState", reflect.TypeOf((*MockSystemView)(nil).ReplicationState)) +} + +// ResponseWrapData mocks base method +func (m *MockSystemView) ResponseWrapData(arg0 context.Context, arg1 map[string]interface{}, arg2 time.Duration, arg3 bool) (*wrapping.ResponseWrapInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResponseWrapData", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*wrapping.ResponseWrapInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ResponseWrapData indicates an expected call of ResponseWrapData +func (mr *MockSystemViewMockRecorder) ResponseWrapData(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResponseWrapData", reflect.TypeOf((*MockSystemView)(nil).ResponseWrapData), arg0, arg1, arg2, arg3) +} + +// Tainted mocks base method +func (m *MockSystemView) Tainted() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Tainted") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Tainted indicates an expected call of Tainted +func (mr *MockSystemViewMockRecorder) Tainted() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Tainted", reflect.TypeOf((*MockSystemView)(nil).Tainted)) +} + +// MockStorage is a mock of Storage interface +type MockStorage struct { + ctrl *gomock.Controller + recorder *MockStorageMockRecorder +} + +// MockStorageMockRecorder is the mock recorder for MockStorage +type MockStorageMockRecorder struct { + mock *MockStorage +} + +// NewMockStorage creates a new mock instance +func NewMockStorage(ctrl *gomock.Controller) *MockStorage { + mock := &MockStorage{ctrl: ctrl} + mock.recorder = &MockStorageMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStorage) EXPECT() *MockStorageMockRecorder { + return m.recorder +} + +// Delete mocks base method +func (m *MockStorage) Delete(arg0 context.Context, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete +func (mr *MockStorageMockRecorder) Delete(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockStorage)(nil).Delete), arg0, arg1) +} + +// Get mocks base method +func (m *MockStorage) Get(arg0 context.Context, arg1 string) (*logical.StorageEntry, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0, arg1) + ret0, _ := ret[0].(*logical.StorageEntry) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get +func (mr *MockStorageMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockStorage)(nil).Get), arg0, arg1) +} + +// List mocks base method +func (m *MockStorage) List(arg0 context.Context, arg1 string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", arg0, arg1) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List +func (mr *MockStorageMockRecorder) List(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockStorage)(nil).List), arg0, arg1) +} + +// Put mocks base method +func (m *MockStorage) Put(arg0 context.Context, arg1 *logical.StorageEntry) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Put", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Put indicates an expected call of Put +func (mr *MockStorageMockRecorder) Put(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockStorage)(nil).Put), arg0, arg1) +} diff --git a/plugin/path_config.go b/plugin/path_config.go index 856b9b6b..309e7b42 100644 --- a/plugin/path_config.go +++ b/plugin/path_config.go @@ -2,9 +2,9 @@ package gcpauth import ( "context" - "errors" - "encoding/json" + "errors" + "net/http" "github.com/hashicorp/errwrap" "github.com/hashicorp/go-gcp-common/gcputil" @@ -44,7 +44,7 @@ Deprecated. This field does nothing and be removed in a future release`, func (b *GcpAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { if err := validateFields(req, d); err != nil { - return nil, logical.CodedError(422, err.Error()) + return nil, logical.CodedError(http.StatusUnprocessableEntity, err.Error()) } c, err := b.config(ctx, req.Storage) @@ -58,7 +58,7 @@ func (b *GcpAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque changed, err := c.Update(d) if err != nil { - return nil, logical.CodedError(400, err.Error()) + return nil, logical.CodedError(http.StatusBadRequest, err.Error()) } // Only do the following if the config is different @@ -83,7 +83,7 @@ func (b *GcpAuthBackend) pathConfigWrite(ctx context.Context, req *logical.Reque func (b *GcpAuthBackend) pathConfigRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { if err := validateFields(req, d); err != nil { - return nil, logical.CodedError(422, err.Error()) + return nil, logical.CodedError(http.StatusUnprocessableEntity, err.Error()) } config, err := b.config(ctx, req.Storage) diff --git a/plugin/path_login.go b/plugin/path_login.go index 35a2d01e..855f86bb 100644 --- a/plugin/path_login.go +++ b/plugin/path_login.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net/http" "strconv" "strings" "time" @@ -59,7 +60,7 @@ GCE identity metadata token ('iam', 'gce' roles).`, func (b *GcpAuthBackend) pathLogin(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { // Validate we didn't get extraneous fields if err := validateFields(req, data); err != nil { - return nil, logical.CodedError(422, err.Error()) + return nil, logical.CodedError(http.StatusUnprocessableEntity, err.Error()) } loginInfo, err := b.parseAndValidateJwt(ctx, req, data) @@ -315,14 +316,20 @@ func (b *GcpAuthBackend) pathIamLogin(ctx context.Context, req *logical.Request, return nil, errors.New("service account is empty") } + alias, err := getIAMAlias(role, serviceAccount) + if err != nil { + return logical.ErrorResponse("unable to create alias: %s", err), nil + } + if req.Operation == logical.AliasLookaheadOperation { - return &logical.Response{ + resp := &logical.Response{ Auth: &logical.Auth{ Alias: &logical.Alias{ - Name: serviceAccount.UniqueId, + Name: alias, }, }, - }, nil + } + return resp, nil } // Validate service account can login against role. @@ -332,7 +339,7 @@ func (b *GcpAuthBackend) pathIamLogin(ctx context.Context, req *logical.Request, auth := &logical.Auth{ Alias: &logical.Alias{ - Name: serviceAccount.UniqueId, + Name: alias, }, Metadata: authMetadata(loginInfo, serviceAccount), DisplayName: serviceAccount.Email, @@ -449,11 +456,16 @@ func (b *GcpAuthBackend) pathGceLogin(ctx context.Context, req *logical.Request, return logical.ErrorResponse(err.Error()), nil } + alias, err := getGCEAlias(role, instance) + if err != nil { + return logical.ErrorResponse("unable to create alias: %s", err), nil + } + if req.Operation == logical.AliasLookaheadOperation { return &logical.Response{ Auth: &logical.Auth{ Alias: &logical.Alias{ - Name: fmt.Sprintf("gce-%s", strconv.FormatUint(instance.Id, 10)), + Name: alias, }, }, }, nil @@ -475,7 +487,7 @@ func (b *GcpAuthBackend) pathGceLogin(ctx context.Context, req *logical.Request, auth := &logical.Auth{ InternalData: map[string]interface{}{}, Alias: &logical.Alias{ - Name: fmt.Sprintf("gce-%s", strconv.FormatUint(instance.Id, 10)), + Name: alias, }, Metadata: authMetadata(loginInfo, serviceAccount), DisplayName: instance.Name, diff --git a/plugin/path_login_test.go b/plugin/path_login_test.go index 9085800f..650b15f5 100644 --- a/plugin/path_login_test.go +++ b/plugin/path_login_test.go @@ -160,7 +160,7 @@ func TestLogin_IAM(t *testing.T) { } for _, tc := range cases { - tc := tc + tc := tc // Since the t.Run is parallel, this is needed to prevent scope sharing between loops t.Run(tc.name, func(t *testing.T) { t.Parallel() diff --git a/plugin/path_role.go b/plugin/path_role.go index 3acc6b90..e8335223 100644 --- a/plugin/path_role.go +++ b/plugin/path_role.go @@ -2,15 +2,14 @@ package gcpauth import ( "context" - "errors" "fmt" + "net/http" "strings" - "time" "github.com/hashicorp/go-gcp-common/gcputil" + "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/sdk/framework" vaultconsts "github.com/hashicorp/vault/sdk/helper/consts" - "github.com/hashicorp/vault/sdk/helper/strutil" "github.com/hashicorp/vault/sdk/helper/tokenutil" "github.com/hashicorp/vault/sdk/logical" ) @@ -105,6 +104,11 @@ var iamOnlyFieldSchema = map[string]*framework.FieldSchema{ Default: true, Description: `'iam' roles only. If false, Vault will not not allow GCE instances to login in against this role`, }, + "iam_alias": { + Type: framework.TypeString, + Default: defaultIAMAlias, + Description: "Indicates what value to use when generating an alias for IAM authentications.", + }, } var gceOnlyFieldSchema = map[string]*framework.FieldSchema{ @@ -136,6 +140,11 @@ var gceOnlyFieldSchema = map[string]*framework.FieldSchema{ "\"key:value\" strings that must be present on the GCE instance " + "in order to authenticate. This option only applies to \"gce\" roles.", }, + "gce_alias": { + Type: framework.TypeString, + Default: defaultGCEAlias, + Description: "Indicates what value to use when generating an alias for GCE authentications.", + }, } // pathsRole creates paths for listing roles and CRUD operations. @@ -274,64 +283,69 @@ func (b *GcpAuthBackend) pathRoleRead(ctx context.Context, req *logical.Request, return nil, nil } - resp := make(map[string]interface{}) - role.PopulateTokenData(resp) + respData := make(map[string]interface{}) + role.PopulateTokenData(respData) + + respData["role_id"] = role.RoleID if role.RoleType != "" { - resp["type"] = role.RoleType + respData["type"] = role.RoleType } if len(role.BoundServiceAccounts) > 0 { - resp["bound_service_accounts"] = role.BoundServiceAccounts + respData["bound_service_accounts"] = role.BoundServiceAccounts } if len(role.BoundProjects) > 0 { - resp["bound_projects"] = role.BoundProjects + respData["bound_projects"] = role.BoundProjects } - resp["add_group_aliases"] = role.AddGroupAliases + respData["add_group_aliases"] = role.AddGroupAliases switch role.RoleType { case iamRoleType: if role.MaxJwtExp != 0 { - resp["max_jwt_exp"] = int64(role.MaxJwtExp.Seconds()) + respData["max_jwt_exp"] = int64(role.MaxJwtExp.Seconds()) } - resp["allow_gce_inference"] = role.AllowGCEInference + respData["allow_gce_inference"] = role.AllowGCEInference + respData["iam_alias"] = role.IAMAliasType case gceRoleType: if len(role.BoundRegions) > 0 { - resp["bound_regions"] = role.BoundRegions + respData["bound_regions"] = role.BoundRegions } if len(role.BoundZones) > 0 { - resp["bound_zones"] = role.BoundZones + respData["bound_zones"] = role.BoundZones } if len(role.BoundInstanceGroups) > 0 { - resp["bound_instance_groups"] = role.BoundInstanceGroups + respData["bound_instance_groups"] = role.BoundInstanceGroups } if len(role.BoundLabels) > 0 { - resp["bound_labels"] = role.BoundLabels + respData["bound_labels"] = role.BoundLabels } + respData["gce_alias"] = role.GCEAliasType } // Upgrade vals if len(role.Policies) > 0 { - resp["policies"] = resp["token_policies"] + respData["policies"] = respData["token_policies"] } if role.TTL > 0 { - resp["ttl"] = int64(role.TTL.Seconds()) + respData["ttl"] = int64(role.TTL.Seconds()) } if role.MaxTTL > 0 { - resp["max_ttl"] = int64(role.MaxTTL.Seconds()) + respData["max_ttl"] = int64(role.MaxTTL.Seconds()) } if role.Period > 0 { - resp["period"] = int64(role.Period.Seconds()) + respData["period"] = int64(role.Period.Seconds()) } - return &logical.Response{ - Data: resp, - }, nil + resp := &logical.Response{ + Data: respData, + } + return resp, nil } func (b *GcpAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { // Validate we didn't get extraneous fields if err := validateFields(req, data); err != nil { - return nil, logical.CodedError(422, err.Error()) + return nil, logical.CodedError(http.StatusUnprocessableEntity, err.Error()) } name := strings.ToLower(data.Get("name").(string)) @@ -347,6 +361,14 @@ func (b *GcpAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical. role = &gcpRole{} } + if role.RoleID == "" { + roleID, err := uuid.GenerateUUID() + if err != nil { + return nil, logical.CodedError(http.StatusInternalServerError, fmt.Sprintf("unable to generate roleID: %s", err)) + } + role.RoleID = roleID + } + warnings, err := role.updateRole(b.System(), req, data) if err != nil { resp := logical.ErrorResponse(err.Error()) @@ -380,7 +402,7 @@ const pathListRolesHelpDesc = `Lists all roles under the GCP backends by name.` func (b *GcpAuthBackend) pathRoleEditIamServiceAccounts(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { // Validate we didn't get extraneous fields if err := validateFields(req, data); err != nil { - return nil, logical.CodedError(422, err.Error()) + return nil, logical.CodedError(http.StatusUnprocessableEntity, err.Error()) } var warnings []string @@ -437,7 +459,7 @@ func editStringValues(initial []string, toAdd []string, toRemove []string) []str func (b *GcpAuthBackend) pathRoleEditGceLabels(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { // Validate we didn't get extraneous fields if err := validateFields(req, data); err != nil { - return nil, logical.CodedError(422, err.Error()) + return nil, logical.CodedError(http.StatusUnprocessableEntity, err.Error()) } var warnings []string @@ -480,7 +502,7 @@ func (b *GcpAuthBackend) pathRoleEditGceLabels(ctx context.Context, req *logical return b.storeRole(ctx, req.Storage, roleName, role, warnings) } -// role reads a gcpRole from storage. This assumes the caller has already obtained the role lock. +// role from storage. This assumes the caller has already obtained the role lock. func (b *GcpAuthBackend) role(ctx context.Context, s logical.Storage, name string) (*gcpRole, error) { name = strings.ToLower(name) @@ -540,6 +562,25 @@ func (b *GcpAuthBackend) role(ctx context.Context, s logical.Storage, name strin modified = true } + if role.RoleType == "iam" && role.IAMAliasType == "" { + role.IAMAliasType = defaultIAMAlias + modified = true + } + if role.RoleType == "gce" && role.GCEAliasType == "" { + role.GCEAliasType = defaultGCEAlias + modified = true + } + + // Ensure the role has a RoleID + if role.RoleID == "" { + roleID, err := uuid.GenerateUUID() + if err != nil { + return nil, logical.CodedError(http.StatusInternalServerError, fmt.Sprintf("unable to generate roleID for role missing an ID: %s", err)) + } + role.RoleID = roleID + modified = true + } + if modified && (b.System().LocalMount() || !b.System().ReplicationState().HasState(vaultconsts.ReplicationPerformanceSecondary)) { b.Logger().Info("upgrading role to new schema", "role", name) @@ -576,6 +617,19 @@ func (b *GcpAuthBackend) storeRole(ctx context.Context, s logical.Storage, roleN return logical.ErrorResponse(err.Error()), nil } + // Set default alias names + if role.IAMAliasType == "" { + role.IAMAliasType = defaultIAMAlias + } + if role.GCEAliasType == "" { + role.GCEAliasType = defaultGCEAlias + } + + // Ensure a version is specified + if role.Version == 0 { + role.Version = currentGCPRoleVersion + } + entry, err := logical.StorageEntryJSON(fmt.Sprintf("role/%s", roleName), role) if err != nil { return nil, err @@ -587,387 +641,6 @@ func (b *GcpAuthBackend) storeRole(ctx context.Context, s logical.Storage, roleN return &resp, nil } -type gcpRole struct { - tokenutil.TokenParams - - // Type of this role. See path_role constants for currently supported types. - RoleType string `json:"role_type,omitempty"` - - // Policies for Vault to assign to authorized entities. - Policies []string `json:"policies,omitempty"` - - // TTL of Vault auth leases under this role. - TTL time.Duration `json:"ttl,omitempty"` - - // Max total TTL including renewals, of Vault auth leases under this role. - MaxTTL time.Duration `json:"max_ttl,omitempty"` - - // Period, If set, indicates that this token should not expire and - // should be automatically renewed within this time period - // with TTL equal to this value. - Period time.Duration `json:"period,omitempty"` - - // Projects that entities must belong to - BoundProjects []string `json:"bound_projects,omitempty"` - - // Service accounts allowed to login under this role. - BoundServiceAccounts []string `json:"bound_service_accounts,omitempty"` - - // AddGroupAliases adds Vault group aliases to the response. - AddGroupAliases bool `json:"add_group_aliases,omitempty"` - - // --| IAM-only attributes |-- - // MaxJwtExp is the duration from time of authentication that a JWT used to authenticate to role must expire within. - // TODO(emilymye): Allow this to be updated for GCE roles once 'exp' parameter has been allowed for GCE metadata. - MaxJwtExp time.Duration `json:"max_jwt_exp,omitempty"` - - // AllowGCEInference, if false, does not allow a GCE instance to login under this 'iam' role. If true (default), - // a service account is inferred from the instance metadata and used as the authenticating instance. - AllowGCEInference bool `json:"allow_gce_inference,omitempty"` - - // --| GCE-only attributes |-- - // BoundRegions that instances must belong to in order to login under this role. - BoundRegions []string `json:"bound_regions,omitempty"` - - // BoundZones that instances must belong to in order to login under this role. - BoundZones []string `json:"bound_zones,omitempty"` - - // BoundInstanceGroups are the instance group that instances must belong to in order to login under this role. - BoundInstanceGroups []string `json:"bound_instance_groups,omitempty"` - - // BoundLabels that instances must currently have set in order to login under this role. - BoundLabels map[string]string `json:"bound_labels,omitempty"` - - // Deprecated fields - // TODO: Remove in 0.5.0+ - ProjectId string `json:"project_id,omitempty"` - BoundRegion string `json:"bound_region,omitempty"` - BoundZone string `json:"bound_zone,omitempty"` - BoundInstanceGroup string `json:"bound_instance_group,omitempty"` -} - -// Update updates the given role with values parsed/validated from given FieldData. -// Exactly one of the response and error will be nil. The response is only used to pass back warnings. -// This method does not validate the role. Validation is done before storage. -func (role *gcpRole) updateRole(sys logical.SystemView, req *logical.Request, data *framework.FieldData) (warnings []string, err error) { - if e := role.ParseTokenFields(req, data); e != nil { - return nil, e - } - - // Handle token field upgrades - { - if e := tokenutil.UpgradeValue(data, "policies", "token_policies", &role.Policies, &role.TokenPolicies); e != nil { - return nil, e - } - - if e := tokenutil.UpgradeValue(data, "ttl", "token_ttl", &role.TTL, &role.TokenTTL); e != nil { - return nil, e - } - - if e := tokenutil.UpgradeValue(data, "max_ttl", "token_max_ttl", &role.MaxTTL, &role.TokenMaxTTL); e != nil { - return nil, e - } - - if e := tokenutil.UpgradeValue(data, "period", "token_period", &role.Period, &role.TokenPeriod); e != nil { - return nil, e - } - } - - // Set role type - if rt, ok := data.GetOk("type"); ok { - roleType := rt.(string) - if role.RoleType != roleType && req.Operation == logical.UpdateOperation { - err = errors.New("role type cannot be changed for an existing role") - return - } - role.RoleType = roleType - } else if req.Operation == logical.CreateOperation { - err = errors.New(errEmptyRoleType) - return - } - - def := sys.DefaultLeaseTTL() - if role.TokenTTL > def { - warnings = append(warnings, fmt.Sprintf(`Given token ttl of %q is greater `+ - `than the maximum system/mount TTL of %q. The TTL will be capped at `+ - `%q during login.`, role.TokenTTL, def, def)) - } - - // Update token Max TTL. - def = sys.MaxLeaseTTL() - if role.TokenMaxTTL > def { - warnings = append(warnings, fmt.Sprintf(`Given token max ttl of %q is greater `+ - `than the maximum system/mount MaxTTL of %q. The MaxTTL will be `+ - `capped at %q during login.`, role.TokenMaxTTL, def, def)) - } - if role.TokenPeriod > def { - warnings = append(warnings, fmt.Sprintf(`Given token period of %q is greater `+ - `than the maximum system/mount period of %q. The period will be `+ - `capped at %q during login.`, role.TokenPeriod, def, def)) - } - - // Update bound GCP service accounts. - if sa, ok := data.GetOk("bound_service_accounts"); ok { - role.BoundServiceAccounts = sa.([]string) - } else { - // Check for older version of param name - if sa, ok := data.GetOk("service_accounts"); ok { - warnings = append(warnings, `The "service_accounts" field is deprecated. `+ - `Please use "bound_service_accounts" instead. The "service_accounts" `+ - `field will be removed in a later release, so please update accordingly.`) - role.BoundServiceAccounts = sa.([]string) - } - } - if len(role.BoundServiceAccounts) > 0 { - role.BoundServiceAccounts = strutil.TrimStrings(role.BoundServiceAccounts) - role.BoundServiceAccounts = strutil.RemoveDuplicates(role.BoundServiceAccounts, false) - } - - // Update bound GCP projects. - boundProjects, givenBoundProj := data.GetOk("bound_projects") - if givenBoundProj { - role.BoundProjects = boundProjects.([]string) - } - if projectId, ok := data.GetOk("project_id"); ok { - if givenBoundProj { - return warnings, errors.New("only one of 'bound_projects' or 'project_id' can be given") - } - warnings = append(warnings, - `The "project_id" (singular) field is deprecated. `+ - `Please use plural "bound_projects" instead to bind required GCP projects. `+ - `The "project_id" field will be removed in a later release, so please update accordingly.`) - role.BoundProjects = []string{projectId.(string)} - } - if len(role.BoundProjects) > 0 { - role.BoundProjects = strutil.TrimStrings(role.BoundProjects) - role.BoundProjects = strutil.RemoveDuplicates(role.BoundProjects, false) - } - - // Update bound GCP projects. - addGroupAliases, ok := data.GetOk("add_group_aliases") - if ok { - role.AddGroupAliases = addGroupAliases.(bool) - } - - // Update fields specific to this type - switch role.RoleType { - case iamRoleType: - if err = checkInvalidRoleTypeArgs(data, gceOnlyFieldSchema); err != nil { - return - } - if warnings, err = role.updateIamFields(data, req.Operation); err != nil { - return - } - case gceRoleType: - if err = checkInvalidRoleTypeArgs(data, iamOnlyFieldSchema); err != nil { - return - } - if warnings, err = role.updateGceFields(data, req.Operation); err != nil { - return - } - } - - return -} - -func (role *gcpRole) validate(sys logical.SystemView) (warnings []string, err error) { - warnings = []string{} - - switch role.RoleType { - case iamRoleType: - if warnings, err = role.validateForIAM(); err != nil { - return warnings, err - } - case gceRoleType: - if warnings, err = role.validateForGCE(); err != nil { - return warnings, err - } - case "": - return warnings, errors.New(errEmptyRoleType) - default: - return warnings, fmt.Errorf("role type '%s' is invalid", role.RoleType) - } - - defaultLeaseTTL := sys.DefaultLeaseTTL() - if role.TokenTTL > defaultLeaseTTL { - warnings = append(warnings, fmt.Sprintf( - "Given ttl of %d seconds greater than current mount/system default of %d seconds; ttl will be capped at login time", - role.TokenTTL/time.Second, defaultLeaseTTL/time.Second)) - } - - defaultMaxTTL := sys.MaxLeaseTTL() - if role.TokenMaxTTL > defaultMaxTTL { - warnings = append(warnings, fmt.Sprintf( - "Given max_ttl of %d seconds greater than current mount/system default of %d seconds; max_ttl will be capped at login time", - role.TokenMaxTTL/time.Second, defaultMaxTTL/time.Second)) - } - if role.TokenMaxTTL < time.Duration(0) { - return warnings, errors.New("max_ttl cannot be negative") - } - if role.TokenMaxTTL != 0 && role.TokenMaxTTL < role.TokenTTL { - return warnings, errors.New("ttl should be shorter than max_ttl") - } - - if role.TokenPeriod > sys.MaxLeaseTTL() { - return warnings, fmt.Errorf("'period' of '%s' is greater than the backend's maximum lease TTL of '%s'", role.TokenPeriod.String(), sys.MaxLeaseTTL().String()) - } - - return warnings, nil -} - -// updateIamFields updates IAM-only fields for a role. -func (role *gcpRole) updateIamFields(data *framework.FieldData, op logical.Operation) (warnings []string, err error) { - if allowGCEInference, ok := data.GetOk("allow_gce_inference"); ok { - role.AllowGCEInference = allowGCEInference.(bool) - } else if op == logical.CreateOperation { - role.AllowGCEInference = data.Get("allow_gce_inference").(bool) - } - - if maxJwtExp, ok := data.GetOk("max_jwt_exp"); ok { - role.MaxJwtExp = time.Duration(maxJwtExp.(int)) * time.Second - } else if op == logical.CreateOperation { - role.MaxJwtExp = time.Duration(defaultIamMaxJwtExpMinutes) * time.Minute - } - - return -} - -// updateGceFields updates GCE-only fields for a role. -func (role *gcpRole) updateGceFields(data *framework.FieldData, op logical.Operation) (warnings []string, err error) { - if regions, ok := data.GetOk("bound_regions"); ok { - role.BoundRegions = regions.([]string) - } else if op == logical.CreateOperation { - role.BoundRegions = data.Get("bound_regions").([]string) - } - - if zones, ok := data.GetOk("bound_zones"); ok { - role.BoundZones = zones.([]string) - } else if op == logical.CreateOperation { - role.BoundZones = data.Get("bound_zones").([]string) - } - - if instanceGroups, ok := data.GetOk("bound_instance_groups"); ok { - role.BoundInstanceGroups = instanceGroups.([]string) - } else if op == logical.CreateOperation { - role.BoundInstanceGroups = data.Get("bound_instance_groups").([]string) - } - - if boundRegion, ok := data.GetOk("bound_region"); ok { - if _, ok := data.GetOk("bound_regions"); ok { - err = fmt.Errorf(`cannot specify both "bound_region" and "bound_regions"`) - return - } - - warnings = append(warnings, `The "bound_region" field is deprecated. `+ - `Please use "bound_regions" (plural) instead. You can still specify a `+ - `single region, but multiple regions are also now supported. The `+ - `"bound_region" field will be removed in a later release, so please `+ - `update accordingly.`) - role.BoundRegions = append(role.BoundRegions, boundRegion.(string)) - } - - if boundZone, ok := data.GetOk("bound_zone"); ok { - if _, ok := data.GetOk("bound_zones"); ok { - err = fmt.Errorf(`cannot specify both "bound_zone" and "bound_zones"`) - return - } - - warnings = append(warnings, `The "bound_zone" field is deprecated. `+ - `Please use "bound_zones" (plural) instead. You can still specify a `+ - `single zone, but multiple zones are also now supported. The `+ - `"bound_zone" field will be removed in a later release, so please `+ - `update accordingly.`) - role.BoundZones = append(role.BoundZones, boundZone.(string)) - } - - if boundInstanceGroup, ok := data.GetOk("bound_instance_group"); ok { - if _, ok := data.GetOk("bound_instance_groups"); ok { - err = fmt.Errorf(`cannot specify both "bound_instance_group" and "bound_instance_groups"`) - return - } - - warnings = append(warnings, `The "bound_instance_group" field is deprecated. `+ - `Please use "bound_instance_groups" (plural) instead. You can still specify a `+ - `single instance group, but multiple instance groups are also now supported. The `+ - `"bound_instance_group" field will be removed in a later release, so please `+ - `update accordingly.`) - role.BoundInstanceGroups = append(role.BoundInstanceGroups, boundInstanceGroup.(string)) - } - - if labelsRaw, ok := data.GetOk("bound_labels"); ok { - labels, invalidLabels := gcputil.ParseGcpLabels(labelsRaw.([]string)) - if len(invalidLabels) > 0 { - err = fmt.Errorf("invalid labels given: %q", invalidLabels) - return - } - role.BoundLabels = labels - } - - if len(role.Policies) > 0 { - role.Policies = strutil.TrimStrings(role.Policies) - role.Policies = strutil.RemoveDuplicates(role.Policies, false) - } - - if len(role.BoundRegions) > 0 { - role.BoundRegions = strutil.TrimStrings(role.BoundRegions) - role.BoundRegions = strutil.RemoveDuplicates(role.BoundRegions, false) - } - - if len(role.BoundZones) > 0 { - role.BoundZones = strutil.TrimStrings(role.BoundZones) - role.BoundZones = strutil.RemoveDuplicates(role.BoundZones, false) - } - - if len(role.BoundInstanceGroups) > 0 { - role.BoundInstanceGroups = strutil.TrimStrings(role.BoundInstanceGroups) - role.BoundInstanceGroups = strutil.RemoveDuplicates(role.BoundInstanceGroups, false) - } - - return -} - -// validateIamFields validates the IAM-only fields for a role. -func (role *gcpRole) validateForIAM() (warnings []string, err error) { - if len(role.BoundServiceAccounts) == 0 { - return []string{}, errors.New(errEmptyIamServiceAccounts) - } - - if len(role.BoundServiceAccounts) > 1 && strutil.StrListContains(role.BoundServiceAccounts, serviceAccountsWildcard) { - return []string{}, fmt.Errorf("cannot provide IAM service account wildcard '%s' (for all service accounts) with other service accounts", serviceAccountsWildcard) - } - - maxMaxJwtExp := time.Duration(maxJwtExpMaxMinutes) * time.Minute - if role.MaxJwtExp > maxMaxJwtExp { - return warnings, fmt.Errorf("max_jwt_exp cannot be more than %d minutes", maxJwtExpMaxMinutes) - } - - return []string{}, nil -} - -// validateGceFields validates the GCE-only fields for a role. -func (role *gcpRole) validateForGCE() (warnings []string, err error) { - warnings = []string{} - - hasRegion := len(role.BoundRegions) > 0 - hasZone := len(role.BoundZones) > 0 - hasRegionOrZone := hasRegion || hasZone - - hasInstanceGroup := len(role.BoundInstanceGroups) > 0 - - if hasInstanceGroup && !hasRegionOrZone { - return warnings, errors.New(`region or zone information must be specified if an instance group is given`) - } - - if hasRegion && hasZone { - warnings = append(warnings, `Given both "bound_regions" and "bound_zones" `+ - `fields for role type "gce", "bound_regions" will be ignored in favor `+ - `of the more specific "bound_zones" field. To fix this warning, update `+ - `the role to remove either the "bound_regions" or "bound_zones" field.`) - } - - return warnings, nil -} - // checkInvalidRoleTypeArgs checks that the data provided does not contain arguments // for a different role type. If it does find some, it will return an error with the // invalid args. diff --git a/plugin/path_role_test.go b/plugin/path_role_test.go index 4e51a420..29282d0e 100644 --- a/plugin/path_role_test.go +++ b/plugin/path_role_test.go @@ -2,16 +2,19 @@ package gcpauth import ( "context" + "encoding/json" "fmt" "math/rand" + "reflect" "strings" "testing" "time" - "reflect" - + "github.com/golang/mock/gomock" + "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/policyutil" "github.com/hashicorp/vault/sdk/helper/strutil" + "github.com/hashicorp/vault/sdk/helper/tokenutil" "github.com/hashicorp/vault/sdk/logical" ) @@ -64,6 +67,7 @@ func TestRoleUpdateIam(t *testing.T) { "name": roleName, "type": iamRoleType, "bound_service_accounts": serviceAccounts, + "iam_alias": defaultIAMAlias, }) serviceAccounts = append(serviceAccounts, "testaccount@google.com") @@ -93,6 +97,7 @@ func TestRoleUpdateIam(t *testing.T) { "allow_gce_inference": false, "add_group_aliases": true, "bound_service_accounts": serviceAccounts, + "iam_alias": defaultIAMAlias, }) } @@ -123,6 +128,7 @@ func TestRoleIam_Wildcard(t *testing.T) { testRoleRead(t, b, reqStorage, roleName, map[string]interface{}{ "type": iamRoleType, "bound_service_accounts": serviceAccounts, + "iam_alias": defaultIAMAlias, }) } @@ -145,6 +151,7 @@ func TestRoleIam_EditServiceAccounts(t *testing.T) { "type": iamRoleType, "bound_projects": projects, "bound_service_accounts": initial, + "iam_alias": defaultIAMAlias, } testRoleCreate(t, b, reqStorage, data) @@ -233,6 +240,7 @@ func TestRoleGce(t *testing.T) { "type": gceRoleType, "bound_projects": []string{}, "bound_service_accounts": []string{}, + "gce_alias": defaultGCEAlias, }) serviceAccounts := []string{"aserviceaccountid", "testaccount@google.com"} @@ -270,6 +278,7 @@ func TestRoleGce(t *testing.T) { "bound_instance_groups": []string{"devGroup"}, "bound_service_accounts": serviceAccounts, "add_group_aliases": true, + "gce_alias": defaultGCEAlias, }) } @@ -295,6 +304,7 @@ func TestRoleGce_EditLabels(t *testing.T) { "type": gceRoleType, "bound_projects": []string{projectId}, "bound_labels": labels, + "gce_alias": defaultGCEAlias, }) testRoleEditLabels(t, b, reqStorage, map[string]interface{}{ @@ -308,6 +318,7 @@ func TestRoleGce_EditLabels(t *testing.T) { "type": gceRoleType, "bound_projects": []string{projectId}, "bound_labels": labels, + "gce_alias": defaultGCEAlias, }) testRoleEditLabels(t, b, reqStorage, map[string]interface{}{ @@ -323,6 +334,7 @@ func TestRoleGce_EditLabels(t *testing.T) { "type": gceRoleType, "bound_projects": []string{projectId}, "bound_labels": labels, + "gce_alias": defaultGCEAlias, }) } @@ -354,6 +366,7 @@ func TestRoleGce_DeprecatedFields(t *testing.T) { "bound_regions": []string{"us-east1"}, "bound_zones": []string{"us-east1-a"}, "bound_instance_groups": []string{"my-ig"}, + "gce_alias": defaultGCEAlias, }) }) @@ -451,6 +464,511 @@ func TestRole_InvalidRoleType(t *testing.T) { }, []string{"role type", invalidRoleType, "is invalid"}) } +func TestRetrieveRole(t *testing.T) { + type testCase struct { + name string + + getName string + getResp *logical.StorageEntry + getErr error + getTimes int + + localMount bool + localMountTimes int + replicationState consts.ReplicationState + replicationStateTimes int + + putTimes int + + expectedRole *gcpRole + expectErr bool + } + + tests := map[string]testCase{ + "not found": { + name: "testrole", + + getName: "role/testrole", + getResp: nil, + getErr: nil, + getTimes: 1, + + localMountTimes: 0, + replicationStateTimes: 0, + + expectedRole: nil, + expectErr: false, + }, + "storage error": { + name: "testrole", + + getName: "role/testrole", + getResp: nil, + getErr: fmt.Errorf("test error"), + getTimes: 1, + + localMountTimes: 0, + replicationStateTimes: 0, + + expectedRole: nil, + expectErr: true, + }, + "bad data": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "role/testrole", + Value: []byte("asdfhoiasndf"), + }, + getErr: nil, + getTimes: 1, + + localMountTimes: 0, + replicationStateTimes: 0, + + expectedRole: nil, + expectErr: true, + }, + "iam nothing modified": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + RoleType: "iam", + IAMAliasType: defaultIAMAlias, + }), + }, + getErr: nil, + getTimes: 1, + + localMountTimes: 0, + replicationStateTimes: 0, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + RoleType: "iam", + IAMAliasType: defaultIAMAlias, + }, + expectErr: false, + }, + "gce nothing modified": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + RoleType: "gce", + GCEAliasType: defaultGCEAlias, + }), + }, + getErr: nil, + getTimes: 1, + + localMountTimes: 0, + replicationStateTimes: 0, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + RoleType: "gce", + GCEAliasType: defaultGCEAlias, + }, + expectErr: false, + }, + "projectID upgrade": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + ProjectId: "projectID", + BoundProjects: []string{}, + }), + }, + getErr: nil, + getTimes: 1, + + localMount: true, + localMountTimes: 1, + replicationStateTimes: 0, + + putTimes: 1, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + BoundProjects: []string{"projectID"}, + }, + expectErr: false, + }, + "boundRegion upgrade": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + BoundRegion: "boundRegion", + BoundRegions: []string{}, + }), + }, + getErr: nil, + getTimes: 1, + + localMount: true, + localMountTimes: 1, + replicationStateTimes: 0, + + putTimes: 1, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + BoundRegions: []string{"boundRegion"}, + }, + expectErr: false, + }, + "boundZone upgrade": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + BoundZone: "boundZone", + BoundZones: []string{}, + }), + }, + getErr: nil, + getTimes: 1, + + localMount: true, + localMountTimes: 1, + replicationStateTimes: 0, + + putTimes: 1, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + BoundZones: []string{"boundZone"}, + }, + expectErr: false, + }, + "boundInstanceGroup upgrade": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + BoundInstanceGroup: "boundInstanceGroup", + BoundInstanceGroups: []string{}, + }), + }, + getErr: nil, + getTimes: 1, + + localMount: true, + localMountTimes: 1, + replicationStateTimes: 0, + + putTimes: 1, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + BoundInstanceGroups: []string{"boundInstanceGroup"}, + }, + expectErr: false, + }, + "TTL upgrade": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + TokenParams: tokenutil.TokenParams{}, + TTL: 1 * time.Second, + }), + }, + getErr: nil, + getTimes: 1, + + localMount: true, + localMountTimes: 1, + replicationStateTimes: 0, + + putTimes: 1, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + TokenParams: tokenutil.TokenParams{ + TokenTTL: 1 * time.Second, + }, + TTL: 1 * time.Second, + }, + expectErr: false, + }, + "MaxTTL upgrade": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + TokenParams: tokenutil.TokenParams{}, + MaxTTL: 1 * time.Second, + }), + }, + getErr: nil, + getTimes: 1, + + localMount: true, + localMountTimes: 1, + replicationStateTimes: 0, + + putTimes: 1, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + TokenParams: tokenutil.TokenParams{ + TokenMaxTTL: 1 * time.Second, + }, + MaxTTL: 1 * time.Second, + }, + expectErr: false, + }, + "TokenPeriod upgrade": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + TokenParams: tokenutil.TokenParams{}, + Period: 1 * time.Second, + }), + }, + getErr: nil, + getTimes: 1, + + localMount: true, + localMountTimes: 1, + replicationStateTimes: 0, + + putTimes: 1, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + TokenParams: tokenutil.TokenParams{ + TokenPeriod: 1 * time.Second, + }, + Period: 1 * time.Second, + }, + expectErr: false, + }, + "TokenPolicies upgrade": { + name: "testrole", + + getName: "role/testrole", + getResp: &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + TokenParams: tokenutil.TokenParams{}, + Policies: []string{"policy1", "policy2"}, + }), + }, + getErr: nil, + getTimes: 1, + + localMount: true, + localMountTimes: 1, + replicationStateTimes: 0, + + putTimes: 1, + + expectedRole: &gcpRole{ + RoleID: "testroleid", + TokenParams: tokenutil.TokenParams{ + TokenPolicies: []string{"policy1", "policy2"}, + }, + Policies: []string{"policy1", "policy2"}, + }, + expectErr: false, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + storage := NewMockStorage(ctrl) + storage.EXPECT().Get(ctx, test.getName).Return(test.getResp, test.getErr).Times(test.getTimes) + + putReq := &logical.StorageEntry{ + Key: fmt.Sprintf("role/%s", test.name), + Value: append(toJSON(t, test.expectedRole), '\n'), // Add a newline because StorageEntryJSON somehow adds it + } + storage.EXPECT().Put(ctx, putReq).Return(nil).Times(test.putTimes) + + systemView := NewMockSystemView(ctrl) + systemView.EXPECT().LocalMount().Return(test.localMount).Times(test.localMountTimes) + systemView.EXPECT().ReplicationState().Return(test.replicationState).Times(test.replicationStateTimes) + + be, err := Factory(ctx, &logical.BackendConfig{System: systemView}) + if err != nil { + t.Fatalf("no error expected, got: %s", err) + } + b := be.(*GcpAuthBackend) + + actualResp, err := b.role(ctx, storage, test.name) + if test.expectErr && err == nil { + t.Fatalf("err expected, got nil") + } + if !test.expectErr && err != nil { + t.Fatalf("no error expected, got: %s", err) + } + if !reflect.DeepEqual(actualResp, test.expectedRole) { + t.Fatalf("Actual role: %#v\nExpected role: %#v", actualResp, test.expectedRole) + } + }) + } + + t.Run("storage put error", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + name := "testrole" + + getResp := &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{ + RoleID: "testroleid", + ProjectId: "projectID", + }), + } + + putReq := &logical.StorageEntry{ + Key: fmt.Sprintf("role/%s", name), + Value: append(toJSON(t, + gcpRole{ + RoleID: "testroleid", + BoundProjects: []string{"projectID"}, + }, + ), '\n'), // Add a newline because StorageEntryJSON somehow adds it + } + + storage := NewMockStorage(ctrl) + storage.EXPECT().Get(ctx, fmt.Sprintf("role/%s", name)).Return(getResp, nil) + storage.EXPECT().Put(ctx, putReq).Return(fmt.Errorf("test error")) + + systemView := NewMockSystemView(ctrl) + systemView.EXPECT().LocalMount().Return(true) + + be, err := Factory(ctx, &logical.BackendConfig{System: systemView}) + if err != nil { + t.Fatalf("no error expected, got: %s", err) + } + b := be.(*GcpAuthBackend) + + actualResp, err := b.role(ctx, storage, name) + if err == nil { + t.Fatalf("err expected, got nil") + } + if actualResp != nil { + t.Fatalf("no role expected, but got: %#v", actualResp) + } + }) + + t.Run("roleID is generated when one does not exist", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + name := "testrole" + + getResp := &logical.StorageEntry{ + Key: "testrole", + Value: toJSON(t, + gcpRole{}), + } + + storage := NewMockStorage(ctrl) + storage.EXPECT().Get(ctx, fmt.Sprintf("role/%s", name)).Return(getResp, nil) + + var actualRawPut *logical.StorageEntry + storage.EXPECT().Put(ctx, gomock.Any()).DoAndReturn(func(_ context.Context, put *logical.StorageEntry) error { + actualRawPut = put + return nil + }) + + systemView := NewMockSystemView(ctrl) + systemView.EXPECT().LocalMount().Return(true) + + be, err := Factory(ctx, &logical.BackendConfig{System: systemView}) + if err != nil { + t.Fatalf("no error expected, got: %s", err) + } + b := be.(*GcpAuthBackend) + + actualRole, err := b.role(ctx, storage, name) + if err != nil { + t.Fatalf("no err expected, got: %s", err) + } + + if actualRole.RoleID == "" { + t.Fatalf("RoleID not set on returned role") + } + + expectedPutKey := fmt.Sprintf("role/%s", name) + if actualRawPut.Key != expectedPutKey { + t.Fatalf("Actual put key: %s Expected put key: %s", actualRawPut.Key, expectedPutKey) + } + + putRole := gcpRole{} + err = json.Unmarshal(actualRawPut.Value, &putRole) + if err != nil { + t.Fatalf("no err expected, got: %s", err) + } + + if putRole.RoleID != actualRole.RoleID { + t.Fatalf("Saved RoleID [%s] does not match returned RoleID [%s]", putRole.RoleID, actualRole.RoleID) + } + }) +} + //-- Utils -- func testRoleCreate(tb testing.TB, b logical.Backend, s logical.Storage, d map[string]interface{}) { tb.Helper() @@ -558,6 +1076,13 @@ func testRoleRead(tb testing.TB, b logical.Backend, s logical.Storage, roleName tb.Fatal(resp.Error()) } + // Because role_id is generated, ensure that it exists but don't worry about the specific value + roleID, exists := resp.Data["role_id"] + if !exists || roleID == "" { + tb.Fatal("missing or empty role_id") + } + delete(resp.Data, "role_id") + if err := checkData(resp, expected, expectedDefaults); err != nil { tb.Fatal(err) } @@ -642,3 +1167,13 @@ func testRoleAndProject(tb testing.TB) (string, string) { return roleName, projectId } + +func toJSON(t testing.TB, val interface{}) []byte { + t.Helper() + + b, err := json.Marshal(val) + if err != nil { + t.Fatalf("Failed to marshal to JSON: %s", err) + } + return b +}