diff --git a/README.md b/README.md index 92c9ea6..60958dd 100644 --- a/README.md +++ b/README.md @@ -63,3 +63,15 @@ Override the token (useful if you don't have gcloud installed): ```sh cloud-run-proxy -token "yc..." ``` + +Specify a custom audience: + +```sh +cloud-run-proxy -audience "https://my-service-daga283.run.app" +``` + +Note: when running on Compute Engine or other services with a metadata service, the audience defaults to the host URL. If you are accessing your Cloud Run service through a load balancer with a vanity domain, you must specify the audience value as the non-vanity URL of your service: + +```sh +cloud-run-proxy -host "https://custom-domain.com" -audience "https://my-service-daga283.run.app" +``` diff --git a/go.mod b/go.mod index bbdd096..42b2a7d 100644 --- a/go.mod +++ b/go.mod @@ -16,12 +16,21 @@ module github.com/GoogleCloudPlatform/cloud-run-proxy go 1.17 -require golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 +require ( + golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 + google.golang.org/api v0.61.0 +) require ( cloud.google.com/go v0.99.0 // indirect + github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect github.com/golang/protobuf v1.5.2 // indirect + go.opencensus.io v0.23.0 // indirect golang.org/x/net v0.0.0-20211209124913-491a49abca63 // indirect + golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 // indirect + golang.org/x/text v0.3.6 // indirect google.golang.org/appengine v1.6.7 // indirect + google.golang.org/genproto v0.0.0-20211206160659-862468c7d6e0 // indirect + google.golang.org/grpc v1.40.0 // indirect google.golang.org/protobuf v1.27.1 // indirect ) diff --git a/go.sum b/go.sum index b4a29af..c6fb49d 100644 --- a/go.sum +++ b/go.sum @@ -75,6 +75,7 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2 github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= @@ -174,6 +175,7 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= +go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -326,6 +328,7 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210823070655-63515b42dcdf/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210908233432-aa78b53d3365/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 h1:TyHqChC80pFkXWraUUf6RuB5IqFdQieMLwwCJokV2pc= golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -335,6 +338,7 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -424,6 +428,7 @@ google.golang.org/api v0.54.0/go.mod h1:7C4bFFOvVDGXjfDTAsgGwDgAxRDeQ4X8NvUedIt6 google.golang.org/api v0.55.0/go.mod h1:38yMfeP1kfjsl8isn0tliTjIb1rJXcQi4UXlbqivdVE= google.golang.org/api v0.56.0/go.mod h1:38yMfeP1kfjsl8isn0tliTjIb1rJXcQi4UXlbqivdVE= google.golang.org/api v0.57.0/go.mod h1:dVPlbZyBo2/OjBpmvNdpn2GRm6rPy75jyU7bmhdrMgI= +google.golang.org/api v0.61.0 h1:TXXKS1slM3b2bZNJwD5DV/Tp6/M2cLzLOLh9PjDhrw8= google.golang.org/api v0.61.0/go.mod h1:xQRti5UdCmoCEqFxcz93fTl338AVqDgyaDRuOZ3hg9I= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -490,6 +495,7 @@ google.golang.org/genproto v0.0.0-20210903162649-d08c68adba83/go.mod h1:eFjDcFEc google.golang.org/genproto v0.0.0-20210909211513-a8c4777a87af/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= google.golang.org/genproto v0.0.0-20210924002016-3dee208752a0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20211118181313-81c1377c94b1/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= +google.golang.org/genproto v0.0.0-20211206160659-862468c7d6e0 h1:c7yRRmuQiVMo+YppNj5MUREXUyc2lPo3DrtYMwaWQ28= google.golang.org/genproto v0.0.0-20211206160659-862468c7d6e0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= @@ -515,6 +521,7 @@ google.golang.org/grpc v1.37.1/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQ google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/grpc v1.39.0/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE= google.golang.org/grpc v1.39.1/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE= +google.golang.org/grpc v1.40.0 h1:AGJ0Ih4mHjSeibYkFGh1dD9KJ/eOtZ93I6hoHhukQ5Q= google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= diff --git a/main.go b/main.go index 25947b2..0a57a95 100644 --- a/main.go +++ b/main.go @@ -29,10 +29,12 @@ import ( "os/signal" "runtime" "strings" + "syscall" "time" "golang.org/x/oauth2" "golang.org/x/oauth2/google" + "google.golang.org/api/idtoken" ) type contextKey string @@ -48,18 +50,24 @@ const UserAgent = "cloud-run-proxy/" + Version + " (" + OSArch + ")" var ( flagHost = flag.String("host", "", "Cloud Run host for which to proxy") flagBind = flag.String("bind", "127.0.0.1:8080", "local host:port on which to listen") + flagAudience = flag.String("audience", "", "override JWT audience value (aud)") flagToken = flag.String("token", "", "override OIDC token") flagPrependUserAgent = flag.Bool("prepend-user-agent", true, "prepend a custom User-Agent header to requests") ) func main() { - if err := realMain(); err != nil { + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + if err := realMain(ctx); err != nil { + cancel() + fmt.Fprintln(os.Stderr, err.Error()) os.Exit(1) } } -func realMain() error { +func realMain(ctx context.Context) error { // Parse flags. flag.Parse() if *flagHost == "" { @@ -69,18 +77,27 @@ func realMain() error { return fmt.Errorf("missing -bind") } - // Get the best token source. - tokenSource, err := findTokenSource(*flagToken) - if err != nil { - return fmt.Errorf("failed to find token source: %w", err) - } - // Build the remote host URL. host, err := smartBuildHost(*flagHost) if err != nil { return fmt.Errorf("failed to parse host URL: %w", err) } + // Compute the audience, default to the host. However, there might be cases + // where you want to specify a custom aud (such as when accessing through a + // load balancer). + audience := *flagAudience + if audience == "" { + audience = host.String() + } + + // Get the best token source. Cloud Run expects the audience parameter to be + // the URL of the service. + tokenSource, err := findTokenSource(ctx, *flagToken, audience) + if err != nil { + return fmt.Errorf("failed to find token source: %w", err) + } + // Build the local bind URL. bindHost, bindPort, err := net.SplitHostPort(*flagBind) if err != nil { @@ -112,15 +129,11 @@ func realMain() error { } }() - // Signal on stop. - stop := make(chan os.Signal, 1) - signal.Notify(stop, os.Interrupt) - - // Wait for error or signal. + // Wait for stop select { case err := <-errCh: return fmt.Errorf("server error: %w", err) - case <-stop: + case <-ctx.Done(): fmt.Fprint(os.Stderr, "\nserver is shutting down...\n") } @@ -160,19 +173,13 @@ func buildProxy(host, bind *url.URL, tokenSource oauth2.TokenSource) *httputil.R return } - // Get the id_token from the oauth token. - idTokenRaw := token.Extra("id_token") - if idTokenRaw == nil { + // Get the id_token. + idToken := token.AccessToken + if idToken == "" { *r = *r.WithContext(context.WithValue(ctx, contextKeyError, fmt.Errorf("missing id_token"))) return } - idToken, ok := idTokenRaw.(string) - if !ok { - *r = *r.WithContext(context.WithValue(ctx, contextKeyError, - fmt.Errorf("id_token is not a string: %T", idTokenRaw))) - return - } // Set a custom user-agent header. if *flagPrependUserAgent { @@ -219,28 +226,64 @@ func buildProxy(host, bind *url.URL, tokenSource oauth2.TokenSource) *httputil.R return proxy } -// findTokenSource fetches the reusable/cached oauth2 token source. If t is -// provided, that token is used as a static value. Othwerise, this attempts to -// get the renewable token from the environment (including via Application -// Default Credentials). -func findTokenSource(t string) (oauth2.TokenSource, error) { +// findTokenSource fetches the reusable/cached oauth2 token source. If rawToken +// is provided, that token is used as a static value and the audience parameter +// is ignored. Othwerise, this attempts to get the renewable token from the +// environment (via Application Default Credentials). +func findTokenSource(ctx context.Context, rawToken, audience string) (oauth2.TokenSource, error) { // Prefer supplied value, usually from the flag. - if t != "" { - token := new(oauth2.Token).WithExtra(map[string]interface{}{ - "id_token": t, - }) + if rawToken != "" { + token := &oauth2.Token{AccessToken: rawToken} return oauth2.StaticTokenSource(token), nil } - // Try and find the default token from ADC. - ctx := context.Background() - tokenSource, err := google.DefaultTokenSource(ctx, cloudPlatformScope) + // Try to use the idtoken package, which will use the metadata service. + // However, the idtoken package does not work with gcloud's ADC, so we need to + // handle that case by falling back to default ADC search. However, the + // default ADC has a token at a different path, so we construct a custom token + // source for this edge case. + tokenSource, err := idtoken.NewTokenSource(ctx, audience) if err != nil { - return nil, fmt.Errorf("failed to get default token source: %w", err) + // Return any unexpected error. + if !strings.Contains(err.Error(), "credential must be service_account") { + return nil, fmt.Errorf("failed to get idtoken source: %w", err) + } + + // If we got this far, it means that we found ADC, but the ADC was supplied + // by a gcloud "authorized_user" instead of a service account. Thus we + // fallback to the default ADC search. + tokenSource, err = google.DefaultTokenSource(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get default token source: %w", err) + } + tokenSource = &idTokenFromDefaultTokenSource{TokenSource: tokenSource} } return oauth2.ReuseTokenSource(nil, tokenSource), nil } +type idTokenFromDefaultTokenSource struct { + TokenSource oauth2.TokenSource +} + +// Token extracts the id_token field from ADC from a default token source and +// puts the value into the AccessToken field. +func (s *idTokenFromDefaultTokenSource) Token() (*oauth2.Token, error) { + token, err := s.TokenSource.Token() + if err != nil { + return nil, err + } + + idToken, ok := token.Extra("id_token").(string) + if !ok { + return nil, fmt.Errorf("missing id_token") + } + + return &oauth2.Token{ + AccessToken: idToken, + Expiry: token.Expiry, + }, nil +} + // smartBuildHost parses the URL, handling the case where it's a real URL // (https://foo.bar) or just a host (foo.bar). If it's just a host, the URL is // assumed to be TLS. diff --git a/main_test.go b/main_test.go index 8a1df87..a5b99f3 100644 --- a/main_test.go +++ b/main_test.go @@ -15,6 +15,7 @@ package main import ( + "context" "fmt" "net" "net/http" @@ -43,6 +44,8 @@ func testRandomPort(tb testing.TB) int { func TestBuildProxy(t *testing.T) { t.Parallel() + ctx := context.Background() + mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if got, want := r.Header.Get("Authorization"), "Bearer mytoken"; got != want { @@ -62,7 +65,7 @@ func TestBuildProxy(t *testing.T) { Host: fmt.Sprintf("localhost:%d", testRandomPort(t)), } - src, err := findTokenSource("mytoken") + src, err := findTokenSource(ctx, "mytoken", "aud") if err != nil { t.Fatal(err) } @@ -81,10 +84,12 @@ func TestBuildProxy(t *testing.T) { func TestFindTokenSource(t *testing.T) { t.Parallel() + ctx := context.Background() + t.Run("static", func(t *testing.T) { t.Parallel() - src, err := findTokenSource("mytoken") + src, err := findTokenSource(ctx, "mytoken", "aud") if err != nil { t.Fatal(err) } @@ -94,7 +99,7 @@ func TestFindTokenSource(t *testing.T) { t.Fatal(err) } - if got, want := token.Extra("id_token"), "mytoken"; got != want { + if got, want := token.AccessToken, "mytoken"; got != want { t.Errorf("expected %q to be %q", got, want) } })