Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom json and base64 encoders for Token and Parser #301

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 35 additions & 0 deletions encoder.go
@@ -0,0 +1,35 @@
package jwt

import "io"

// Base64Encoding represents an object that can encode and decode base64. A
// common example is [encoding/base64.Encoding].
type Base64Encoding interface {
EncodeToString(src []byte) string
DecodeString(s string) ([]byte, error)
}

type StrictFunc[T Base64Encoding] func() T

type Stricter[T Base64Encoding] interface {
Strict() T
}

func DoStrict[S Base64Encoding, T Stricter[S]](x T) Base64Encoding {
return x.Strict()
}

// JSONMarshalFunc is an function type that allows to implement custom JSON
// encoding algorithms.
type JSONMarshalFunc func(v any) ([]byte, error)

// JSONUnmarshalFunc is an function type that allows to implement custom JSON
// unmarshal algorithms.
type JSONUnmarshalFunc func(data []byte, v any) error

type JSONDecoder interface {
UseNumber()
Decode(v any) error
}

type JSONNewDecoderFunc[T JSONDecoder] func(r io.Reader) T
1 change: 1 addition & 0 deletions errors.go
Expand Up @@ -22,6 +22,7 @@ var (
ErrTokenInvalidId = errors.New("token has invalid id")
ErrTokenInvalidClaims = errors.New("token has invalid claims")
ErrInvalidType = errors.New("invalid type for claim")
ErrUnsupported = errors.New("operation is unsupported")
)

// joinedError is an error type that works similar to what [errors.Join]
Expand Down
96 changes: 71 additions & 25 deletions parser.go
Expand Up @@ -12,16 +12,26 @@ type Parser struct {
// If populated, only these methods will be considered valid.
validMethods []string

// Use JSON Number format in JSON decoder.
useJSONNumber bool

// Skip claims validation during token parsing.
skipClaimsValidation bool

validator *validator

decodeStrict bool
decoding
}

type decoding struct {
jsonUnmarshal JSONUnmarshalFunc
jsonNewDecoder JSONNewDecoderFunc[JSONDecoder]

rawUrlBase64Encoding Base64Encoding
urlBase64Encoding Base64Encoding
strict StrictFunc[Base64Encoding]

// Use JSON Number format in JSON decoder.
useJSONNumber bool

decodeStrict bool
decodePaddingAllowed bool
}

Expand Down Expand Up @@ -148,7 +158,18 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
if headerBytes, err = p.DecodeSegment(parts[0]); err != nil {
return token, parts, newError("could not base64 decode header", ErrTokenMalformed, err)
}
if err = json.Unmarshal(headerBytes, &token.Header); err != nil {

// Choose our JSON decoder. If no custom function is supplied, we use the standard library.
var unmarshal JSONUnmarshalFunc
if p.jsonUnmarshal != nil {
unmarshal = p.jsonUnmarshal
} else {
unmarshal = json.Unmarshal
}

// JSON Unmarshal the header
err = unmarshal(headerBytes, &token.Header)
if err != nil {
return token, parts, newError("could not JSON decode header", ErrTokenMalformed, err)
}

Expand All @@ -160,25 +181,31 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
return token, parts, newError("could not base64 decode claim", ErrTokenMalformed, err)
}

