Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method WatchX509Bundles #192

Merged
merged 3 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions v2/bundle/x509bundle/bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ func Parse(trustDomain spiffeid.TrustDomain, b []byte) (*Bundle, error) {
return bundle, nil
}

// ParseRaw parses a bundle from bytes. The certificate must be ASN.1 DER (concatenated
// with no intermediate padding if there are more than one certificate)
func ParseRaw(trustDomain spiffeid.TrustDomain, b []byte) (*Bundle, error) {
bundle := New(trustDomain)
certs, err := x509.ParseCertificates(b)
if err != nil {
return nil, x509bundleErr.New("cannot parse certificate: %v", err)
}
if len(certs) == 0 {
return nil, x509bundleErr.New("no certificates found")
}
for _, cert := range certs {
bundle.AddX509Authority(cert)
}
return bundle, nil
}

// TrustDomain returns the trust domain that the bundle belongs to.
func (b *Bundle) TrustDomain() spiffeid.TrustDomain {
return b.trustDomain
Expand Down
58 changes: 58 additions & 0 deletions v2/bundle/x509bundle/bundle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/internal/pemutil"
"github.com/spiffe/go-spiffe/v2/internal/test"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -136,6 +137,49 @@ func TestParse(t *testing.T) {
}
}

func TestParseRaw(t *testing.T) {
tests := []struct {
name string
trustDomain spiffeid.TrustDomain
path string
expNumAuthorities int
expErrContains string
}{
{
name: "Parse multiple certificates should succeed",
path: "testdata/certs.pem",
expNumAuthorities: 2,
},
{
name: "Parse single certificate should succeed",
path: "testdata/cert.pem",
expNumAuthorities: 1,
},
{
name: "Parse should fail if no certificate block is is found",
path: "testdata/key.pem",
expErrContains: "x509bundle: no certificates found",
},
}

for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
certsBytes := loadRawCertificates(t, test.path)
bundle, err := x509bundle.ParseRaw(td, certsBytes)

if test.expErrContains != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), test.expErrContains)
return
}
require.NoError(t, err)
assert.NotNil(t, bundle)
assert.Len(t, bundle.X509Authorities(), test.expNumAuthorities)
})
}
}

