Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FetchJWTSVIDs function for workloadapi and jwtSource #187

Merged
merged 7 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
56 changes: 50 additions & 6 deletions v2/workloadapi/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,29 @@ func (c *Client) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*jwts
return nil, err
}

if len(resp.Svids) == 0 {
return nil, errors.New("there were no SVIDs in the response")
svids, err := parseJWTSVIDs(resp, audience, true)
if err != nil {
return nil, err
}

return svids[0], nil
}

// FetchJWTSVIDs fetches all JWT-SVIDs.
func (c *Client) FetchJWTSVIDs(ctx context.Context, params jwtsvid.Params) ([]*jwtsvid.SVID, error) {
ctx, cancel := context.WithCancel(withHeader(ctx))
defer cancel()

audience := append([]string{params.Audience}, params.ExtraAudiences...)
resp, err := c.wlClient.FetchJWTSVID(ctx, &workload.JWTSVIDRequest{
SpiffeId: params.Subject.String(),
Audience: audience,
})
if err != nil {
return nil, err
}
return jwtsvid.ParseInsecure(resp.Svids[0].Svid, audience)

return parseJWTSVIDs(resp, audience, false)
}

// FetchJWTBundles fetches the JWT bundles for JWT-SVID validation, keyed
Expand Down Expand Up @@ -357,6 +376,9 @@ func parseX509Context(resp *workload.X509SVIDResponse) (*X509Context, error) {
// Otherwise all SVIDs are parsed and returned.
func parseX509SVIDs(resp *workload.X509SVIDResponse, firstOnly bool) ([]*x509svid.SVID, error) {
n := len(resp.Svids)
if n == 0 {
return nil, errors.New("no SVIDs in response")
}
if firstOnly {
n = 1
}
Expand All @@ -371,9 +393,6 @@ func parseX509SVIDs(resp *workload.X509SVIDResponse, firstOnly bool) ([]*x509svi
svids = append(svids, s)
}

if len(svids) == 0 {
return nil, errors.New("no SVIDs in response")
}
return svids, nil
}

Expand Down Expand Up @@ -413,6 +432,31 @@ func parseX509Bundle(spiffeID string, bundle []byte) (*x509bundle.Bundle, error)
return x509bundle.FromX509Authorities(td, certs), nil
}

// parseJWTSVIDs parses one or all of the SVIDs in the response. If firstOnly
// is true, then only the first SVID in the response is parsed and returned.
// Otherwise all SVIDs are parsed and returned.
func parseJWTSVIDs(resp *workload.JWTSVIDResponse, audience []string, firstOnly bool) ([]*jwtsvid.SVID, error) {
n := len(resp.Svids)
if n == 0 {
return nil, errors.New("there were no SVIDs in the response")
}
if firstOnly {
n = 1
}

svids := make([]*jwtsvid.SVID, 0, n)
for i := 0; i < n; i++ {
svid := resp.Svids[i]
s, err := jwtsvid.ParseInsecure(svid.Svid, audience)
if err != nil {
return nil, err
}
svids = append(svids, s)
}

return svids, nil
}

func parseJWTSVIDBundles(resp *workload.JWTBundlesResponse) (*jwtbundle.Set, error) {
bundles := []*jwtbundle.Bundle{}

Expand Down
49 changes: 40 additions & 9 deletions v2/workloadapi/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ func TestFetchJWTSVID(t *testing.T) {
subjectID := spiffeid.RequireFromPath(td, "/subject")
audienceID := spiffeid.RequireFromPath(td, "/audience")
extraAudienceID := spiffeid.RequireFromPath(td, "/extra_audience")
token := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()})
respJWT := makeJWTSVIDResponse(subjectID.String(), token)
token := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal()
respJWT := makeJWTSVIDResponse([]string{token}, subjectID)
wl.SetJWTSVIDResponse(respJWT)

