Skip to content

Commit

Permalink
Merge pull request #252 from XenitAB/errorhandler-does-http
Browse files Browse the repository at this point in the history
Error handlers are responsible for HTTP response
  • Loading branch information
bittrance committed May 10, 2023
2 parents b13ec3c + fa3ed69 commit 88c83b7
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 83 deletions.
32 changes: 29 additions & 3 deletions README.md
Expand Up @@ -319,11 +319,35 @@ oidcHandler := oidcgin.New(

### Custom error handler

It is possible to add a custom function to handle errors. It will not be possible to change anything using it, but you will be able to add logic for logging as an example.
It is possible to add a custom function to handle errors. The error handler can return an `options.Response` which will be rendered by the middleware. Returning `nil` will result in a default 400/401 error.

```go
errorHandler := func(description options.ErrorDescription, err error) {
fmt.Printf("Description: %s\tError: %v\n", description, err)
type Message struct {
Message string `json:"message"`
Url string `json:"url"`
}

func errorHandler(ctx context.Context, oidcErr *options.OidcError) *options.Response {
message := Message{
Message: string(oidcErr.Status),
Url: oidcErr.Url.String(),
}
var headers map[string]string
json, err := json.Marshal(message)
if err != nil {
headers["Content-Type"] = "text/plain"
return &options.Response{
StatusCode: 500,
Headers: headers,
Body: []byte("Internal encoding failure\r\n"),
}
}
headers["Content-Type"] = "text/plain"
return &options.Response{
StatusCode: 418,
Headers: headers,
Body: json,
}
}

oidcHandler := oidcgin.New(
Expand All @@ -334,6 +358,8 @@ oidcHandler := oidcgin.New(
)
```

This error handling interface was changed in v0.0.42. The previous interface was `func(description ErrorDescription, err error)`. In order to retain the same behavior, you need to update your error handler to read `desctiption` and `err` from `oidcErr` and return `nil`.

### Testing with the middleware enabled

There's a small package that simulates an OpenID Provider that can be used with tests.
Expand Down
137 changes: 81 additions & 56 deletions internal/oidctesting/tests.go
@@ -1,11 +1,10 @@
package oidctesting

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -313,65 +312,91 @@ func runTestErrorHandler(t *testing.T, testName string, tester tester) {
op := optest.NewTesting(t)
defer op.Close(t)

var info struct {
sync.RWMutex
description options.ErrorDescription
err error
}

setInfo := func(description options.ErrorDescription, err error) {
info.Lock()
info.description = description
info.err = err
info.Unlock()
}

getInfo := func() (description options.ErrorDescription, err error) {
info.RLock()
defer info.RUnlock()
return info.description, info.err
}

errorHandler := func(description options.ErrorDescription, err error) {
t.Logf("Description: %s\tError: %v", description, err)
setInfo(description, err)
}

opts := []options.Option{
options.WithIssuer(op.GetURL(t)),
options.WithRequiredAudience("test-client"),
options.WithRequiredTokenType("JWT+AT"),
options.WithErrorHandler(errorHandler),
cases := []struct {
testDescription string
errorHandler options.ErrorHandler
expectStatusCode int
expectHeaders map[string]string
expectBodyContains []byte
}{
{
testDescription: "no output",
errorHandler: func(ctx context.Context, oidcErr *options.OidcError) *options.Response { return nil },
expectStatusCode: http.StatusBadRequest,
expectHeaders: map[string]string{},
expectBodyContains: []byte{},
},
{
testDescription: "basic propagation",
errorHandler: func(ctx context.Context, oidcErr *options.OidcError) *options.Response {
return &options.Response{
StatusCode: 418,
Headers: map[string]string{},
Body: []byte("badness"),
}
},
expectStatusCode: http.StatusTeapot,
expectHeaders: map[string]string{
"Content-Type": "application/octet-stream",
},
expectBodyContains: []byte("bad"),
},
{
testDescription: "additional header",
errorHandler: func(ctx context.Context, oidcErr *options.OidcError) *options.Response {
return &options.Response{
StatusCode: 418,
Headers: map[string]string{"some": "header"},
Body: []byte("badness"),
}
},
expectStatusCode: http.StatusTeapot,
expectHeaders: map[string]string{
"Some": "header",
"Content-Type": "application/octet-stream",
},
expectBodyContains: []byte{},
},
{
testDescription: "content type",
errorHandler: func(ctx context.Context, oidcErr *options.OidcError) *options.Response {
return &options.Response{
StatusCode: 418,
Headers: map[string]string{"content-type": "application/json"},
Body: []byte("{}"),
}
},
expectStatusCode: http.StatusTeapot,
expectHeaders: map[string]string{
"Content-Type": "application/json",
},
expectBodyContains: []byte("{}"),
},
}
for i := range cases {
c := cases[i]
t.Logf("Test iteration %d: %s", i, c.testDescription)
opts := []options.Option{
options.WithIssuer(op.GetURL(t)),
options.WithRequiredAudience("test-client"),
options.WithRequiredTokenType("JWT+AT"),
options.WithErrorHandler(c.errorHandler),
}

oidcHandler, err := oidc.NewHandler[TestClaims](nil, opts...)
require.NoError(t, err)

handler := tester.ToHandlerFn(oidcHandler.ParseToken, opts...)

// Test without token
reqNoAuth := httptest.NewRequest(http.MethodGet, "/", nil)
recNoAuth := httptest.NewRecorder()
handler.ServeHTTP(recNoAuth, reqNoAuth)

require.Equal(t, http.StatusBadRequest, recNoAuth.Result().StatusCode)
oidcHandler, err := oidc.NewHandler[TestClaims](nil, opts...)
require.NoError(t, err)

d, e := getInfo()
handler := tester.ToHandlerFn(oidcHandler.ParseToken, opts...)

if !strings.Contains(t.Name(), "OidcEchoJwt") {
require.Equal(t, options.GetTokenErrorDescription, d)
require.EqualError(t, e, "unable to extract token: Authorization header empty")
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
handler.ServeHTTP(res, req)
require.Equal(t, c.expectStatusCode, res.Result().StatusCode)
for k, v := range c.expectHeaders {
require.Equal(t, []string{v}, res.Result().Header[k])
}
require.Subset(t, res.Body.Bytes(), c.expectBodyContains)
}

// Test with fake token
token := op.GetToken(t)
token.AccessToken = "foobar"
testHttpWithAuthenticationFailure(t, token, handler)

d, e = getInfo()

require.Equal(t, options.ParseTokenErrorDescription, d)
require.EqualError(t, e, "unable to parse token signature: invalid compact serialization format: invalid number of segments")
})
}

Expand Down
31 changes: 24 additions & 7 deletions oidcecho/echo.go
Expand Up @@ -19,10 +19,29 @@ func New[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...opt
return toEchoMiddleware(h.ParseToken, setters...)
}

func onError(errorHandler options.ErrorHandler, description options.ErrorDescription, err error) {
if errorHandler != nil {
errorHandler(description, err)
func onError(c echo.Context, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) error {
if errorHandler == nil {
c.Logger().Error(err)
return c.NoContent(statusCode)
}
oidcErr := options.OidcError{
Url: c.Request().URL,
Headers: c.Request().Header,
Status: description,
Error: err,
}
response := errorHandler(c.Request().Context(), &oidcErr)
if response == nil {
c.Logger().Error(err)
return c.NoContent(statusCode)
}
for k, v := range response.Headers {
c.Response().Header().Set(k, v)
}
c.Response().Header().Set(echo.HeaderContentType, response.ContentType())
c.Response().WriteHeader(response.StatusCode)
_, err = c.Response().Write(response.Body)
return err
}

func toEchoMiddleware[T any](parseToken oidc.ParseTokenFunc[T], setters ...options.Option) echo.MiddlewareFunc {
Expand All @@ -34,14 +53,12 @@ func toEchoMiddleware[T any](parseToken oidc.ParseTokenFunc[T], setters ...optio

tokenString, err := oidc.GetTokenString(c.Request().Header.Get, opts.TokenString)
if err != nil {
onError(opts.ErrorHandler, options.GetTokenErrorDescription, err)
return echo.ErrBadRequest
return onError(c, opts.ErrorHandler, echo.ErrBadRequest.Code, options.GetTokenErrorDescription, err)
}

claims, err := parseToken(ctx, tokenString)
if err != nil {
onError(opts.ErrorHandler, options.ParseTokenErrorDescription, err)
return echo.ErrUnauthorized
return onError(c, opts.ErrorHandler, echo.ErrUnauthorized.Code, options.ParseTokenErrorDescription, err)
}
c.Set(string(opts.ClaimsContextKeyName), claims)
return next(c)
Expand Down
27 changes: 23 additions & 4 deletions oidcfiber/fiber.go
Expand Up @@ -2,6 +2,7 @@ package oidcfiber

import (
"fmt"
"net/url"

"github.com/gofiber/fiber/v2"
"github.com/xenitab/go-oidc-middleware/internal/oidc"
Expand All @@ -20,11 +21,29 @@ func New[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...opt
}

func onError(c *fiber.Ctx, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) error {
if errorHandler != nil {
errorHandler(description, err)
if errorHandler == nil {
return c.SendStatus(statusCode)
}

return c.SendStatus(statusCode)
url, _ := url.Parse(c.OriginalURL())
headers := make(map[string][]string, 1)
for k, v := range c.GetReqHeaders() {
headers[k] = []string{v}
}
oidcErr := options.OidcError{
Url: url,
Headers: headers,
Status: description,
Error: err,
}
response := errorHandler(c.Context(), &oidcErr)
if response == nil {
return c.SendStatus(statusCode)
}
for k, v := range response.Headers {
c.Response().Header.Set(k, v)
}
c.Set("Content-Type", response.ContentType())
return c.Status(response.StatusCode).Send(response.Body)
}

func toFiberHandler[T any](parseToken oidc.ParseTokenFunc[T], setters ...options.Option) fiber.Handler {
Expand Down
23 changes: 18 additions & 5 deletions oidcgin/gin.go
Expand Up @@ -20,13 +20,26 @@ func New[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...opt
return toGinHandler(oidcHandler.ParseToken, setters...)
}

func onError(c *gin.Context, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) {
if errorHandler != nil {
errorHandler(description, err)
func onError(c *gin.Context, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) error {
if errorHandler == nil {
return c.AbortWithError(statusCode, err)
}

//nolint:errcheck // false positive
c.AbortWithError(statusCode, err)
oidcErr := options.OidcError{
Url: c.Request.URL,
Headers: c.Request.Header,
Status: description,
Error: err,
}
response := errorHandler(c.Request.Context(), &oidcErr)
if response == nil {
return c.AbortWithError(statusCode, err)
}
for k, v := range response.Headers {
c.Header(k, v)
}
c.Data(response.StatusCode, response.ContentType(), response.Body)
return nil
}

func toGinHandler[T any](parseToken oidc.ParseTokenFunc[T], setters ...options.Option) gin.HandlerFunc {
Expand Down
30 changes: 23 additions & 7 deletions oidchttp/http.go
Expand Up @@ -20,12 +20,28 @@ func New[T any](h http.Handler, claimsValidationFn options.ClaimsValidationFn[T]
return toHttpHandler(h, oidcHandler.ParseToken, setters...)
}

func onError(w http.ResponseWriter, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) {
if errorHandler != nil {
errorHandler(description, err)
func onError(r *http.Request, w http.ResponseWriter, errorHandler options.ErrorHandler, statusCode int, description options.ErrorDescription, err error) {
if errorHandler == nil {
w.WriteHeader(statusCode)
return
}

w.WriteHeader(statusCode)
oidcErr := options.OidcError{
Url: r.URL,
Headers: r.Header,
Status: description,
Error: err,
}
response := errorHandler(r.Context(), &oidcErr)
if response == nil {
w.WriteHeader(statusCode)
return
}
for k, v := range response.Headers {
w.Header().Add(k, v)
}
w.Header().Set("Content-Type", response.ContentType())
w.WriteHeader(response.StatusCode)
w.Write(response.Body)
}

func toHttpHandler[T any](h http.Handler, parseToken oidc.ParseTokenFunc[T], setters ...options.Option) http.Handler {
Expand All @@ -36,13 +52,13 @@ func toHttpHandler[T any](h http.Handler, parseToken oidc.ParseTokenFunc[T], set

tokenString, err := oidc.GetTokenString(r.Header.Get, opts.TokenString)
if err != nil {
onError(w, opts.ErrorHandler, http.StatusBadRequest, options.GetTokenErrorDescription, err)
onError(r, w, opts.ErrorHandler, http.StatusBadRequest, options.GetTokenErrorDescription, err)
return
}

claims, err := parseToken(ctx, tokenString)
if err != nil {
onError(w, opts.ErrorHandler, http.StatusUnauthorized, options.ParseTokenErrorDescription, err)
onError(r, w, opts.ErrorHandler, http.StatusUnauthorized, options.ParseTokenErrorDescription, err)
return
}

Expand Down

0 comments on commit 88c83b7

Please sign in to comment.