diff --git a/idtoken/idtoken.go b/idtoken/idtoken.go index 3dce463bd61..152ac9c2007 100644 --- a/idtoken/idtoken.go +++ b/idtoken/idtoken.go @@ -8,12 +8,16 @@ import ( "context" "encoding/json" "fmt" + "log" "net/http" + "path/filepath" + "strings" "cloud.google.com/go/compute/metadata" "golang.org/x/oauth2" "golang.org/x/oauth2/google" + "google.golang.org/api/impersonate" "google.golang.org/api/internal" "google.golang.org/api/option" "google.golang.org/api/option/internaloption" @@ -67,6 +71,7 @@ func NewClient(ctx context.Context, audience string, opts ...ClientOption) (*htt // provided and configured with the supplied options. The parameter audience may // not be empty. func NewTokenSource(ctx context.Context, audience string, opts ...ClientOption) (oauth2.TokenSource, error) { + option.WithScopes() if audience == "" { return nil, fmt.Errorf("idtoken: must supply a non-empty audience") } @@ -103,45 +108,72 @@ func newTokenSource(ctx context.Context, audience string, ds *internal.DialSetti } func tokenSourceFromBytes(ctx context.Context, data []byte, audience string, ds *internal.DialSettings) (oauth2.TokenSource, error) { - if err := isServiceAccount(data); err != nil { - return nil, err - } - cfg, err := google.JWTConfigFromJSON(data, ds.GetScopes()...) + allowedType, err := getAllowedType(data) if err != nil { return nil, err } - - customClaims := ds.CustomClaims - if customClaims == nil { - customClaims = make(map[string]interface{}) - } - customClaims["target_audience"] = audience - - cfg.PrivateClaims = customClaims - cfg.UseIDToken = true - - ts := cfg.TokenSource(ctx) - tok, err := ts.Token() - if err != nil { - return nil, err + if allowedType == "service_account" { + cfg, err := google.JWTConfigFromJSON(data, ds.GetScopes()...) + if err != nil { + return nil, err + } + customClaims := ds.CustomClaims + if customClaims == nil { + customClaims = make(map[string]interface{}) + } + customClaims["target_audience"] = audience + + cfg.PrivateClaims = customClaims + cfg.UseIDToken = true + + ts := cfg.TokenSource(ctx) + tok, err := ts.Token() + if err != nil { + return nil, err + } + return oauth2.ReuseTokenSource(tok, ts), nil + } else { + // if allowedType is "impersonated_service_account": + type url struct { + ServiceAccountImpersonationURL string `json:"service_account_impersonation_url"` + } + var accountUrl *url + if err := json.Unmarshal(data, &accountUrl); err != nil { + return nil, err + } + account := filepath.Base(accountUrl.ServiceAccountImpersonationURL) + account = strings.Split(account, ":")[0] + + config := impersonate.IDTokenConfig{ + Audience: audience, + TargetPrincipal: account, + IncludeEmail: true, + } + ts, err := impersonate.IDTokenSource(ctx, config) + if err != nil { + log.Println(err) + } + return ts, nil } - return oauth2.ReuseTokenSource(tok, ts), nil } -func isServiceAccount(data []byte) error { +// isOfAllowedType returns the credentials type as a string, and an error. +// allowed types are "service_account" and "impersonated_service_account" +func getAllowedType(data []byte) (string, error) { if len(data) == 0 { - return fmt.Errorf("idtoken: credential provided is 0 bytes") + return "", fmt.Errorf("idtoken: credential provided is 0 bytes") } var f struct { Type string `json:"type"` } + // if not service account return an error if err := json.Unmarshal(data, &f); err != nil { - return err + return "", err } - if f.Type != "service_account" { - return fmt.Errorf("idtoken: credential must be service_account, found %q", f.Type) + if f.Type != "service_account" && f.Type != "impersonated_service_account" { + return "", fmt.Errorf("idtoken: credential must be service_account or impersonated_service_account, found %q", f.Type) } - return nil + return f.Type, nil } // WithCustomClaims optionally specifies custom private claims for an ID token. diff --git a/idtoken/idtoken_test.go b/idtoken/idtoken_test.go new file mode 100644 index 00000000000..367f12d8710 --- /dev/null +++ b/idtoken/idtoken_test.go @@ -0,0 +1,80 @@ +// Copyright 2020 Google LLC. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package idtoken + +import ( + "context" + "reflect" + "testing" + + "golang.org/x/oauth2" + "google.golang.org/api/internal" +) + +var TokenSource oauth2.TokenSource + +func TestNewTokenSource(t *testing.T) { + tests := []struct { + name string + ctx context.Context + audience string + want oauth2.TokenSource + wantErr bool + }{ + { + name: "works", + ctx: context.Background(), + audience: "https://apikeys.googleapis.com", + want: TokenSource, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewTokenSource(tt.ctx, tt.audience) + if (err != nil) != tt.wantErr { + t.Errorf("NewTokenSource() error = %v, wantErr %v", err, tt.wantErr) + return + } + tok, err := got.Token() + if (err != nil) != tt.wantErr { + t.Errorf("NewTokenSource() error = %v, wantErr %v", err, tt.wantErr) + return + } + _, err = Validate(tt.ctx, tok.AccessToken, tt.audience) + if err != nil { + t.Errorf("NewTokenSource() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_newTokenSource(t *testing.T) { + type args struct { + ctx context.Context + audience string + ds *internal.DialSettings + } + tests := []struct { + name string + args args + want oauth2.TokenSource + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newTokenSource(tt.args.ctx, tt.args.audience, tt.args.ds) + if (err != nil) != tt.wantErr { + t.Errorf("newTokenSource() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("newTokenSource() = %v, want %v", got, tt.want) + } + }) + } +}