Skip to content

Commit

Permalink
Add support for pulling ID Tokens from the metadata server (#8)
Browse files Browse the repository at this point in the history
* Add support for pulling ID Tokens from the metadata server

This adds support for cloud-run-proxy to pull ID Tokens from the metadata server instead of always assuming gcloud. This means it will work on a GCE VM or Cloud Run service.

However, this requires a user to specify an audience value for the JWT. When using the gcloud token, Cloud Run trusts the gcloud client IDs as valid aud values, but the only truly accepted value is the URL of the server. That's fine - we have the URL of the service because we need it to proxy, but it does introduce an edge case where a Cloud Run service is fronted by a Load Balancer and the Load Balancer is serving a vanity URL. In this case, the user must specify the "host" value as the Load Balancer DNS entry, but the "audience" value must be the .run.app URL.

* Address review feedback

* Finish sentence
  • Loading branch information
sethvargo committed Feb 16, 2022
1 parent e62b5db commit 263984c
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 40 deletions.
12 changes: 12 additions & 0 deletions README.md
Expand Up @@ -63,3 +63,15 @@ Override the token (useful if you don't have gcloud installed):
```sh
cloud-run-proxy -token "yc..."
```

Specify a custom audience:

```sh
cloud-run-proxy -audience "https://my-service-daga283.run.app"
```

Note: when running on Compute Engine or other services with a metadata service, the audience defaults to the host URL. If you are accessing your Cloud Run service through a load balancer with a vanity domain, you must specify the audience value as the non-vanity URL of your service:

```sh
cloud-run-proxy -host "https://custom-domain.com" -audience "https://my-service-daga283.run.app"
```
11 changes: 10 additions & 1 deletion go.mod
Expand Up @@ -16,12 +16,21 @@ module github.com/GoogleCloudPlatform/cloud-run-proxy

go 1.17

require golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8
require (
golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8
google.golang.org/api v0.61.0
)

require (
cloud.google.com/go v0.99.0 // indirect
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect
github.com/golang/protobuf v1.5.2 // indirect
go.opencensus.io v0.23.0 // indirect
golang.org/x/net v0.0.0-20211209124913-491a49abca63 // indirect
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 // indirect
golang.org/x/text v0.3.6 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20211206160659-862468c7d6e0 // indirect
google.golang.org/grpc v1.40.0 // indirect
google.golang.org/protobuf v1.27.1 // indirect
)
7 changes: 7 additions & 0 deletions go.sum
Expand Up @@ -75,6 +75,7 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
Expand Down Expand Up @@ -174,6 +175,7 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
Expand Down Expand Up @@ -326,6 +328,7 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210823070655-63515b42dcdf/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210908233432-aa78b53d3365/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 h1:TyHqChC80pFkXWraUUf6RuB5IqFdQieMLwwCJokV2pc=
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand All @@ -335,6 +338,7 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
Expand Down Expand Up @@ -424,6 +428,7 @@ google.golang.org/api v0.54.0/go.mod h1:7C4bFFOvVDGXjfDTAsgGwDgAxRDeQ4X8NvUedIt6
google.golang.org/api v0.55.0/go.mod h1:38yMfeP1kfjsl8isn0tliTjIb1rJXcQi4UXlbqivdVE=
google.golang.org/api v0.56.0/go.mod h1:38yMfeP1kfjsl8isn0tliTjIb1rJXcQi4UXlbqivdVE=
google.golang.org/api v0.57.0/go.mod h1:dVPlbZyBo2/OjBpmvNdpn2GRm6rPy75jyU7bmhdrMgI=
google.golang.org/api v0.61.0 h1:TXXKS1slM3b2bZNJwD5DV/Tp6/M2cLzLOLh9PjDhrw8=
google.golang.org/api v0.61.0/go.mod h1:xQRti5UdCmoCEqFxcz93fTl338AVqDgyaDRuOZ3hg9I=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
Expand Down Expand Up @@ -490,6 +495,7 @@ google.golang.org/genproto v0.0.0-20210903162649-d08c68adba83/go.mod h1:eFjDcFEc
google.golang.org/genproto v0.0.0-20210909211513-a8c4777a87af/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY=
google.golang.org/genproto v0.0.0-20210924002016-3dee208752a0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20211118181313-81c1377c94b1/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20211206160659-862468c7d6e0 h1:c7yRRmuQiVMo+YppNj5MUREXUyc2lPo3DrtYMwaWQ28=
google.golang.org/genproto v0.0.0-20211206160659-862468c7d6e0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
Expand All @@ -515,6 +521,7 @@ google.golang.org/grpc v1.37.1/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQ
google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM=
google.golang.org/grpc v1.39.0/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE=
google.golang.org/grpc v1.39.1/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE=
google.golang.org/grpc v1.40.0 h1:AGJ0Ih4mHjSeibYkFGh1dD9KJ/eOtZ93I6hoHhukQ5Q=
google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34=
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
Expand Down
115 changes: 79 additions & 36 deletions main.go
Expand Up @@ -29,10 +29,12 @@ import (
"os/signal"
"runtime"
"strings"
"syscall"
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/idtoken"
)

type contextKey string
Expand All @@ -48,18 +50,24 @@ const UserAgent = "cloud-run-proxy/" + Version + " (" + OSArch + ")"
var (
flagHost = flag.String("host", "", "Cloud Run host for which to proxy")
flagBind = flag.String("bind", "127.0.0.1:8080", "local host:port on which to listen")
flagAudience = flag.String("audience", "", "override JWT audience value (aud)")
flagToken = flag.String("token", "", "override OIDC token")
flagPrependUserAgent = flag.Bool("prepend-user-agent", true, "prepend a custom User-Agent header to requests")
)

func main() {
if err := realMain(); err != nil {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()

if err := realMain(ctx); err != nil {
cancel()

fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1)
}
}

func realMain() error {
func realMain(ctx context.Context) error {
// Parse flags.
flag.Parse()
if *flagHost == "" {
Expand All @@ -69,18 +77,27 @@ func realMain() error {
return fmt.Errorf("missing -bind")
}

// Get the best token source.
tokenSource, err := findTokenSource(*flagToken)
if err != nil {
return fmt.Errorf("failed to find token source: %w", err)
}

// Build the remote host URL.
host, err := smartBuildHost(*flagHost)
if err != nil {
return fmt.Errorf("failed to parse host URL: %w", err)
}

// Compute the audience, default to the host. However, there might be cases
// where you want to specify a custom aud (such as when accessing through a
// load balancer).
audience := *flagAudience
if audience == "" {
audience = host.String()
}

// Get the best token source. Cloud Run expects the audience parameter to be
// the URL of the service.
tokenSource, err := findTokenSource(ctx, *flagToken, audience)
if err != nil {
return fmt.Errorf("failed to find token source: %w", err)
}

// Build the local bind URL.
bindHost, bindPort, err := net.SplitHostPort(*flagBind)
if err != nil {
Expand Down Expand Up @@ -112,15 +129,11 @@ func realMain() error {
}
}()

// Signal on stop.
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt)

// Wait for error or signal.
// Wait for stop
select {
case err := <-errCh:
return fmt.Errorf("server error: %w", err)
case <-stop:
case <-ctx.Done():
fmt.Fprint(os.Stderr, "\nserver is shutting down...\n")
}

Expand Down Expand Up @@ -160,19 +173,13 @@ func buildProxy(host, bind *url.URL, tokenSource oauth2.TokenSource) *httputil.R
return
}

// Get the id_token from the oauth token.
idTokenRaw := token.Extra("id_token")
if idTokenRaw == nil {
// Get the id_token.
idToken := token.AccessToken
if idToken == "" {
*r = *r.WithContext(context.WithValue(ctx, contextKeyError,
fmt.Errorf("missing id_token")))
return
}
idToken, ok := idTokenRaw.(string)
if !ok {
*r = *r.WithContext(context.WithValue(ctx, contextKeyError,
fmt.Errorf("id_token is not a string: %T", idTokenRaw)))
return
}

// Set a custom user-agent header.
if *flagPrependUserAgent {
Expand Down Expand Up @@ -219,28 +226,64 @@ func buildProxy(host, bind *url.URL, tokenSource oauth2.TokenSource) *httputil.R
return proxy
}

// findTokenSource fetches the reusable/cached oauth2 token source. If t is
// provided, that token is used as a static value. Othwerise, this attempts to
// get the renewable token from the environment (including via Application
// Default Credentials).
func findTokenSource(t string) (oauth2.TokenSource, error) {
// findTokenSource fetches the reusable/cached oauth2 token source. If rawToken
// is provided, that token is used as a static value and the audience parameter
// is ignored. Othwerise, this attempts to get the renewable token from the
// environment (via Application Default Credentials).
func findTokenSource(ctx context.Context, rawToken, audience string) (oauth2.TokenSource, error) {
// Prefer supplied value, usually from the flag.
if t != "" {
token := new(oauth2.Token).WithExtra(map[string]interface{}{
"id_token": t,
})
if rawToken != "" {
token := &oauth2.Token{AccessToken: rawToken}
return oauth2.StaticTokenSource(token), nil
}

// Try and find the default token from ADC.
ctx := context.Background()
tokenSource, err := google.DefaultTokenSource(ctx, cloudPlatformScope)
// Try to use the idtoken package, which will use the metadata service.
// However, the idtoken package does not work with gcloud's ADC, so we need to
// handle that case by falling back to default ADC search. However, the
// default ADC has a token at a different path, so we construct a custom token
// source for this edge case.
tokenSource, err := idtoken.NewTokenSource(ctx, audience)
if err != nil {
return nil, fmt.Errorf("failed to get default token source: %w", err)
// Return any unexpected error.
if !strings.Contains(err.Error(), "credential must be service_account") {
return nil, fmt.Errorf("failed to get idtoken source: %w", err)
}

// If we got this far, it means that we found ADC, but the ADC was supplied
// by a gcloud "authorized_user" instead of a service account. Thus we
// fallback to the default ADC search.
tokenSource, err = google.DefaultTokenSource(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get default token source: %w", err)
}
tokenSource = &idTokenFromDefaultTokenSource{TokenSource: tokenSource}
}
return oauth2.ReuseTokenSource(nil, tokenSource), nil
}

type idTokenFromDefaultTokenSource struct {
TokenSource oauth2.TokenSource
}

// Token extracts the id_token field from ADC from a default token source and
// puts the value into the AccessToken field.
func (s *idTokenFromDefaultTokenSource) Token() (*oauth2.Token, error) {
token, err := s.TokenSource.Token()
if err != nil {
return nil, err
}

idToken, ok := token.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("missing id_token")
}

return &oauth2.Token{
AccessToken: idToken,
Expiry: token.Expiry,
}, nil
}

// smartBuildHost parses the URL, handling the case where it's a real URL
// (https://foo.bar) or just a host (foo.bar). If it's just a host, the URL is
// assumed to be TLS.
Expand Down
11 changes: 8 additions & 3 deletions main_test.go
Expand Up @@ -15,6 +15,7 @@
package main

import (
"context"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -43,6 +44,8 @@ func testRandomPort(tb testing.TB) int {
func TestBuildProxy(t *testing.T) {
t.Parallel()

ctx := context.Background()

mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if got, want := r.Header.Get("Authorization"), "Bearer mytoken"; got != want {
Expand All @@ -62,7 +65,7 @@ func TestBuildProxy(t *testing.T) {
Host: fmt.Sprintf("localhost:%d", testRandomPort(t)),
}

src, err := findTokenSource("mytoken")
src, err := findTokenSource(ctx, "mytoken", "aud")
if err != nil {
t.Fatal(err)
}
Expand All @@ -81,10 +84,12 @@ func TestBuildProxy(t *testing.T) {
func TestFindTokenSource(t *testing.T) {
t.Parallel()

ctx := context.Background()

t.Run("static", func(t *testing.T) {
t.Parallel()

src, err := findTokenSource("mytoken")
src, err := findTokenSource(ctx, "mytoken", "aud")
if err != nil {
t.Fatal(err)
}
Expand All @@ -94,7 +99,7 @@ func TestFindTokenSource(t *testing.T) {
t.Fatal(err)
}

if got, want := token.Extra("id_token"), "mytoken"; got != want {
if got, want := token.AccessToken, "mytoken"; got != want {
t.Errorf("expected %q to be %q", got, want)
}
})
Expand Down

0 comments on commit 263984c

Please sign in to comment.