func TestX509AuthorityCRUD(t *testing.T) {
// Load bundle1, which contains a single certificate
bundle1, err := x509bundle.Load(td, "testdata/cert.pem")
Expand Down Expand Up @@ -274,3 +318,17 @@ func TestClone(t *testing.T) {
cloned := original.Clone()
require.True(t, original.Equal(cloned))
}

func loadRawCertificates(t *testing.T, path string) []byte {
certsBytes, err := ioutil.ReadFile(path)
require.NoError(t, err)

certs, err := pemutil.ParseCertificates(certsBytes)
require.NoError(t, err)

var rawBytes []byte
for _, cert := range certs {
rawBytes = append(rawBytes, cert.Raw...)
}
return rawBytes
}
102 changes: 90 additions & 12 deletions v2/internal/test/fakeworkloadapi/workload_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/internal/pemutil"
"github.com/spiffe/go-spiffe/v2/internal/x509util"
"github.com/spiffe/go-spiffe/v2/proto/spiffe/workload"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
Expand All @@ -29,22 +30,25 @@ import (
var noIdentityError = status.Error(codes.PermissionDenied, "no identity issued")

type WorkloadAPI struct {
tb testing.TB
wg sync.WaitGroup
addr string
server *grpc.Server
mu sync.Mutex
x509Resp *workload.X509SVIDResponse
x509Chans map[chan *workload.X509SVIDResponse]struct{}
jwtResp *workload.JWTSVIDResponse
jwtBundlesResp *workload.JWTBundlesResponse
jwtBundlesChans map[chan *workload.JWTBundlesResponse]struct{}
tb testing.TB
wg sync.WaitGroup
addr string
server *grpc.Server
mu sync.Mutex
x509Resp *workload.X509SVIDResponse
x509Chans map[chan *workload.X509SVIDResponse]struct{}
jwtResp *workload.JWTSVIDResponse
jwtBundlesResp *workload.JWTBundlesResponse
jwtBundlesChans map[chan *workload.JWTBundlesResponse]struct{}
x509BundlesResp *workload.X509BundlesResponse
x509BundlesChans map[chan *workload.X509BundlesResponse]struct{}
}

func New(tb testing.TB) *WorkloadAPI {
w := &WorkloadAPI{
x509Chans: make(map[chan *workload.X509SVIDResponse]struct{}),
jwtBundlesChans: make(map[chan *workload.JWTBundlesResponse]struct{}),
x509Chans: make(map[chan *workload.X509SVIDResponse]struct{}),
jwtBundlesChans: make(map[chan *workload.JWTBundlesResponse]struct{}),
x509BundlesChans: make(map[chan *workload.X509BundlesResponse]struct{}),
}

listener, err := net.Listen("tcp", "localhost:0")
Expand Down Expand Up @@ -126,6 +130,38 @@ func (w *WorkloadAPI) SetJWTBundles(jwtBundles ...*jwtbundle.Bundle) {
}
}

func (w *WorkloadAPI) SetX509Bundles(x509Bundles ...*x509bundle.Bundle) {
resp := &workload.X509BundlesResponse{
Bundles: make(map[string][]byte),
}
for _, bundle := range x509Bundles {
bundleBytes, err := bundle.Marshal()
assert.NoError(w.tb, err)
bundlePem, err := pemutil.ParseCertificates(bundleBytes)
assert.NoError(w.tb, err)

var rawBytes []byte
for _, c := range bundlePem {
rawBytes = append(rawBytes, c.Raw...)
}

resp.Bundles[bundle.TrustDomain().String()] = rawBytes
}

w.mu.Lock()
defer w.mu.Unlock()
w.x509BundlesResp = resp

for ch := range w.x509BundlesChans {
select {
case ch <- w.x509BundlesResp:
default:
<-ch
ch <- w.x509BundlesResp
}
}
}

type workloadAPIWrapper struct {
workload.UnimplementedSpiffeWorkloadAPIServer
w *WorkloadAPI
Expand All @@ -135,6 +171,10 @@ func (w *workloadAPIWrapper) FetchX509SVID(req *workload.X509SVIDRequest, stream
return w.w.fetchX509SVID(req, stream)
}

func (w *workloadAPIWrapper) FetchX509Bundles(req *workload.X509BundlesRequest, stream workload.SpiffeWorkloadAPI_FetchX509BundlesServer) error {
return w.w.fetchX509Bundles(req, stream)
}

func (w *workloadAPIWrapper) FetchJWTSVID(ctx context.Context, req *workload.JWTSVIDRequest) (*workload.JWTSVIDResponse, error) {
return w.w.fetchJWTSVID(ctx, req)
}
Expand Down Expand Up @@ -221,6 +261,44 @@ func (w *WorkloadAPI) fetchX509SVID(_ *workload.X509SVIDRequest, stream workload
}
}

func (w *WorkloadAPI) fetchX509Bundles(_ *workload.X509BundlesRequest, stream workload.SpiffeWorkloadAPI_FetchX509BundlesServer) error {
if err := checkHeader(stream.Context()); err != nil {
return err
}
ch := make(chan *workload.X509BundlesResponse, 1)
w.mu.Lock()
w.x509BundlesChans[ch] = struct{}{}
resp := w.x509BundlesResp
w.mu.Unlock()

defer func() {
w.mu.Lock()
delete(w.x509BundlesChans, ch)
w.mu.Unlock()
}()

sendResp := func(resp *workload.X509BundlesResponse) error {
if resp == nil {
return noIdentityError
}
return stream.Send(resp)
}

if err := sendResp(resp); err != nil {
return err
}
for {
select {
case resp := <-ch:
if err := sendResp(resp); err != nil {
return err
}
case <-stream.Context().Done():
return stream.Context().Err()
}
}
}

func (w *WorkloadAPI) fetchJWTSVID(ctx context.Context, req *workload.JWTSVIDRequest) (*workload.JWTSVIDResponse, error) {
if err := checkHeader(ctx); err != nil {
return nil, err
Expand Down
75 changes: 73 additions & 2 deletions v2/workloadapi/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (c *Client) FetchX509Bundles(ctx context.Context) (*x509bundle.Set, error)
ctx, cancel := context.WithCancel(withHeader(ctx))
defer cancel()

stream, err := c.wlClient.FetchX509SVID(ctx, &workload.X509SVIDRequest{})
stream, err := c.wlClient.FetchX509Bundles(ctx, &workload.X509BundlesRequest{})
if err != nil {
return nil, err
}
Expand All @@ -113,7 +113,21 @@ func (c *Client) FetchX509Bundles(ctx context.Context) (*x509bundle.Set, error)
return nil, err
}

return parseX509Bundles(resp)
return parseX509BundlesResponse(resp)
}

// WatchX509Bundles watches for changes to the X.509 bundles. The watcher receives
// the updated X.509 bundles.
func (c *Client) WatchX509Bundles(ctx context.Context, watcher X509BundleWatcher) error {
backoff := newBackoff()
for {
err := c.watchX509Bundles(ctx, watcher, backoff)
watcher.OnX509BundlesWatchError(err)
err = c.handleWatchError(ctx, err, backoff)
if err != nil {
return err
}
}
}

// FetchX509Context fetches the X.509 context, which contains both X509-SVIDs
Expand Down Expand Up @@ -321,6 +335,33 @@ func (c *Client) watchJWTBundles(ctx context.Context, watcher JWTBundleWatcher,
}
}

func (c *Client) watchX509Bundles(ctx context.Context, watcher X509BundleWatcher, backoff *backoff) error {
ctx, cancel := context.WithCancel(withHeader(ctx))
defer cancel()

c.config.log.Debugf("Watching X.509 bundles")
stream, err := c.wlClient.FetchX509Bundles(ctx, &workload.X509BundlesRequest{})
if err != nil {
return err
}

for {
resp, err := stream.Recv()
if err != nil {
return err
}

backoff.Reset()
x509bundleSet, err := parseX509BundlesResponse(resp)
if err != nil {
c.config.log.Errorf("Failed to parse X.509 bundle response: %v", err)
watcher.OnX509BundlesWatchError(err)
continue
}
watcher.OnX509BundlesUpdate(x509bundleSet)
}
}

// X509ContextWatcher receives X509Context updates from the Workload API.
type X509ContextWatcher interface {
// OnX509ContextUpdate is called with the latest X.509 context retrieved
Expand All @@ -343,6 +384,17 @@ type JWTBundleWatcher interface {
OnJWTBundlesWatchError(error)
}

// X509BundleWatcher receives X.509 bundle updates from the Workload API.
type X509BundleWatcher interface {
// OnX509BundlesUpdate is called with the latest X.509 bundle set retrieved
// from the Workload API.
OnX509BundlesUpdate(*x509bundle.Set)

// OnX509BundlesWatchError is called when there is a problem establishing
// or maintaining connectivity with the Workload API.
OnX509BundlesWatchError(error)
}

func withHeader(ctx context.Context) context.Context {
header := metadata.Pairs("workload.spiffe.io", "true")
return metadata.NewOutgoingContext(ctx, header)
Expand Down Expand Up @@ -432,6 +484,25 @@ func parseX509Bundle(spiffeID string, bundle []byte) (*x509bundle.Bundle, error)
return x509bundle.FromX509Authorities(td, certs), nil
}

func parseX509BundlesResponse(resp *workload.X509BundlesResponse) (*x509bundle.Set, error) {
bundles := []*x509bundle.Bundle{}

for tdID, b := range resp.Bundles {
td, err := spiffeid.TrustDomainFromString(tdID)
if err != nil {
return nil, err
}

b, err := x509bundle.ParseRaw(td, b)
if err != nil {
return nil, err
}
bundles = append(bundles, b)
}

return x509bundle.NewSet(bundles...), nil
}

// parseJWTSVIDs parses one or all of the SVIDs in the response. If firstOnly
// is true, then only the first SVID in the response is parsed and returned.
// Otherwise all SVIDs are parsed and returned.
Expand Down