params := jwtsvid.Params{
Expand All @@ -245,7 +245,36 @@ func TestFetchJWTSVID(t *testing.T) {
jwtSvid, err := c.FetchJWTSVID(context.Background(), params)

require.NoError(t, err)
assertJWTSVID(t, jwtSvid, subjectID, token.Marshal(), audienceID.String(), extraAudienceID.String())
assertJWTSVID(t, jwtSvid, subjectID, token, audienceID.String(), extraAudienceID.String())
}

func TestFetchJWTSVIDs(t *testing.T) {
ca := test.NewCA(t, td)
wl := fakeworkloadapi.New(t)
defer wl.Stop()
c, _ := New(context.Background(), WithAddr(wl.Addr()))
defer c.Close()

subjectID := spiffeid.RequireFromPath(td, "/subject")
extraSubjectID := spiffeid.RequireFromPath(td, "/extra_subject")
audienceID := spiffeid.RequireFromPath(td, "/audience")
extraAudienceID := spiffeid.RequireFromPath(td, "/extra_audience")
subjectIDToken := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal()
extraSubjectIDToken := ca.CreateJWTSVID(extraSubjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal()
respJWT := makeJWTSVIDResponse([]string{subjectIDToken, extraSubjectIDToken}, subjectID, extraSubjectID)
wl.SetJWTSVIDResponse(respJWT)

params := jwtsvid.Params{
Subject: subjectID,
Audience: audienceID.String(),
ExtraAudiences: []string{extraAudienceID.String()},
}

jwtSvid, err := c.FetchJWTSVIDs(context.Background(), params)

require.NoError(t, err)
assertJWTSVID(t, jwtSvid[0], subjectID, subjectIDToken, audienceID.String(), extraAudienceID.String())
assertJWTSVID(t, jwtSvid[1], extraSubjectID, extraSubjectIDToken, audienceID.String(), extraAudienceID.String())
}

func TestFetchJWTBundles(t *testing.T) {
Expand Down Expand Up @@ -357,12 +386,14 @@ func makeX509SVIDs(ca *test.CA, ids ...spiffeid.ID) []*x509svid.SVID {
return svids
}

func makeJWTSVIDResponse(spiffeID string, token *jwtsvid.SVID) *workload.JWTSVIDResponse {
svids := []*workload.JWTSVID{
{
SpiffeId: spiffeID,
Svid: token.Marshal(),
},
func makeJWTSVIDResponse(token []string, ids ...spiffeid.ID) *workload.JWTSVIDResponse {
svids := []*workload.JWTSVID{}
for i, id := range ids {
svid := &workload.JWTSVID{
SpiffeId: id.String(),
Svid: token[i],
}
svids = append(svids, svid)
}
return &workload.JWTSVIDResponse{
Svids: svids,
Expand Down
10 changes: 10 additions & 0 deletions v2/workloadapi/convenience.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ func FetchJWTSVID(ctx context.Context, params jwtsvid.Params, options ...ClientO
return c.FetchJWTSVID(ctx, params)
}

// FetchJWTSVID fetches all JWT-SVIDs.
func FetchJWTSVIDs(ctx context.Context, params jwtsvid.Params, options ...ClientOption) ([]*jwtsvid.SVID, error) {
c, err := New(ctx, options...)
if err != nil {
return nil, err
}
defer c.Close()
return c.FetchJWTSVIDs(ctx, params)
}

// FetchJWTBundles fetches the JWT bundles for JWT-SVID validation, keyed
// by a SPIFFE ID of the trust domain to which they belong.
func FetchJWTBundles(ctx context.Context, options ...ClientOption) (*jwtbundle.Set, error) {
Expand Down
9 changes: 9 additions & 0 deletions v2/workloadapi/jwtsource.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ func (s *JWTSource) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*j
return s.watcher.client.FetchJWTSVID(ctx, params)
}

// FetchJWTSVIDs fetches all JWT-SVIDs from the source with the given parameters.
// It implements the jwtsvid.Source interface.
func (s *JWTSource) FetchJWTSVIDs(ctx context.Context, params jwtsvid.Params) ([]*jwtsvid.SVID, error) {
if err := s.checkClosed(); err != nil {
return nil, err
}
return s.watcher.client.FetchJWTSVIDs(ctx, params)
}

// GetJWTBundleForTrustDomain returns the JWT bundle for the given trust
// domain. It implements the jwtbundle.Source interface.
func (s *JWTSource) GetJWTBundleForTrustDomain(trustDomain spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
Expand Down
1 change: 1 addition & 0 deletions v2/workloadapi/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type sourceClient interface {
WatchX509Context(context.Context, X509ContextWatcher) error
WatchJWTBundles(context.Context, JWTBundleWatcher) error
FetchJWTSVID(context.Context, jwtsvid.Params) (*jwtsvid.SVID, error)
FetchJWTSVIDs(context.Context, jwtsvid.Params) ([]*jwtsvid.SVID, error)
Close() error
}

Expand Down