Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ErrorSink method #421

Merged
merged 5 commits into from Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions Changes
Expand Up @@ -6,6 +6,9 @@ v1.2.5 (UNRELEASED)
* Implement RFC7797. The value of the header field `b64` changes
how the payload is treated in JWS
* Implement detached payloads for JWS
* Implement (jwk.AutoRefresh).ErrorSink() to register a channel
where you can receive errors from fetches and parses that occur during
JWK(s) retrieval.

v1.2.4 15 Jul 2021
[Bug fixes]
Expand Down
5 changes: 5 additions & 0 deletions jwk/interface.go
Expand Up @@ -111,3 +111,8 @@ type HTTPClient interface {

type DecodeCtx = json.DecodeCtx
type KeyWithDecodeCtx = json.DecodeCtxContainer

type AutoRefreshError struct {
Error error
URL string
}
142 changes: 101 additions & 41 deletions jwk/refresh.go
Expand Up @@ -26,9 +26,12 @@ import (
// All JWKS objects that are retrieved via the auto-fetch mechanism should be
// treated read-only, as they are shared among the consumers and this object.
type AutoRefresh struct {
errDst chan AutoRefreshError // user-specified error sink
errSink chan AutoRefreshError // AutoRefresh's error sink
cache map[string]Set
configureCh chan struct{}
fetching map[string]chan struct{}
muErrDst sync.Mutex
muCache sync.RWMutex
muFetching sync.Mutex
muRegistry sync.RWMutex
Expand Down Expand Up @@ -108,13 +111,15 @@ type resetTimerReq struct {
// }
func NewAutoRefresh(ctx context.Context) *AutoRefresh {
af := &AutoRefresh{
errSink: make(chan AutoRefreshError, 1),
cache: make(map[string]Set),
configureCh: make(chan struct{}),
fetching: make(map[string]chan struct{}),
registry: make(map[string]*target),
resetTimerCh: make(chan *resetTimerReq),
}
go af.refreshLoop(ctx)
go af.drainErrSink(ctx)
return af
}

Expand Down Expand Up @@ -445,30 +450,49 @@ func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableB

res, err := fetch(ctx, url, options...)
if err == nil {
defer res.Body.Close()
keyset, parseErr := ParseReader(res.Body)
if parseErr == nil {
// Got a new key set. replace the keyset in the target
af.muCache.Lock()
af.cache[url] = keyset
af.muCache.Unlock()
nextInterval := calculateRefreshDuration(res, t.refreshInterval, t.minRefreshInterval)
rtr := &resetTimerReq{
t: t,
d: nextInterval,
}
select {
case <-ctx.Done():
return ctx.Err()
case af.resetTimerCh <- rtr:
if res.StatusCode != http.StatusOK {
// now, can there be a remote resource that responds with a status code
// other than 200 and still be valid...? naaaaaaahhhhhh....
err = errors.Errorf(`bad response status code (%d)`, res.StatusCode)
} else {
defer res.Body.Close()
keyset, parseErr := ParseReader(res.Body)
if parseErr == nil {
// Got a new key set. replace the keyset in the target
af.muCache.Lock()
af.cache[url] = keyset
af.muCache.Unlock()
nextInterval := calculateRefreshDuration(res, t.refreshInterval, t.minRefreshInterval)
rtr := &resetTimerReq{
t: t,
d: nextInterval,
}
select {
case <-ctx.Done():
return ctx.Err()
case af.resetTimerCh <- rtr:
}

now := time.Now()
t.lastRefresh = now.Local()
t.nextRefresh = now.Add(nextInterval).Local()
return nil
}
err = parseErr
}
}

now := time.Now()
t.lastRefresh = now.Local()
t.nextRefresh = now.Add(nextInterval).Local()
return nil
// At this point if err != nil, we know that there was something wrong
// in either the fetching or the parsing. Send this error to be processed,
// but take the extra mileage to not block regular processing by
// sending the error to a "proxy" sink, and not directly at the user-specified sink
// (see drainErrSink)
if err != nil && af.errSink != nil {
select {
case af.errSink <- AutoRefreshError{Error: err, URL: url}:
default:
panic("af.errSink is not draining")
}
err = parseErr
}

// We either failed to perform the HTTP GET, or we failed to parse the
Expand All @@ -479,7 +503,7 @@ func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableB
// If we failed to get a single time, then queue another fetch in the future.
rtr := &resetTimerReq{
t: t,
d: t.minRefreshInterval,
d: calculateRefreshDuration(res, t.refreshInterval, t.minRefreshInterval),
}
select {
case <-ctx.Done():
Expand All @@ -490,38 +514,74 @@ func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableB
return err
}

// drainErrSink is used proxy the errors that were sent to the main
// error sink (af.errSink) to the user specified error sink
func (af *AutoRefresh) drainErrSink(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case err := <-af.errSink:
af.muErrDst.Lock()
dst := af.errDst
af.muErrDst.Unlock()
if dst != nil {
// This will block if the user isn't properly draining the channel.
// It is the user's responsibility to drain it once they
// requested the errors to be streamed
dst <- err
}
}
}
}

// ErrorSink sets a channel to receive JWK fetch errors, if any.
// Only the errors that occurred *after* the channel was set will be sent.
//
// The user is responsible for properly draining the channel. If the channel
// is not drained, the fetch operation will block on repeated errors.
//
// To disable, set a nil channel.
func (af *AutoRefresh) ErrorSink(ch chan AutoRefreshError) {
af.muErrDst.Lock()
af.errDst = ch
af.muErrDst.Unlock()
}

func calculateRefreshDuration(res *http.Response, refreshInterval *time.Duration, minRefreshInterval time.Duration) time.Duration {
// This always has precedence
if refreshInterval != nil {
return *refreshInterval
}

if v := res.Header.Get(`Cache-Control`); v != "" {
dir, err := httpcc.ParseResponse(v)
if err == nil {
maxAge, ok := dir.MaxAge()
if ok {
resDuration := time.Duration(maxAge) * time.Second
if resDuration > minRefreshInterval {
return resDuration
if res != nil {
if v := res.Header.Get(`Cache-Control`); v != "" {
dir, err := httpcc.ParseResponse(v)
if err == nil {
maxAge, ok := dir.MaxAge()
if ok {
resDuration := time.Duration(maxAge) * time.Second
if resDuration > minRefreshInterval {
return resDuration
}
return minRefreshInterval
}
return minRefreshInterval
// fallthrough
}
// fallthrough
}
// fallthrough
}

if v := res.Header.Get(`Expires`); v != "" {
expires, err := http.ParseTime(v)
if err == nil {
resDuration := time.Until(expires)
if resDuration > minRefreshInterval {
return resDuration
if v := res.Header.Get(`Expires`); v != "" {
expires, err := http.ParseTime(v)
if err == nil {
resDuration := time.Until(expires)
if resDuration > minRefreshInterval {
return resDuration
}
return minRefreshInterval
}
return minRefreshInterval
// fallthrough
}
// fallthrough
}

// Previous fallthroughs are a little redandunt, but hey, it's all good.
Expand Down
60 changes: 60 additions & 0 deletions jwk/refresh_test.go
Expand Up @@ -281,3 +281,63 @@ func TestRefreshSnapshot(t *testing.T) {
t.Logf("%s last refreshed at %s, next refresh at %s", target.URL, target.LastRefresh, target.NextRefresh)
}
}

func TestErrorSink(t *testing.T) {
t.Parallel()

testcases := []struct {
Name string
Handler http.Handler
}{
{
Name: "non-200 response",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
}),
},
{
Name: "invalid JWK",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"empty": "nonthingness"}`))
}),
},
}

for _, tc := range testcases {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
srv := httptest.NewServer(tc.Handler)
defer srv.Close()

ar := jwk.NewAutoRefresh(ctx)
ar.Configure(srv.URL, jwk.WithRefreshInterval(500*time.Millisecond))
ch := make(chan jwk.AutoRefreshError, 256) // big buffer
ar.ErrorSink(ch)
ar.Fetch(ctx, srv.URL)

timer := time.NewTimer(3 * time.Second)

select {
case <-ctx.Done():
t.Errorf(`ctx.Done before timer`)
case <-timer.C:
}

cancel() // forcefully end context, and thus the AutoRefresh

// timing issues can cause this to be non-deterministic...
// we'll say it's okay as long as we're in +/- 1 range
l := len(ch)
if !assert.True(t, l <= 7, "number of errors shold be less than or equal to 7 (%d)", l) {
return
}
if !assert.True(t, l >= 5, "number of errors shold be greather than or equal to 5 (%d)", l) {
return
}
})
}
}