Skip to content

Commit

Permalink
Add Google client initialisation helper method to generate ID tokens (#…
Browse files Browse the repository at this point in the history
…78)

* Add google auth methods

* Add service account unit tests

* Fix lint comments

* Fix unit tests

* Refactor unit tests

* Fix lint comments

* Fix broke unit test when no credentials are found

* Refactor unit tests to have proper dummy credentials

* Shorten string literals

* Fix lint comment on redundant declaration

* Remove redundant step to look for credentials again

* Update project names

* Remove audience parameter for InitGoogleClient

* Set default audience as a constant

* Remove redundant SkipDialSettingsValidation parameter
  • Loading branch information
deadlycoconuts committed Mar 28, 2023
1 parent 05ede13 commit 45464ea
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 1 deletion.
2 changes: 1 addition & 1 deletion api/api/router.go
Expand Up @@ -98,7 +98,7 @@ func (route Route) HandlerFunc(validate *validator.Validate) http.HandlerFunc {

response := func() *Response {
vars["user"] = r.Header.Get("User-Email")
var body interface{} = nil
var body interface{}

if bodyType != nil {
body = reflect.New(bodyType).Interface()
Expand Down
76 changes: 76 additions & 0 deletions api/pkg/auth/google_auth.go
@@ -0,0 +1,76 @@
package auth

import (
"context"
"encoding/json"
"fmt"
"net/http"

"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/idtoken"
"google.golang.org/api/option"
htransport "google.golang.org/api/transport/http"
)

const (
defaultCaraMLAudience = "api.caraml"
// JSON key file types
serviceAccountKey = "service_account"
)

type credentialsFile struct {
Type string `json:"type"`
}

// idTokenSource is an oauth2.TokenSource that wraps another TokenSource
// It takes the id_token from TokenSource and passes that on as a bearer token
type idTokenSource struct {
TokenSource oauth2.TokenSource
}

func (s *idTokenSource) Token() (*oauth2.Token, error) {
token, err := s.TokenSource.Token()
if err != nil {
return nil, err
}

idToken, ok := token.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("token did not contain an id_token")
}

return &oauth2.Token{
AccessToken: idToken,
TokenType: "Bearer",
Expiry: token.Expiry,
}, nil
}

// InitGoogleClient is a helper method to be used by CaraML components to initialise a Google Client that appends ID
// tokens to the headers of all outgoing requests with ID tokens, regardless of the type of credentials used
func InitGoogleClient(ctx context.Context) (*http.Client, error) {
cred, err := google.FindDefaultCredentials(ctx)
if err != nil {
return nil, err
}

var f credentialsFile
if err := json.Unmarshal(cred.JSON, &f); err != nil {
return nil, err
}

if f.Type == serviceAccountKey {
return idtoken.NewClient(ctx, defaultCaraMLAudience)
}

tokenSource := oauth2.ReuseTokenSource(nil, &idTokenSource{TokenSource: cred.TokenSource})

var opts []idtoken.ClientOption
opts = append(opts, option.WithTokenSource(tokenSource))
t, err := htransport.NewTransport(ctx, http.DefaultTransport, opts...)
if err != nil {
return nil, err
}
return &http.Client{Transport: t}, nil
}
121 changes: 121 additions & 0 deletions api/pkg/auth/google_auth_test.go
@@ -0,0 +1,121 @@
package auth

import (
"context"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

// testSetupDummyGoogleCredentials creates a temporary file containing dummy credentials JSON
// then set the environment variable GOOGLE_APPLICATION_CREDENTIALS to point to the file.
//
// This is useful for tests that assume Google Cloud Client libraries can automatically find
// the service account credentials in any environment.
//
// At the end of the test, the returned function can be called to perform cleanup.
func testSetupDummyGoogleCredentials(t *testing.T, dummyCredentials []byte) (reset func()) {
file, err := os.CreateTemp("", "dummy-credentials")
if err != nil {
t.Fatal(err)
}

err = os.WriteFile(file.Name(), dummyCredentials, 0644)
if err != nil {
t.Fatal(err)
}

err = os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", file.Name())
if err != nil {
t.Fatal(err)
}

return func() {
err := os.Unsetenv("GOOGLE_APPLICATION_CREDENTIALS")
if err != nil {
t.Log("Cleanup failed", err)
}
err = os.Remove(file.Name())
if err != nil {
t.Log("Cleanup failed", err)
}
}
}

func TestTestInitGoogleClient(t *testing.T) {
// Define tests
tests := map[string]struct {
dummyCredential string
err string
}{
"failure | no default credentials found": {
err: "google: could not find default credentials. See " +
"https://developers.google.com/accounts/docs/application-default-credentials for more information.",
},
"failure | invalid json file": {
dummyCredential: `{`,
err: "google: error getting credentials using GOOGLE_APPLICATION_CREDENTIALS " +
"environment variable: unexpected end of JSON input",
},
"failure | json file with invalid credentials": {
dummyCredential: `{}`,
err: "google: error getting credentials using GOOGLE_APPLICATION_CREDENTIALS " +
"environment variable: missing 'type' field in credentials",
},
"failure | service account not found": {
//nolint:lll // the private key is in a string literal and cannot contain arbitrary line breaks
dummyCredential: `{
"type": "service_account",
"project_id": "example-project",
"private_key_id": "1",
"private_key": "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj\n7wZgkdmM7oVK2OfgrSj/FCTkInKPqaCR0gD7K80q+mLBrN3PUkDrJQZpvRZIff3/\nxmVU1WeruQLFJjnFb2dqu0s/FY/2kWiJtBCakXvXEOb7zfbINuayL+MSsCGSdVYs\nSliS5qQpgyDap+8b5fpXZVJkq92hrcNtbkg7hCYUJczt8n9hcCTJCfUpApvaFQ18\npe+zpyl4+WzkP66I28hniMQyUlA1hBiskT7qiouq0m8IOodhv2fagSZKjOTTU2xk\nSBc//fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQABAoIBAQDGGHzQxGKX+ANk\nnQi53v/c6632dJKYXVJC+PDAz4+bzU800Y+n/bOYsWf/kCp94XcG4Lgsdd0Gx+Zq\nHD9CI1IcqqBRR2AFscsmmX6YzPLTuEKBGMW8twaYy3utlFxElMwoUEsrSWRcCA1y\nnHSDzTt871c7nxCXHxuZ6Nm/XCL7Bg8uidRTSC1sQrQyKgTPhtQdYrPQ4WZ1A4J9\nIisyDYmZodSNZe5P+LTJ6M1SCgH8KH9ZGIxv3diMwzNNpk3kxJc9yCnja4mjiGE2\nYCNusSycU5IhZwVeCTlhQGcNeV/skfg64xkiJE34c2y2ttFbdwBTPixStGaF09nU\nZ422D40BAoGBAPvVyRRsC3BF+qZdaSMFwI1yiXY7vQw5+JZh01tD28NuYdRFzjcJ\nvzT2n8LFpj5ZfZFvSMLMVEFVMgQvWnN0O6xdXvGov6qlRUSGaH9u+TCPNnIldjMP\nB8+xTwFMqI7uQr54wBB+Poq7dVRP+0oHb0NYAwUBXoEuvYo3c/nDoRcZAoGBAOWl\naLHjMv4CJbArzT8sPfic/8waSiLV9Ixs3Re5YREUTtnLq7LoymqB57UXJB3BNz/2\neCueuW71avlWlRtE/wXASj5jx6y5mIrlV4nZbVuyYff0QlcG+fgb6pcJQuO9DxMI\naqFGrWP3zye+LK87a6iR76dS9vRU+bHZpSVvGMKJAoGAFGt3TIKeQtJJyqeUWNSk\nklORNdcOMymYMIlqG+JatXQD1rR6ThgqOt8sgRyJqFCVT++YFMOAqXOBBLnaObZZ\nCFbh1fJ66BlSjoXff0W+SuOx5HuJJAa5+WtFHrPajwxeuRcNa8jwxUsB7n41wADu\nUqWWSRedVBg4Ijbw3nWwYDECgYB0pLew4z4bVuvdt+HgnJA9n0EuYowVdadpTEJg\nsoBjNHV4msLzdNqbjrAqgz6M/n8Ztg8D2PNHMNDNJPVHjJwcR7duSTA6w2p/4k28\nbvvk/45Ta3XmzlxZcZSOct3O31Cw0i2XDVc018IY5be8qendDYM08icNo7vQYkRH\n504kQQKBgQDjx60zpz8ozvm1XAj0wVhi7GwXe+5lTxiLi9Fxq721WDxPMiHDW2XL\nYXfFVy/9/GIMvEiGYdmarK1NW+VhWl1DC5xhDg0kvMfxplt4tynoq1uTsQTY31Mx\nBeF5CT/JuNYk3bEBF0H/Q3VGO1/ggVS+YezdFbLWIRoMnLj6XCFEGg==\n-----END RSA PRIVATE KEY-----\n",
"client_email": "service-account@example.com",
"client_id": "1234",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://accounts.google.com/o/oauth2/token"
}`,
err: "oauth2: cannot fetch token: 400 Bad Request\nResponse: " +
"{\"error\":\"invalid_grant\",\"error_description\":\"Invalid grant: account not found\"}",
},
"failure | invalid credentials": {
dummyCredential: `{
"client_id": "dummyclientid.apps.googleusercontent.com",
"client_secret": "dummy-secret",
"quota_project_id": "test-project",
"refresh_token": "dummy-token",
"type": "unauthorized_user"
}`,
err: "google: error getting credentials using GOOGLE_APPLICATION_CREDENTIALS environment variable: " +
"unknown credential type: \"unauthorized_user\"",
},
"success | user account": {
dummyCredential: `{
"client_id": "dummyclientid.apps.googleusercontent.com",
"client_secret": "dummy-secret",
"quota_project_id": "test-project",
"refresh_token": "dummy-token",
"type": "authorized_user"
}`,
},
}

// Run tests
for name, data := range tests {
t.Run(name, func(t *testing.T) {
if data.dummyCredential != "" {
reset := testSetupDummyGoogleCredentials(t, []byte(data.dummyCredential))
defer reset()
}

client, err := InitGoogleClient(context.Background())
if data.err != "" {
assert.EqualError(t, err, data.err)
assert.Nil(t, client)
} else {
assert.NoError(t, err)
assert.NotNil(t, client)
}
})
}
}
3 changes: 3 additions & 0 deletions go.mod
Expand Up @@ -31,6 +31,7 @@ require (
github.com/uber/jaeger-client-go v2.16.0+incompatible
go.uber.org/zap v1.17.0
golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5
google.golang.org/api v0.75.0
gopkg.in/errgo.v2 v2.1.0
gopkg.in/yaml.v2 v2.4.0
k8s.io/client-go v0.26.0
Expand Down Expand Up @@ -58,6 +59,7 @@ require (
github.com/go-openapi/validate v0.19.7 // indirect
github.com/go-stack/stack v1.8.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/golang/snappy v0.0.3 // indirect
github.com/google/gofuzz v1.1.0 // indirect
Expand Down Expand Up @@ -92,6 +94,7 @@ require (
github.com/uber-go/atomic v1.4.0 // indirect
github.com/uber/jaeger-lib v2.0.0+incompatible // indirect
go.mongodb.org/mongo-driver v1.1.2 // indirect
go.opencensus.io v0.23.0 // indirect
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/multierr v1.6.0 // indirect
golang.org/x/crypto v0.1.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Expand Up @@ -293,6 +293,8 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfU
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
Expand Down Expand Up @@ -713,6 +715,7 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
Expand Down Expand Up @@ -1074,6 +1077,7 @@ google.golang.org/api v0.67.0/go.mod h1:ShHKP8E60yPsKNw/w8w+VYaj9H6buA5UqDp8dhbQ
google.golang.org/api v0.70.0/go.mod h1:Bs4ZM2HGifEvXwd50TtW70ovgJffJYw2oRCOFU/SkfA=
google.golang.org/api v0.71.0/go.mod h1:4PyU6e6JogV1f9eA4voyrTY2batOLdgZ5qZ5HOCc4j8=
google.golang.org/api v0.74.0/go.mod h1:ZpfMZOVRMywNyvJFeqL9HRWBgAuRfSjJFpe9QtRRyDs=
google.golang.org/api v0.75.0 h1:0AYh/ae6l9TDUvIQrDw5QRpM100P6oHgD+o3dYHMzJg=
google.golang.org/api v0.75.0/go.mod h1:pU9QmyHLnzlpar1Mjt4IbapUCy8J+6HD6GeELN69ljA=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
Expand Down

0 comments on commit 45464ea

Please sign in to comment.