Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FetchJWTSVIDs function for workloadapi and jwtSource #187

Merged
merged 7 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 47 additions & 3 deletions v2/workloadapi/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,29 @@ func (c *Client) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*jwts
return nil, err
}

if len(resp.Svids) == 0 {
return nil, errors.New("there were no SVIDs in the response")
svids, err := parseJWTSVIDs(resp, audience, true)
if err != nil {
return nil, err
}

return svids[0], nil
}

// FetchJWTSVIDs fetches all JWT-SVIDs.
func (c *Client) FetchJWTSVIDs(ctx context.Context, params jwtsvid.Params) ([]*jwtsvid.SVID, error) {
ctx, cancel := context.WithCancel(withHeader(ctx))
defer cancel()

audience := append([]string{params.Audience}, params.ExtraAudiences...)
resp, err := c.wlClient.FetchJWTSVID(ctx, &workload.JWTSVIDRequest{
SpiffeId: params.Subject.String(),
Audience: audience,
})
if err != nil {
return nil, err
}
return jwtsvid.ParseInsecure(resp.Svids[0].Svid, audience)

return parseJWTSVIDs(resp, audience, false)
}

// FetchJWTBundles fetches the JWT bundles for JWT-SVID validation, keyed
Expand Down Expand Up @@ -412,6 +431,31 @@ func parseX509Bundle(spiffeID string, bundle []byte) (*x509bundle.Bundle, error)
return x509bundle.FromX509Authorities(td, certs), nil
}

// parseJWTSVIDs parses one or all of the SVIDs in the response. If firstOnly
// is true, then only the first SVID in the response is parsed and returned.
// Otherwise all SVIDs are parsed and returned.
func parseJWTSVIDs(resp *workload.JWTSVIDResponse, audience []string, firstOnly bool) ([]*jwtsvid.SVID, error) {
n := len(resp.Svids)
if firstOnly {
n = 1
}

svids := make([]*jwtsvid.SVID, 0, n)
for i := 0; i < n; i++ {
svid := resp.Svids[i]
s, err := jwtsvid.ParseInsecure(svid.Svid, audience)
if err != nil {
return nil, err
}
svids = append(svids, s)
}

if len(svids) == 0 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check should move above the for loop and check the length of SVIDs in the response, otherwise, if the response has no SVIDs and firstOnly is true, we will panic on line 445.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I directly copied the code of parseX509SVIDs and modified it. I didn't notice it here. I think parseX509SVIDs also has this problem, so I modified it at the same time.
In addition, I would like to ask about java-spiffe, this spiffe/java-spiffe#90 not assigned to reviewer, any idea?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to ask about java-spiffe, this spiffe/java-spiffe#90 not assigned to reviewer, any idea?

@loveyana I'm reviewing it today.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to ask about java-spiffe, this spiffe/java-spiffe#90 not assigned to reviewer, any idea?

@loveyana I'm reviewing it today.

In addition, I may need help to look at the HewlettPackard/py-spiffe#105. It seems that the lint in CI on the py-spiffe side is wrong, but I passed it locally, and it seems that the previous pr also has the same problem. I hope you can help and review it at the same time

return nil, errors.New("there were no SVIDs in the response")
}
return svids, nil
}

func parseJWTSVIDBundles(resp *workload.JWTBundlesResponse) (*jwtbundle.Set, error) {
bundles := []*jwtbundle.Bundle{}

Expand Down
49 changes: 40 additions & 9 deletions v2/workloadapi/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ func TestFetchJWTSVID(t *testing.T) {
subjectID := spiffeid.RequireFromPath(td, "/subject")
audienceID := spiffeid.RequireFromPath(td, "/audience")
extraAudienceID := spiffeid.RequireFromPath(td, "/extra_audience")
token := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()})
respJWT := makeJWTSVIDResponse(subjectID.String(), token)
token := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal()
respJWT := makeJWTSVIDResponse(ca, []string{token}, subjectID)
wl.SetJWTSVIDResponse(respJWT)

