Skip to content

Commit

Permalink
Implement ErrorSink method (#421)
Browse files Browse the repository at this point in the history
* Implement FetchErrorChannel method

* Use a struct to encapsulate the URL as well

* Rename method and error

FetchErrorChannel -> ErrorSink
AutoRefreshFetchError -> AutoRefreshError

* Add comment

* Update Changes
  • Loading branch information
lestrrat committed Jul 29, 2021
1 parent 9204f79 commit c76a3f5
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 41 deletions.
3 changes: 3 additions & 0 deletions Changes
Expand Up @@ -6,6 +6,9 @@ v1.2.5 (UNRELEASED)
* Implement RFC7797. The value of the header field `b64` changes
how the payload is treated in JWS
* Implement detached payloads for JWS
* Implement (jwk.AutoRefresh).ErrorSink() to register a channel
where you can receive errors from fetches and parses that occur during
JWK(s) retrieval.

v1.2.4 15 Jul 2021
[Bug fixes]
Expand Down
5 changes: 5 additions & 0 deletions jwk/interface.go
Expand Up @@ -111,3 +111,8 @@ type HTTPClient interface {

type DecodeCtx = json.DecodeCtx
type KeyWithDecodeCtx = json.DecodeCtxContainer

type AutoRefreshError struct {
Error error
URL string
}
142 changes: 101 additions & 41 deletions jwk/refresh.go
Expand Up @@ -26,9 +26,12 @@ import (
// All JWKS objects that are retrieved via the auto-fetch mechanism should be
// treated read-only, as they are shared among the consumers and this object.
type AutoRefresh struct {
errDst chan AutoRefreshError // user-specified error sink
errSink chan AutoRefreshError // AutoRefresh's error sink
cache map[string]Set
configureCh chan struct{}
fetching map[string]chan struct{}
muErrDst sync.Mutex
muCache sync.RWMutex
muFetching sync.Mutex
muRegistry sync.RWMutex
Expand Down Expand Up @@ -108,13 +111,15 @@ type resetTimerReq struct {
// }
func NewAutoRefresh(ctx context.Context) *AutoRefresh {
af := &AutoRefresh{
errSink: make(chan AutoRefreshError, 1),
cache: make(map[string]Set),
configureCh: make(chan struct{}),
fetching: make(map[string]chan struct{}),
registry: make(map[string]*target),
resetTimerCh: make(chan *resetTimerReq),
}
go af.refreshLoop(ctx)
go af.drainErrSink(ctx)
return af
}

Expand Down Expand Up @@ -445,30 +450,49 @@ func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableB

res, err := fetch(ctx, url, options...)
if err == nil {
defer res.Body.Close()
keyset, parseErr := ParseReader(res.Body)
if parseErr == nil {
// Got a new key set. replace the keyset in the target
af.muCache.Lock()
af.cache[url] = keyset
af.muCache.Unlock()
nextInterval := calculateRefreshDuration(res, t.refreshInterval, t.minRefreshInterval)
rtr := &resetTimerReq{
t: t,
d: nextInterval,
}
select {
case <-ctx.Done():
return ctx.Err()
case af.resetTimerCh <- rtr:
if res.StatusCode != http.StatusOK {
// now, can there be a remote resource that responds with a status code
// other than 200 and still be valid...? naaaaaaahhhhhh....
err = errors.Errorf(`bad response status code (%d)`, res.StatusCode)
} else {
defer res.Body.Close()
keyset, parseErr := ParseReader(res.Body)
if parseErr == nil {
// Got a new key set. replace the keyset in the target
af.muCache.Lock()
af.cache[url] = keyset
af.muCache.Unlock()
nextInterval := calculateRefreshDuration(res, t.refreshInterval, t.minRefreshInterval)
rtr := &resetTimerReq{
t: t,
d: nextInterval,
}
select {
case <-ctx.Done():
return ctx.Err()
case af.resetTimerCh <- rtr:
}

now := time.Now()
t.lastRefresh = now.Local()
t.nextRefresh = now.Add(nextInterval).Local()
return nil
}
err = parseErr
}
}

now := time.Now()
t.lastRefresh = now.Local()
t.nextRefresh = now.Add(nextInterval).Local()
return nil
// At this point if err != nil, we know that there was something wrong
// in either the fetching or the parsing. Send this error to be processed,
// but take the extra mileage to not block regular processing by
// sending the error to a "proxy" sink, and not directly at the user-specified sink
// (see drainErrSink)
if err != nil && af.errSink != nil {
select {
case af.errSink <- AutoRefreshError{Error: err, URL: url}:
default:
panic("af.errSink is not draining")
}
err = parseErr
}

// We either failed to perform the HTTP GET, or we failed to parse the
Expand All @@ -479,7 +503,7 @@ func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableB
// If we failed to get a single time, then queue another fetch in the future.
rtr := &resetTimerReq{
t: t,
d: t.minRefreshInterval,
d: calculateRefreshDuration(res, t.refreshInterval, t.minRefreshInterval),
}
select {
case <-ctx.Done():
Expand All @@ -490,38 +514,74 @@ func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableB
return err
}

// drainErrSink is used proxy the errors that were sent to the main
// error sink (af.errSink) to the user specified error sink
func (af *AutoRefresh) drainErrSink(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case err := <-af.errSink:
af.muErrDst.Lock()
dst := af.errDst
af.muErrDst.Unlock()
if dst != nil {
// This will block if the user isn't properly draining the channel.
// It is the user's responsibility to drain it once they
// requested the errors to be streamed
dst <- err
}
}
}
}

// ErrorSink sets a channel to receive JWK fetch errors, if any.
// Only the errors that occurred *after* the channel was set will be sent.
//
// The user is responsible for properly draining the channel. If the channel
// is not drained, the fetch operation will block on repeated errors.
//
// To disable, set a nil channel.
func (af *AutoRefresh) ErrorSink(ch chan AutoRefreshError) {
af.muErrDst.Lock()
af.errDst = ch
af.muErrDst.Unlock()
}

func calculateRefreshDuration(res *http.Response, refreshInterval *time.Duration, minRefreshInterval time.Duration) time.Duration {
// This always has precedence
if refreshInterval != nil {
return *refreshInterval
}

if v := res.Header.Get(`Cache-Control`); v != "" {
dir, err := httpcc.ParseResponse(v)
if err == nil {
maxAge, ok := dir.MaxAge()
if ok {
resDuration := time.Duration(maxAge) * time.Second
if resDuration > minRefreshInterval {
return resDuration
if res != nil {
if v := res.Header.Get(`Cache-Control`); v != "" {
dir, err := httpcc.ParseResponse(v)
if err == nil {
maxAge, ok := dir.MaxAge()
if ok {
resDuration := time.Duration(maxAge) * time.Second
if resDuration > minRefreshInterval {
return resDuration
}
return minRefreshInterval
}
return minRefreshInterval
// fallthrough
}
// fallthrough
}
// fallthrough
}

if v := res.Header.Get(`Expires`); v != "" {
expires, err := http.ParseTime(v)
if err == nil {
resDuration := time.Until(expires)
if resDuration > minRefreshInterval {
return resDuration
if v := res.Header.Get(`Expires`); v != "" {
expires, err := http.ParseTime(v)
if err == nil {
resDuration := time.Until(expires)
if resDuration > minRefreshInterval {
return resDuration
}
return minRefreshInterval
}
return minRefreshInterval
// fallthrough
}
// fallthrough
}

// Previous fallthroughs are a little redandunt, but hey, it's all good.
Expand Down
60 changes: 60 additions & 0 deletions jwk/refresh_test.go
Expand Up @@ -281,3 +281,63 @@ func TestRefreshSnapshot(t *testing.T) {
t.Logf("%s last refreshed at %s, next refresh at %s", target.URL, target.LastRefresh, target.NextRefresh)
}
}

func TestErrorSink(t *testing.T) {
t.Parallel()

testcases := []struct {
Name string
Handler http.Handler
}{
{
Name: "non-200 response",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
}),
},
{
Name: "invalid JWK",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"empty": "nonthingness"}`))
}),
},
}

for _, tc := range testcases {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
srv := httptest.NewServer(tc.Handler)
defer srv.Close()

ar := jwk.NewAutoRefresh(ctx)
ar.Configure(srv.URL, jwk.WithRefreshInterval(500*time.Millisecond))
ch := make(chan jwk.AutoRefreshError, 256) // big buffer
ar.ErrorSink(ch)
ar.Fetch(ctx, srv.URL)

timer := time.NewTimer(3 * time.Second)

select {
case <-ctx.Done():
t.Errorf(`ctx.Done before timer`)
case <-timer.C:
}

cancel() // forcefully end context, and thus the AutoRefresh

// timing issues can cause this to be non-deterministic...
// we'll say it's okay as long as we're in +/- 1 range
l := len(ch)
if !assert.True(t, l <= 7, "number of errors shold be less than or equal to 7 (%d)", l) {
return
}
if !assert.True(t, l >= 5, "number of errors shold be greather than or equal to 5 (%d)", l) {
return
}
})
}
}

0 comments on commit c76a3f5

Please sign in to comment.