Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for pulling ID Tokens from the metadata server #8

Merged
merged 3 commits into from Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Expand Up @@ -63,3 +63,15 @@ Override the token (useful if you don't have gcloud installed):
```sh
cloud-run-proxy -token "yc..."
```

Specify a custom audience:

```sh
cloud-run-proxy -audience "https://my-service-daga283.run.app"
```

Note: when running on Compute Engine or other services with a metadata service, the audience defaults to the host URL. If you are accessing your Cloud Run service through a load balancer with a vanity domain, you must specify the audience value as the non-vanity URL of your service:

```sh
cloud-run-proxy -host "https://custom-domain.com" -audience "https://my-service-daga283.run.app"
```
11 changes: 10 additions & 1 deletion go.mod
Expand Up @@ -16,12 +16,21 @@ module github.com/GoogleCloudPlatform/cloud-run-proxy

go 1.17

require golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8
require (
golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8
google.golang.org/api v0.61.0
)

require (
cloud.google.com/go v0.99.0 // indirect
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect
github.com/golang/protobuf v1.5.2 // indirect
go.opencensus.io v0.23.0 // indirect
golang.org/x/net v0.0.0-20211209124913-491a49abca63 // indirect
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 // indirect
golang.org/x/text v0.3.6 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20211206160659-862468c7d6e0 // indirect
google.golang.org/grpc v1.40.0 // indirect
google.golang.org/protobuf v1.27.1 // indirect
)
7 changes: 7 additions & 0 deletions go.sum
Expand Up @@ -75,6 +75,7 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
Expand Down Expand Up @@ -174,6 +175,7 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
Expand Down Expand Up @@ -326,6 +328,7 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210823070655-63515b42dcdf/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210908233432-aa78b53d3365/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 h1:TyHqChC80pFkXWraUUf6RuB5IqFdQieMLwwCJokV2pc=
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand All @@ -335,6 +338,7 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
Expand Down Expand Up @@ -424,6 +428,7 @@ google.golang.org/api v0.54.0/go.mod h1:7C4bFFOvVDGXjfDTAsgGwDgAxRDeQ4X8NvUedIt6
google.golang.org/api v0.55.0/go.mod h1:38yMfeP1kfjsl8isn0tliTjIb1rJXcQi4UXlbqivdVE=
google.golang.org/api v0.56.0/go.mod h1:38yMfeP1kfjsl8isn0tliTjIb1rJXcQi4UXlbqivdVE=
google.golang.org/api v0.57.0/go.mod h1:dVPlbZyBo2/OjBpmvNdpn2GRm6rPy75jyU7bmhdrMgI=
google.golang.org/api v0.61.0 h1:TXXKS1slM3b2bZNJwD5DV/Tp6/M2cLzLOLh9PjDhrw8=
google.golang.org/api v0.61.0/go.mod h1:xQRti5UdCmoCEqFxcz93fTl338AVqDgyaDRuOZ3hg9I=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
Expand Down Expand Up @@ -490,6 +495,7 @@ google.golang.org/genproto v0.0.0-20210903162649-d08c68adba83/go.mod h1:eFjDcFEc
google.golang.org/genproto v0.0.0-20210909211513-a8c4777a87af/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY=
google.golang.org/genproto v0.0.0-20210924002016-3dee208752a0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20211118181313-81c1377c94b1/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/genproto v0.0.0-20211206160659-862468c7d6e0 h1:c7yRRmuQiVMo+YppNj5MUREXUyc2lPo3DrtYMwaWQ28=
google.golang.org/genproto v0.0.0-20211206160659-862468c7d6e0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
Expand All @@ -515,6 +521,7 @@ google.golang.org/grpc v1.37.1/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQ
google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM=
google.golang.org/grpc v1.39.0/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE=
google.golang.org/grpc v1.39.1/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE=
google.golang.org/grpc v1.40.0 h1:AGJ0Ih4mHjSeibYkFGh1dD9KJ/eOtZ93I6hoHhukQ5Q=
google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34=
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
Expand Down
111 changes: 75 additions & 36 deletions main.go
Expand Up @@ -29,10 +29,12 @@ import (
"os/signal"
"runtime"
"strings"
"syscall"
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/idtoken"
)

type contextKey string
Expand All @@ -48,18 +50,24 @@ const UserAgent = "cloud-run-proxy/" + Version + " (" + OSArch + ")"
var (
flagHost = flag.String("host", "", "Cloud Run host for which to proxy")
flagBind = flag.String("bind", "127.0.0.1:8080", "local host:port on which to listen")
flagAudience = flag.String("audience", "", "override JWT audience value (aud)")
flagToken = flag.String("token", "", "override OIDC token")
flagPrependUserAgent = flag.Bool("prepend-user-agent", true, "prepend a custom User-Agent header to requests")
)

func main() {
if err := realMain(); err != nil {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()

if err := realMain(ctx); err != nil {
cancel()
sethvargo marked this conversation as resolved.
Show resolved Hide resolved

fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1)
}
}

func realMain() error {
func realMain(ctx context.Context) error {
// Parse flags.
flag.Parse()
if *flagHost == "" {
Expand All @@ -69,18 +77,27 @@ func realMain() error {
return fmt.Errorf("missing -bind")
}

// Get the best token source.
tokenSource, err := findTokenSource(*flagToken)
if err != nil {
return fmt.Errorf("failed to find token source: %w", err)
}

// Build the remote host URL.
host, err := smartBuildHost(*flagHost)
if err != nil {
return fmt.Errorf("failed to parse host URL: %w", err)
}

// Compute the audience, default to the host. However, there might be cases
// where you want to specify a custom aud (such as when accessing through a
// load balancer).
audience := *flagAudience
if audience == "" {
audience = host.String()
}

// Get the best token source. Cloud Run expects the audience parameter to be
// the URL of the service.
tokenSource, err := findTokenSource(ctx, *flagToken, audience)
if err != nil {
return fmt.Errorf("failed to find token source: %w", err)
}

// Build the local bind URL.
bindHost, bindPort, err := net.SplitHostPort(*flagBind)
if err != nil {
Expand Down Expand Up @@ -112,15 +129,11 @@ func realMain() error {
}
}()

// Signal on stop.
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt)

// Wait for error or signal.
// Wait for stop
select {
case err := <-errCh:
return fmt.Errorf("server error: %w", err)
case <-stop:
case <-ctx.Done():
fmt.Fprint(os.Stderr, "\nserver is shutting down...\n")
}

Expand Down Expand Up @@ -160,19 +173,13 @@ func buildProxy(host, bind *url.URL, tokenSource oauth2.TokenSource) *httputil.R
return
}

// Get the id_token from the oauth token.
idTokenRaw := token.Extra("id_token")
if idTokenRaw == nil {
// Get the id_token.
idToken := token.AccessToken
if idToken == "" {
*r = *r.WithContext(context.WithValue(ctx, contextKeyError,
fmt.Errorf("missing id_token")))
sethvargo marked this conversation as resolved.
Show resolved Hide resolved
return
}
idToken, ok := idTokenRaw.(string)
if !ok {
*r = *r.WithContext(context.WithValue(ctx, contextKeyError,
fmt.Errorf("id_token is not a string: %T", idTokenRaw)))
return
}

// Set a custom user-agent header.
if *flagPrependUserAgent {
Expand Down Expand Up @@ -219,28 +226,60 @@ func buildProxy(host, bind *url.URL, tokenSource oauth2.TokenSource) *httputil.R
return proxy
}

// findTokenSource fetches the reusable/cached oauth2 token source. If t is
// provided, that token is used as a static value. Othwerise, this attempts to
// get the renewable token from the environment (including via Application
// Default Credentials).
func findTokenSource(t string) (oauth2.TokenSource, error) {
// findTokenSource fetches the reusable/cached oauth2 token source. If rawToken
// is provided, that token is used as a static value and the audience parameter
// is ignored. Othwerise, this attempts to get the renewable token from the
// environment (via Application Default Credentials).
func findTokenSource(ctx context.Context, rawToken, audience string) (oauth2.TokenSource, error) {
// Prefer supplied value, usually from the flag.
if t != "" {
token := new(oauth2.Token).WithExtra(map[string]interface{}{
"id_token": t,
})
if rawToken != "" {
token := &oauth2.Token{AccessToken: rawToken}
return oauth2.StaticTokenSource(token), nil
}

// Try and find the default token from ADC.
ctx := context.Background()
tokenSource, err := google.DefaultTokenSource(ctx, cloudPlatformScope)
// Try to use the idtoken package, which will use the metadata service.
// However, the idtoken package does not work with gcloud, so we need to
sethvargo marked this conversation as resolved.
Show resolved Hide resolved
// handle that case by falling back to default ADC. However, the default ADC
// has a token at a different path, so we construct a custom token source for
// this edge case.
tokenSource, err := idtoken.NewTokenSource(ctx, audience)
if err != nil {
return nil, fmt.Errorf("failed to get default token source: %w", err)
if !strings.Contains(err.Error(), "credential must be service_account") {
sethvargo marked this conversation as resolved.
Show resolved Hide resolved
return nil, fmt.Errorf("failed to get idtoken source: %w", err)
}

tokenSource, err = google.DefaultTokenSource(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get default token source: %w", err)
}
tokenSource = &idTokenFromDefaultTokenSource{TokenSource: tokenSource}
}
return oauth2.ReuseTokenSource(nil, tokenSource), nil
}

type idTokenFromDefaultTokenSource struct {
TokenSource oauth2.TokenSource
}

// Token extracts the id_token field from ADC from a default token source and
// puts the value into the AccessToken field.
func (s *idTokenFromDefaultTokenSource) Token() (*oauth2.Token, error) {
token, err := s.TokenSource.Token()
if err != nil {
return nil, err
}

idToken, ok := token.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("missing id_token")
}

return &oauth2.Token{
AccessToken: idToken,
Expiry: token.Expiry,
}, nil
}

// smartBuildHost parses the URL, handling the case where it's a real URL
// (https://foo.bar) or just a host (foo.bar). If it's just a host, the URL is
// assumed to be TLS.
Expand Down
11 changes: 8 additions & 3 deletions main_test.go
Expand Up @@ -15,6 +15,7 @@
package main

import (
"context"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -43,6 +44,8 @@ func testRandomPort(tb testing.TB) int {
func TestBuildProxy(t *testing.T) {
t.Parallel()

ctx := context.Background()

mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if got, want := r.Header.Get("Authorization"), "Bearer mytoken"; got != want {
Expand All @@ -62,7 +65,7 @@ func TestBuildProxy(t *testing.T) {
Host: fmt.Sprintf("localhost:%d", testRandomPort(t)),
}

src, err := findTokenSource("mytoken")
src, err := findTokenSource(ctx, "mytoken", "aud")
if err != nil {
t.Fatal(err)
}
Expand All @@ -81,10 +84,12 @@ func TestBuildProxy(t *testing.T) {
func TestFindTokenSource(t *testing.T) {
t.Parallel()

ctx := context.Background()

t.Run("static", func(t *testing.T) {
t.Parallel()

src, err := findTokenSource("mytoken")
src, err := findTokenSource(ctx, "mytoken", "aud")
if err != nil {
t.Fatal(err)
}
Expand All @@ -94,7 +99,7 @@ func TestFindTokenSource(t *testing.T) {
t.Fatal(err)
}

if got, want := token.Extra("id_token"), "mytoken"; got != want {
if got, want := token.AccessToken, "mytoken"; got != want {
t.Errorf("expected %q to be %q", got, want)
}
})
Expand Down