Skip to content

Commit

Permalink
feat(idtoken): NewTokenSource allows impersonated_service_account creds
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianajg committed Dec 27, 2022
1 parent 9fb35f5 commit 5267088
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 25 deletions.
82 changes: 57 additions & 25 deletions idtoken/idtoken.go
Expand Up @@ -8,12 +8,16 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"path/filepath"
"strings"

"cloud.google.com/go/compute/metadata"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"

"google.golang.org/api/impersonate"
"google.golang.org/api/internal"
"google.golang.org/api/option"
"google.golang.org/api/option/internaloption"
Expand Down Expand Up @@ -67,6 +71,7 @@ func NewClient(ctx context.Context, audience string, opts ...ClientOption) (*htt
// provided and configured with the supplied options. The parameter audience may
// not be empty.
func NewTokenSource(ctx context.Context, audience string, opts ...ClientOption) (oauth2.TokenSource, error) {
option.WithScopes()
if audience == "" {
return nil, fmt.Errorf("idtoken: must supply a non-empty audience")
}
Expand Down Expand Up @@ -103,45 +108,72 @@ func newTokenSource(ctx context.Context, audience string, ds *internal.DialSetti
}

func tokenSourceFromBytes(ctx context.Context, data []byte, audience string, ds *internal.DialSettings) (oauth2.TokenSource, error) {
if err := isServiceAccount(data); err != nil {
return nil, err
}
cfg, err := google.JWTConfigFromJSON(data, ds.GetScopes()...)
allowedType, err := getAllowedType(data)
if err != nil {
return nil, err
}

customClaims := ds.CustomClaims
if customClaims == nil {
customClaims = make(map[string]interface{})
}
customClaims["target_audience"] = audience

cfg.PrivateClaims = customClaims
cfg.UseIDToken = true

ts := cfg.TokenSource(ctx)
tok, err := ts.Token()
if err != nil {
return nil, err
if allowedType == "service_account" {
cfg, err := google.JWTConfigFromJSON(data, ds.GetScopes()...)
if err != nil {
return nil, err
}
customClaims := ds.CustomClaims
if customClaims == nil {
customClaims = make(map[string]interface{})
}
customClaims["target_audience"] = audience

cfg.PrivateClaims = customClaims
cfg.UseIDToken = true

ts := cfg.TokenSource(ctx)
tok, err := ts.Token()
if err != nil {
return nil, err
}
return oauth2.ReuseTokenSource(tok, ts), nil
} else {
// if allowedType is "impersonated_service_account":
type url struct {
ServiceAccountImpersonationURL string `json:"service_account_impersonation_url"`
}
var accountUrl *url
if err := json.Unmarshal(data, &accountUrl); err != nil {
return nil, err
}
account := filepath.Base(accountUrl.ServiceAccountImpersonationURL)
account = strings.Split(account, ":")[0]

config := impersonate.IDTokenConfig{
Audience: audience,
TargetPrincipal: account,
IncludeEmail: true,
}
ts, err := impersonate.IDTokenSource(ctx, config)
if err != nil {
log.Println(err)
}
return ts, nil
}
return oauth2.ReuseTokenSource(tok, ts), nil
}

func isServiceAccount(data []byte) error {
// isOfAllowedType returns the credentials type as a string, and an error.
// allowed types are "service_account" and "impersonated_service_account"
func getAllowedType(data []byte) (string, error) {
if len(data) == 0 {
return fmt.Errorf("idtoken: credential provided is 0 bytes")
return "", fmt.Errorf("idtoken: credential provided is 0 bytes")
}
var f struct {
Type string `json:"type"`
}
// if not service account return an error
if err := json.Unmarshal(data, &f); err != nil {
return err
return "", err
}
if f.Type != "service_account" {
return fmt.Errorf("idtoken: credential must be service_account, found %q", f.Type)
if f.Type != "service_account" && f.Type != "impersonated_service_account" {
return "", fmt.Errorf("idtoken: credential must be service_account or impersonated_service_account, found %q", f.Type)
}
return nil
return f.Type, nil
}

// WithCustomClaims optionally specifies custom private claims for an ID token.
Expand Down
80 changes: 80 additions & 0 deletions idtoken/idtoken_test.go
@@ -0,0 +1,80 @@
// Copyright 2020 Google LLC.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package idtoken

import (
"context"
"reflect"
"testing"

"golang.org/x/oauth2"
"google.golang.org/api/internal"
)

var TokenSource oauth2.TokenSource

func TestNewTokenSource(t *testing.T) {
tests := []struct {
name string
ctx context.Context
audience string
want oauth2.TokenSource
wantErr bool
}{
{
name: "works",
ctx: context.Background(),
audience: "https://apikeys.googleapis.com",
want: TokenSource,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewTokenSource(tt.ctx, tt.audience)
if (err != nil) != tt.wantErr {
t.Errorf("NewTokenSource() error = %v, wantErr %v", err, tt.wantErr)
return
}
tok, err := got.Token()
if (err != nil) != tt.wantErr {
t.Errorf("NewTokenSource() error = %v, wantErr %v", err, tt.wantErr)
return
}
_, err = Validate(tt.ctx, tok.AccessToken, tt.audience)
if err != nil {
t.Errorf("NewTokenSource() = %v, want %v", got, tt.want)
}
})
}
}

func Test_newTokenSource(t *testing.T) {
type args struct {
ctx context.Context
audience string
ds *internal.DialSettings
}
tests := []struct {
name string
args args
want oauth2.TokenSource
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := newTokenSource(tt.args.ctx, tt.args.audience, tt.args.ds)
if (err != nil) != tt.wantErr {
t.Errorf("newTokenSource() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("newTokenSource() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit 5267088

Please sign in to comment.