// If `useJSONNumber` is enabled then we must use *json.Decoder to decode
// the claims. However, this comes with a performance penalty so only use
// it if we must and, otherwise, simple use json.Unmarshal.
if !p.useJSONNumber {
// JSON Unmarshal. Special case for map type to avoid weird pointer behavior.
if c, ok := token.Claims.(MapClaims); ok {
err = json.Unmarshal(claimBytes, &c)
} else {
err = json.Unmarshal(claimBytes, &claims)
// If `useJSONNumber` is enabled, then we must use a dedicated JSONDecoder
// to decode the claims. However, this comes with a performance penalty so
// only use it if we must and, otherwise, simple use our existing unmarshal
// function.
if p.useJSONNumber {
unmarshal = func(data []byte, v any) error {
buffer := bytes.NewBuffer(claimBytes)

var decoder JSONDecoder
if p.jsonNewDecoder != nil {
decoder = p.jsonNewDecoder(buffer)
} else {
decoder = json.NewDecoder(buffer)
}
decoder.UseNumber()
return decoder.Decode(v)
}
}

// JSON Unmarshal the claims. Special case for map type to avoid weird
// pointer behavior.
if c, ok := token.Claims.(MapClaims); ok {
err = unmarshal(claimBytes, &c)
} else {
dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
dec.UseNumber()
// JSON Decode. Special case for map type to avoid weird pointer behavior.
if c, ok := token.Claims.(MapClaims); ok {
err = dec.Decode(&c)
} else {
err = dec.Decode(&claims)
}
err = unmarshal(claimBytes, &claims)
}
if err != nil {
return token, parts, newError("could not JSON decode claim", ErrTokenMalformed, err)
Expand All @@ -200,18 +227,37 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
// take into account whether the [Parser] is configured with additional options,
// such as [WithStrictDecoding] or [WithPaddingAllowed].
func (p *Parser) DecodeSegment(seg string) ([]byte, error) {
encoding := base64.RawURLEncoding
var encoding Base64Encoding
if p.rawUrlBase64Encoding != nil {
encoding = p.rawUrlBase64Encoding
} else {
encoding = base64.RawURLEncoding
}

if p.decodePaddingAllowed {
if l := len(seg) % 4; l > 0 {
seg += strings.Repeat("=", 4-l)
}
encoding = base64.URLEncoding

if p.urlBase64Encoding != nil {
encoding = p.urlBase64Encoding
} else {
encoding = base64.URLEncoding
}
}

if p.decodeStrict {
encoding = encoding.Strict()
if p.strict != nil {
encoding = p.strict()
} else {
stricter, ok := encoding.(Stricter[*base64.Encoding])
if !ok {
return nil, newError("WithStrictDecoding() was enabled but supplied base64 encoder does not support strict mode", ErrUnsupported)
}
encoding = stricter.Strict()
}
}

return encoding.DecodeString(seg)
}

Expand Down
77 changes: 76 additions & 1 deletion parser_option.go
@@ -1,6 +1,9 @@
package jwt

import "time"
import (
"io"
"time"
)

// ParserOption is used to implement functional-style options that modify the
// behavior of the parser. To add new options, just create a function (ideally
Expand Down Expand Up @@ -113,8 +116,80 @@ func WithPaddingAllowed() ParserOption {
// WithStrictDecoding will switch the codec used for decoding JWTs into strict
// mode. In this mode, the decoder requires that trailing padding bits are zero,
// as described in RFC 4648 section 3.5.
//
// Note: This is only supported when using [encoding/base64.Encoding], but not
// by any other decoder specified with [WithBase64Decoder].
func WithStrictDecoding() ParserOption {
return func(p *Parser) {
p.decodeStrict = true
}
}

// WithJSONDecoder supports a custom JSON decoder to use in parsing the JWT.
// There are two functions that can be supplied:
// - jsonUnmarshal is a [JSONUnmarshalFunc] that is used for the
// un-marshalling the header and claims when no other options are specified
// - jsonNewDecoder is a [JSONNewDecoderFunc] that is used to create an object
// satisfying the [JSONDecoder] interface.
//
// The latter is used when the [WithJSONNumber] option is used.
//
// If any of the supplied functions is set to nil, the defaults from the Go
// standard library, [encoding/json.Unmarshal] and [encoding/json.NewDecoder]
// are used.
//
// Example using the https://github.com/bytedance/sonic library.
//
// import (
// "github.com/bytedance/sonic"
// )
//
// var parser = jwt.NewParser(jwt.WithJSONDecoder(sonic.Unmarshal, sonic.ConfigDefault.NewDecoder))
func WithJSONDecoder[T JSONDecoder](jsonUnmarshal JSONUnmarshalFunc, jsonNewDecoder JSONNewDecoderFunc[T]) ParserOption {
return func(p *Parser) {
p.jsonUnmarshal = jsonUnmarshal
// This seems to be necessary, since we don't want to store the specific
// JSONDecoder type in our parser, but need it in the function
// interface.
p.jsonNewDecoder = func(r io.Reader) JSONDecoder {
return jsonNewDecoder(r)
}
}
}

// WithBase64Decoder supports a custom Base64 when decoding a base64 encoded
// token. Two encoding can be specified:
// - rawURL needs to contain a [Base64Encoding] that is based on base64url
// without padding. This is used for parsing tokens with the default
// options.
// - url needs to contain a [Base64Encoding] based on base64url with padding.
// The sole use of this to decode tokens when [WithPaddingAllowed] is
// enabled.
//
// If any of the supplied encodings are set to nil, the defaults from the Go
// standard library, [encoding/base64.RawURLEncoding] and
// [encoding/base64.URLEncoding] are used.
//
// Example using the https://github.com/segmentio/asm library.
//
// import (
// asmbase64 "github.com/segmentio/asm/base64"
// )
//
// var parser = jwt.NewParser(jwt.WithBase64Decoder(asmbase64.RawURLEncoding, asmbase64.URLEncoding))
func WithBase64Decoder[T Base64Encoding](rawURL Base64Encoding, url T) ParserOption {
return func(p *Parser) {
p.rawUrlBase64Encoding = rawURL
p.urlBase64Encoding = url

// Check, whether the library supports the Strict() function
stricter, ok := rawURL.(Stricter[T])
if ok {
// We need to get rid of the type parameter T, so we need to wrap it
// here
p.strict = func() Base64Encoding {
return stricter.Strict()
}
}
}
}
34 changes: 34 additions & 0 deletions parser_test.go
Expand Up @@ -3,6 +3,7 @@ package jwt_test
import (
"crypto"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -423,6 +424,39 @@ var jwtTestData = []struct {
jwt.NewParser(jwt.WithLeeway(2 * time.Minute)),
jwt.SigningMethodRS256,
},
{
"custom json encoder",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg",
defaultKeyFunc,
jwt.MapClaims{"foo": "bar"},
true,
nil,
jwt.NewParser(jwt.WithJSONDecoder(json.Unmarshal, json.NewDecoder)),
jwt.SigningMethodRS256,
},
{
"custom json encoder - use numbers",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg",
defaultKeyFunc,
jwt.MapClaims{"foo": "bar"},
true,
nil,
jwt.NewParser(
jwt.WithJSONDecoder(json.Unmarshal, json.NewDecoder),
jwt.WithJSONNumber(),
),
jwt.SigningMethodRS256,
},
{
"custom base64 encoder",
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg",
defaultKeyFunc,
jwt.MapClaims{"foo": "bar"},
true,
nil,
jwt.NewParser(jwt.WithBase64Decoder(base64.RawURLEncoding, base64.URLEncoding)),
jwt.SigningMethodRS256,
},
}

// signToken creates and returns a signed JWT token using signingMethod.
Expand Down
35 changes: 30 additions & 5 deletions token.go
Expand Up @@ -34,6 +34,13 @@ type Token struct {
Claims Claims // Claims is the second segment of the token in decoded form
Signature []byte // Signature is the third segment of the token in decoded form. Populated when you Parse a token
Valid bool // Valid specifies if the token is valid. Populated when you Parse/Verify a token

encoders
}

type encoders struct {
jsonMarshal JSONMarshalFunc // jsonEncoder is the custom json encoder/decoder
base64Encoding Base64Encoding // base64Encoder is the custom base64 encoding
}

// New creates a new [Token] with the specified signing method and an empty map
Expand All @@ -45,14 +52,18 @@ func New(method SigningMethod, opts ...TokenOption) *Token {
// NewWithClaims creates a new [Token] with the specified signing method and
// claims. Additional options can be specified, but are currently unused.
func NewWithClaims(method SigningMethod, claims Claims, opts ...TokenOption) *Token {
return &Token{
t := &Token{
Header: map[string]interface{}{
"typ": "JWT",
"alg": method.Alg(),
},
Claims: claims,
Method: method,
}
for _, opt := range opts {
opt(t)
}
return t
}

// SignedString creates and returns a complete, signed JWT. The token is signed
Expand All @@ -78,12 +89,19 @@ func (t *Token) SignedString(key interface{}) (string, error) {
// of the whole deal. Unless you need this for something special, just go
// straight for the SignedString.
func (t *Token) SigningString() (string, error) {
h, err := json.Marshal(t.Header)
var marshal JSONMarshalFunc
if t.jsonMarshal != nil {
marshal = t.jsonMarshal
} else {
marshal = json.Marshal
}

h, err := marshal(t.Header)
if err != nil {
return "", err
}

c, err := json.Marshal(t.Claims)
c, err := marshal(t.Claims)
if err != nil {
return "", err
}
Expand All @@ -95,6 +113,13 @@ func (t *Token) SigningString() (string, error) {
// stripped. In the future, this function might take into account a
// [TokenOption]. Therefore, this function exists as a method of [Token], rather
// than a global function.
func (*Token) EncodeSegment(seg []byte) string {
return base64.RawURLEncoding.EncodeToString(seg)
func (t *Token) EncodeSegment(seg []byte) string {
var enc Base64Encoding
if t.base64Encoding != nil {
enc = t.base64Encoding
} else {
enc = base64.RawURLEncoding
}

return enc.EncodeToString(seg)
}