diff --git a/example/server/default/default.go b/example/server/default/default.go index e5b9d0f9..421c7f76 100644 --- a/example/server/default/default.go +++ b/example/server/default/default.go @@ -15,20 +15,27 @@ import ( func main() { ctx := context.Background() + port := "9998" config := &op.Config{ Issuer: "http://localhost:9998/", CryptoKey: sha256.Sum256([]byte("test")), - Port: "9998", } storage := mock.NewAuthStorage() handler, err := op.NewDefaultOP(ctx, config, storage, op.WithCustomTokenEndpoint(op.NewEndpoint("test"))) if err != nil { log.Fatal(err) } - router := handler.HttpHandler().Handler.(*mux.Router) + router := handler.HttpHandler().(*mux.Router) router.Methods("GET").Path("/login").HandlerFunc(HandleLogin) router.Methods("POST").Path("/login").HandlerFunc(HandleCallback) - op.Start(ctx, handler) + server := &http.Server{ + Addr: ":" + port, + Handler: router, + } + err = server.ListenAndServe() + if err != nil { + log.Fatal(err) + } <-ctx.Done() } diff --git a/pkg/op/config.go b/pkg/op/config.go index 1b047db7..c52609a9 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -16,8 +16,6 @@ type Configuration interface { KeysEndpoint() Endpoint AuthMethodPostSupported() bool - - Port() string } func ValidateIssuer(issuer string) error { diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go index ed117687..a16d4d3f 100644 --- a/pkg/op/default_op.go +++ b/pkg/op/default_op.go @@ -10,6 +10,7 @@ import ( "gopkg.in/square/go-jose.v2" "github.com/caos/logging" + "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/rp" ) @@ -45,7 +46,7 @@ type DefaultOP struct { signer Signer verifier rp.Verifier crypto Crypto - http *http.Server + http http.Handler decoder *schema.Decoder encoder *schema.Encoder interceptor HttpInterceptor @@ -64,7 +65,6 @@ type Config struct { // IdTokenSigningAlgValuesSupported: []string{keys.SigningAlgorithm}, // SubjectTypesSupported: []string{"public"}, // TokenEndpointAuthMethodsSupported: - Port string } type endpoints struct { @@ -180,13 +180,10 @@ func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts . p.signer = NewDefaultSigner(ctx, storage, keyCh) go p.ensureKey(ctx, storage, keyCh, p.timer) - p.verifier = rp.NewDefaultVerifier(config.Issuer, "", p, rp.WithIgnoreAudience()) + p.verifier = rp.NewDefaultVerifier(config.Issuer, "", p, rp.WithIgnoreAudience(), rp.WithIgnoreExpiration()) + + p.http = CreateRouter(p, p.interceptor) - router := CreateRouter(p, p.interceptor) - p.http = &http.Server{ - Addr: ":" + config.Port, - Handler: router, - } p.decoder = schema.NewDecoder() p.decoder.IgnoreUnknownKeys(true) @@ -225,11 +222,7 @@ func (p *DefaultOP) AuthMethodPostSupported() bool { return true //TODO: config } -func (p *DefaultOP) Port() string { - return p.config.Port -} - -func (p *DefaultOP) HttpHandler() *http.Server { +func (p *DefaultOP) HttpHandler() http.Handler { return p.http } diff --git a/pkg/op/op.go b/pkg/op/op.go index a926d341..732a9332 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -1,12 +1,10 @@ package op import ( - "context" "net/http" "github.com/gorilla/handlers" "github.com/gorilla/mux" - "github.com/sirupsen/logrus" "github.com/caos/oidc/pkg/oidc" ) @@ -26,7 +24,7 @@ type OpenIDProvider interface { HandleUserinfo(w http.ResponseWriter, r *http.Request) HandleEndSession(w http.ResponseWriter, r *http.Request) HandleKeys(w http.ResponseWriter, r *http.Request) - HttpHandler() *http.Server + HttpHandler() http.Handler } type HttpInterceptor func(http.HandlerFunc) http.HandlerFunc @@ -54,21 +52,3 @@ func CreateRouter(o OpenIDProvider, h HttpInterceptor) *mux.Router { router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys) return router } - -func Start(ctx context.Context, o OpenIDProvider) { - go func() { - <-ctx.Done() - err := o.HttpHandler().Shutdown(ctx) - if err != nil { - logrus.Error("graceful shutdown of oidc server failed") - } - }() - - go func() { - err := o.HttpHandler().ListenAndServe() - if err != nil { - logrus.Panicf("oidc server serve failed: %v", err) - } - }() - logrus.Infof("oidc server is listening on %s", o.Port()) -} diff --git a/pkg/op/session.go b/pkg/op/session.go index 96ec1bfa..c274bf09 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -27,7 +27,11 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) { RequestError(w, r, err) return } - err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.Client.GetID()) + var clientID string + if session.Client != nil { + clientID = session.Client.GetID() + } + err = ender.Storage().TerminateSession(r.Context(), session.UserID, clientID) if err != nil { RequestError(w, r, ErrServerError("error terminating session")) return @@ -50,6 +54,9 @@ func ParseEndSessionRequest(r *http.Request, decoder *schema.Decoder) (*oidc.End func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, ender SessionEnder) (*EndSessionRequest, error) { session := new(EndSessionRequest) + if req.IdTokenHint == "" { + return session, nil + } claims, err := ender.IDTokenVerifier().Verify(ctx, "", req.IdTokenHint) if err != nil { return nil, ErrInvalidRequest("id_token_hint invalid") diff --git a/pkg/rp/default_verifier.go b/pkg/rp/default_verifier.go index 64ecaa0c..db599e3e 100644 --- a/pkg/rp/default_verifier.go +++ b/pkg/rp/default_verifier.go @@ -46,13 +46,20 @@ func NewDefaultVerifier(issuer, clientID string, keySet oidc.KeySet, confOpts .. return &DefaultVerifier{config: conf, keySet: keySet} } -//WithIgnoreAudience will turn off audience claim (should only be used for id_token_hints) +//WithIgnoreAudience will turn off validation for audience claim (should only be used for id_token_hints) func WithIgnoreAudience() func(*verifierConfig) { return func(conf *verifierConfig) { conf.ignoreAudience = true } } +//WithIgnoreExpiration will turn off validation for expiration claim (should only be used for id_token_hints) +func WithIgnoreExpiration() func(*verifierConfig) { + return func(conf *verifierConfig) { + conf.ignoreExpiration = true + } +} + //WithIgnoreIssuedAt will turn off iat claim verification func WithIgnoreIssuedAt() func(*verifierConfig) { return func(conf *verifierConfig) { @@ -108,6 +115,7 @@ type verifierConfig struct { clientID string nonce string ignoreAudience bool + ignoreExpiration bool iat *iatConfig acr ACRVerifier maxAge time.Duration @@ -275,10 +283,10 @@ func (v *DefaultVerifier) checkSignature(ctx context.Context, idTokenString stri return "", err } if len(jws.Signatures) == 0 { - return "", nil //TODO: error + return "", ErrSignatureMissing() } if len(jws.Signatures) > 1 { - return "", nil //TODO: error + return "", ErrSignatureMultiple() } sig := jws.Signatures[0] supportedSigAlgs := v.config.supportedSignAlgs @@ -292,16 +300,18 @@ func (v *DefaultVerifier) checkSignature(ctx context.Context, idTokenString stri signedPayload, err := v.keySet.VerifySignature(ctx, jws) if err != nil { return "", err - //TODO: } if !bytes.Equal(signedPayload, payload) { - return "", ErrSignatureInvalidPayload() //TODO: err + return "", ErrSignatureInvalidPayload() } return jose.SignatureAlgorithm(sig.Header.Algorithm), nil } func (v *DefaultVerifier) checkExpiration(expiration time.Time) error { + if v.config.ignoreExpiration { + return nil + } expiration = expiration.Round(time.Second) if !v.now().Before(expiration) { return ErrExpInvalid(expiration) @@ -362,8 +372,8 @@ func (v *DefaultVerifier) decryptToken(tokenString string) (string, error) { } func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error { - if atHash == "" { - return nil //TODO: return error + if accessToken == "" { + return nil } actual, err := oidc.ClaimHash(accessToken, sigAlgorithm) @@ -371,7 +381,7 @@ func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgor return err } if actual != atHash { - return nil //TODO: error + return ErrAtHash() } return nil } diff --git a/pkg/rp/error.go b/pkg/rp/error.go index 038aa4a0..fa0ece9d 100644 --- a/pkg/rp/error.go +++ b/pkg/rp/error.go @@ -40,9 +40,18 @@ var ( ErrAuthTimeToOld = func(maxAge, authTime time.Time) *validationError { return ValidationError("Auth Time of token must not be older than %v, but was %v (%v to old)", maxAge, authTime, maxAge.Sub(authTime)) } + ErrSignatureMissing = func() *validationError { + return ValidationError("id_token does not contain a signature") + } + ErrSignatureMultiple = func() *validationError { + return ValidationError("id_token contains multiple signatures") + } ErrSignatureInvalidPayload = func() *validationError { return ValidationError("Signature does not match Payload") } + ErrAtHash = func() *validationError { + return ValidationError("at_hash does not correspond to access token") + } ) func ValidationError(message string, args ...interface{}) *validationError {