params := jwtsvid.Params{
Expand All @@ -245,7 +245,36 @@ func TestFetchJWTSVID(t *testing.T) {
jwtSvid, err := c.FetchJWTSVID(context.Background(), params)

require.NoError(t, err)
assertJWTSVID(t, jwtSvid, subjectID, token.Marshal(), audienceID.String(), extraAudienceID.String())
assertJWTSVID(t, jwtSvid, subjectID, token, audienceID.String(), extraAudienceID.String())
}

func TestFetchJWTSVIDs(t *testing.T) {
ca := test.NewCA(t, td)
wl := fakeworkloadapi.New(t)
defer wl.Stop()
c, _ := New(context.Background(), WithAddr(wl.Addr()))
defer c.Close()

subjectID := spiffeid.RequireFromPath(td, "/subject")
extraSubjectID := spiffeid.RequireFromPath(td, "/extra_subject")
audienceID := spiffeid.RequireFromPath(td, "/audience")
extraAudienceID := spiffeid.RequireFromPath(td, "/extra_audience")
subjectIDToken := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal()
extraSubjectIDToken := ca.CreateJWTSVID(extraSubjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal()
respJWT := makeJWTSVIDResponse(ca, []string{subjectIDToken, extraSubjectIDToken}, subjectID, extraSubjectID)
wl.SetJWTSVIDResponse(respJWT)

params := jwtsvid.Params{
Subject: subjectID,
Audience: audienceID.String(),
ExtraAudiences: []string{extraAudienceID.String()},
}

jwtSvid, err := c.FetchJWTSVIDs(context.Background(), params)

require.NoError(t, err)
assertJWTSVID(t, jwtSvid[0], subjectID, subjectIDToken, audienceID.String(), extraAudienceID.String())
assertJWTSVID(t, jwtSvid[1], extraSubjectID, extraSubjectIDToken, audienceID.String(), extraAudienceID.String())
}

func TestFetchJWTBundles(t *testing.T) {
Expand Down Expand Up @@ -357,12 +386,14 @@ func makeX509SVIDs(ca *test.CA, ids ...spiffeid.ID) []*x509svid.SVID {
return svids
}

func makeJWTSVIDResponse(spiffeID string, token *jwtsvid.SVID) *workload.JWTSVIDResponse {
svids := []*workload.JWTSVID{
{
SpiffeId: spiffeID,
Svid: token.Marshal(),
},
func makeJWTSVIDResponse(ca *test.CA, token []string, ids ...spiffeid.ID) *workload.JWTSVIDResponse {
svids := []*workload.JWTSVID{}
for i, id := range ids {
svid := &workload.JWTSVID{
SpiffeId: id.String(),
Svid: token[i],
}
svids = append(svids, svid)
}
return &workload.JWTSVIDResponse{
Svids: svids,
Expand Down
10 changes: 10 additions & 0 deletions v2/workloadapi/convenience.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ func FetchJWTSVID(ctx context.Context, params jwtsvid.Params, options ...ClientO
return c.FetchJWTSVID(ctx, params)
}

// FetchJWTSVID fetches all JWT-SVIDs.
func FetchJWTSVIDs(ctx context.Context, params jwtsvid.Params, options ...ClientOption) ([]*jwtsvid.SVID, error) {
c, err := New(ctx, options...)
if err != nil {
return nil, err
}
defer c.Close()
return c.FetchJWTSVIDs(ctx, params)
}

// FetchJWTBundles fetches the JWT bundles for JWT-SVID validation, keyed
// by a SPIFFE ID of the trust domain to which they belong.
func FetchJWTBundles(ctx context.Context, options ...ClientOption) (*jwtbundle.Set, error) {
Expand Down
9 changes: 9 additions & 0 deletions v2/workloadapi/jwtsource.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ func (s *JWTSource) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*j
return s.watcher.client.FetchJWTSVID(ctx, params)
}

// FetchJWTSVIDs fetches all JWT-SVIDs from the source with the given parameters.
// It implements the jwtsvid.Source interface.
func (s *JWTSource) FetchJWTSVIDs(ctx context.Context, params jwtsvid.Params) ([]*jwtsvid.SVID, error) {
if err := s.checkClosed(); err != nil {
return nil, err
}
return s.watcher.client.FetchJWTSVIDs(ctx, params)
}

// GetJWTBundleForTrustDomain returns the JWT bundle for the given trust
// domain. It implements the jwtbundle.Source interface.
func (s *JWTSource) GetJWTBundleForTrustDomain(trustDomain spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
Expand Down
1 change: 1 addition & 0 deletions v2/workloadapi/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type sourceClient interface {
WatchX509Context(context.Context, X509ContextWatcher) error
WatchJWTBundles(context.Context, JWTBundleWatcher) error
FetchJWTSVID(context.Context, jwtsvid.Params) (*jwtsvid.SVID, error)
FetchJWTSVIDs(context.Context, jwtsvid.Params) ([]*jwtsvid.SVID, error)
Close() error
}

Expand Down