diff --git a/cmd/cleanup/background.go b/cmd/cleanup/background.go new file mode 100644 index 000000000000..80265cb1ba6f --- /dev/null +++ b/cmd/cleanup/background.go @@ -0,0 +1,33 @@ +package cleanup + +import ( + cx "context" + "time" + + "github.com/ory/graceful" + "github.com/ory/kratos/driver" +) + +func BackgroundCleanup(ctx cx.Context, r driver.Registry) { + ctx, cancel := cx.WithCancel(ctx) + + r.Logger().Println("Cleanup worker started.") + if err := graceful.Graceful(func() error { + for { + select { + case <-time.After(r.Config(ctx).DatabaseCleanupSleepBackground()): + err := r.Persister().CleanupDatabase(ctx, r.Config(ctx).DatabaseCleanupSleepTables()) + r.Logger().Error(err) + case <-ctx.Done(): + return nil + } + } + }, func(_ cx.Context) error { + cancel() + return nil + }); err != nil { + r.Logger().WithError(err).Fatalf("Failed to run cleanup worker.") + } + + r.Logger().Println("Background cleanup worker was shutdown gracefully.") +} diff --git a/cmd/cleanup/root.go b/cmd/cleanup/root.go new file mode 100644 index 000000000000..5305b679a1c0 --- /dev/null +++ b/cmd/cleanup/root.go @@ -0,0 +1,21 @@ +package cleanup + +import ( + "github.com/ory/x/configx" + "github.com/spf13/cobra" +) + +func NewCleanupCmd() *cobra.Command { + c := &cobra.Command{ + Use: "cleanup", + Short: "Various cleanup helpers", + } + configx.RegisterFlags(c.PersistentFlags()) + return c +} + +func RegisterCommandRecursive(parent *cobra.Command) { + c := NewCleanupCmd() + parent.AddCommand(c) + c.AddCommand(NewCleanupSQLCmd()) +} diff --git a/cmd/cleanup/sql.go b/cmd/cleanup/sql.go new file mode 100644 index 000000000000..ea1f90b38826 --- /dev/null +++ b/cmd/cleanup/sql.go @@ -0,0 +1,50 @@ +/* +Copyright © 2019 NAME HERE +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package cleanup + +import ( + "time" + + "github.com/ory/kratos/driver/config" + "github.com/spf13/cobra" + + "github.com/ory/kratos/cmd/cliclient" + "github.com/ory/x/configx" +) + +// cleanupSqlCmd represents the sql command +func NewCleanupSQLCmd() *cobra.Command { + c := &cobra.Command{ + Use: "sql ", + Short: "Cleanup sql database from expired flows and sessions", + Long: `Run this command as frequently as you need. +It is recommended to run this command close to the SQL instance (e.g. same subnet) instead of over the public internet. +This decreases risk of failure and decreases time required. +You can read in the database URL using the -e flag, for example: + export DSN=... + kratos cleanup sql -e +### WARNING ### +Before running this command on an existing database, create a back up! +`, + Run: func(cmd *cobra.Command, args []string) { + cliclient.NewCleanupHandler().CleanupSQL(cmd, args) + }, + } + + configx.RegisterFlags(c.PersistentFlags()) + c.Flags().BoolP("read-from-env", "e", false, "If set, reads the database connection string from the environment variable DSN or config file key dsn.") + c.Flags().IntP(config.ViperKeyDatabaseCleanupBatchSize, "b", 100, "Set the number of records to be cleaned per run") + c.Flags().Duration(config.ViperKeyDatabaseCleanupSleepBackground, 30*time.Minute, "How long to wait between each cleanup run") + c.Flags().Duration(config.ViperKeyDatabaseCleanupSleepTables, time.Minute, "How long to wait between each table cleanup") + return c +} diff --git a/cmd/root.go b/cmd/root.go index ac744e5b3dc9..f3808a1402c8 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -5,6 +5,8 @@ import ( "fmt" "os" + "github.com/ory/kratos/cmd/cleanup" + "github.com/ory/kratos/driver/config" "github.com/ory/kratos/cmd/courier" @@ -36,6 +38,7 @@ func NewRootCmd() (cmd *cobra.Command) { cmd.AddCommand(identities.NewListCmd(cmd)) migrate.RegisterCommandRecursive(cmd) serve.RegisterCommandRecursive(cmd) + cleanup.RegisterCommandRecursive(cmd) remote.RegisterCommandRecursive(cmd) cmd.AddCommand(identities.NewValidateCmd()) cmd.AddCommand(cmdx.Version(&config.Version, &config.Commit, &config.Date)) diff --git a/continuity/persistence.go b/continuity/persistence.go index 5a2af3553d48..b536b692df68 100644 --- a/continuity/persistence.go +++ b/continuity/persistence.go @@ -2,6 +2,7 @@ package continuity import ( "context" + "time" "github.com/gofrs/uuid" ) @@ -14,4 +15,5 @@ type Persister interface { SaveContinuitySession(ctx context.Context, c *Container) error GetContinuitySession(ctx context.Context, id uuid.UUID) (*Container, error) DeleteContinuitySession(ctx context.Context, id uuid.UUID) error + DeleteExpiredContinuitySessions(context.Context, time.Time, int) error } diff --git a/driver/config/config.go b/driver/config/config.go index 51ce7d670838..5f9c94608ed3 100644 --- a/driver/config/config.go +++ b/driver/config/config.go @@ -15,6 +15,8 @@ import ( "testing" "time" + "github.com/santhosh-tekuri/jsonschema" + "github.com/ory/jsonschema/v3/httploader" "github.com/ory/x/httpx" @@ -153,6 +155,9 @@ const ( ViperKeyHasherArgon2ConfigDedicatedMemory = "hashers.argon2.dedicated_memory" ViperKeyHasherBcryptCost = "hashers.bcrypt.cost" ViperKeyCipherAlgorithm = "ciphers.algorithm" + ViperKeyDatabaseCleanupBatchSize = "database.cleanup.batch_size" + ViperKeyDatabaseCleanupSleepBackground = "database.cleanup.sleep.background" + ViperKeyDatabaseCleanupSleepTables = "database.cleanup.sleep.tables" ViperKeyLinkLifespan = "selfservice.methods.link.config.lifespan" ViperKeyLinkBaseURL = "selfservice.methods.link.config.base_url" ViperKeyPasswordHaveIBeenPwnedHost = "selfservice.methods.password.config.haveibeenpwned_host" @@ -1080,6 +1085,18 @@ func (p *Config) SelfServiceLinkMethodBaseURL() *url.URL { return p.p.RequestURIF(ViperKeyLinkBaseURL, p.SelfPublicURL()) } +func (p *Config) DatabaseCleanupBatchSize() int { + return p.p.IntF(ViperKeyDatabaseCleanupBatchSize, 100) +} + +func (p *Config) DatabaseCleanupSleepBackground() time.Duration { + return p.p.DurationF(ViperKeyDatabaseCleanupSleepBackground, 30*time.Minute) +} + +func (p *Config) DatabaseCleanupSleepTables() time.Duration { + return p.p.DurationF(ViperKeyDatabaseCleanupSleepTables, 1*time.Minute) +} + func (p *Config) SelfServiceFlowRecoveryAfterHooks(strategy string) []SelfServiceHook { return p.selfServiceHooks(HookStrategyKey(ViperKeySelfServiceRecoveryAfter, strategy)) } diff --git a/embedx/config.schema.json b/embedx/config.schema.json index d9b3129d9ee3..8b11aad33011 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -1433,6 +1433,49 @@ } } }, + "database": { + "type": "object", + "title": "Database related configuration", + "description": "Miscellaneous settings used in database related tasks (cleanup, etc.)", + "properties": { + "cleanup": { + "type": "object", + "title": "Database cleanup settings", + "description": "Settings that controls how the database cleanup process is configured (delays, batch size, etc.)", + "properties": { + "batch_size" : { + "type": "integer", + "title": "Number of records to clean in one iteration", + "description": "Controls how many records should be purged from one table during database cleanup task", + "minimum": 1, + "default": 100 + }, + "sleep": { + "type": "object", + "title": "Delays between various database cleanup phases", + "description": "Configures delays between each step of the cleanup process. It is useful to tune the process so it will be efficient and performant.", + "properties": { + "background": { + "type": "string", + "title": "Delay between each background runs", + "description": "When running the task in the background this parameter controls how long to wait before staring a new database cleanup iteration", + "pattern": "^[0-9]+(ns|us|ms|s|m|h)$", + "default": "30m" + }, + "tables": { + "type": "string", + "title": "Delay between each table cleanups", + "description": "Controls the delay time between cleaning each table in one cleanup iteration", + "pattern": "^[0-9]+(ns|us|ms|s|m|h)$", + "default": "1m" + } + } + } + } + } + }, + "additionalProperties": false + }, "dsn": { "type": "string", "title": "Data Source Name", diff --git a/internal/driver.go b/internal/driver.go index 9a52bba439d2..95482a74c096 100644 --- a/internal/driver.go +++ b/internal/driver.go @@ -4,6 +4,7 @@ import ( "context" "os" "testing" + "time" "github.com/ory/kratos/corp" @@ -45,6 +46,9 @@ func NewConfigurationWithDefaults(t *testing.T) *config.Config { config.ViperKeyCourierSMTPURL: "smtp://foo:bar@baz.com/", config.ViperKeySelfServiceBrowserDefaultReturnTo: "https://www.ory.sh/redirect-not-set", config.ViperKeySecretsCipher: []string{"secret-thirty-two-character-long"}, + config.ViperKeyDatabaseCleanupBatchSize: 100, + config.ViperKeyDatabaseCleanupSleepBackground: 30 * time.Minute, + config.ViperKeyDatabaseCleanupSleepTables: 1 * time.Minute, }), configx.SkipValidation(), ) diff --git a/persistence/reference.go b/persistence/reference.go index 7afa06685c45..1754ed5933b4 100644 --- a/persistence/reference.go +++ b/persistence/reference.go @@ -2,6 +2,7 @@ package persistence import ( "context" + "time" "github.com/ory/x/networkx" @@ -43,6 +44,7 @@ type Persister interface { link.RecoveryTokenPersister link.VerificationTokenPersister + CleanupDatabase(context.Context, time.Duration) error Close(context.Context) error Ping() error MigrationStatus(c context.Context) (popx.MigrationStatuses, error) diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go index a0b620603956..f563ff8c5f13 100644 --- a/persistence/sql/persister.go +++ b/persistence/sql/persister.go @@ -4,6 +4,7 @@ import ( "context" "embed" "fmt" + "time" "github.com/ory/x/fsx" @@ -135,6 +136,57 @@ type node interface { GetNID() uuid.UUID } +func (p *Persister) CleanupDatabase(ctx context.Context, wait time.Duration) error { + currentTime := time.Now() + deleteLimit := p.r.Config(ctx).DatabaseCleanupBatchSize() + p.r.Logger().Printf("Cleaning up first %d records older than %s\n", deleteLimit, currentTime) + + p.r.Logger().Println("Cleaning up expired sessions") + if err := p.DeleteExpiredSessions(ctx, currentTime, deleteLimit); err != nil { + return err + } + time.Sleep(wait) + + p.r.Logger().Println("Cleaning up expired continuity containers") + if err := p.DeleteExpiredContinuitySessions(ctx, currentTime, deleteLimit); err != nil { + return err + } + time.Sleep(wait) + + p.r.Logger().Println("Cleaning up expired login flows") + if err := p.DeleteExpiredLoginFlows(ctx, currentTime, deleteLimit); err != nil { + return err + } + time.Sleep(wait) + + p.r.Logger().Println("Cleaning up expired recovery flows") + if err := p.DeleteExpiredRecoveryFlows(ctx, currentTime, deleteLimit); err != nil { + return err + } + time.Sleep(wait) + + p.r.Logger().Println("Cleaning up expired registation flows") + if err := p.DeleteExpiredRegistrationFlows(ctx, currentTime, deleteLimit); err != nil { + return err + } + time.Sleep(wait) + + p.r.Logger().Println("Cleaning up expired settings flows") + if err := p.DeleteExpiredSettingsFlows(ctx, currentTime, deleteLimit); err != nil { + return err + } + time.Sleep(wait) + + p.r.Logger().Println("Cleaning up expired verification flows") + if err := p.DeleteExpiredVerificationFlows(ctx, currentTime, deleteLimit); err != nil { + return err + } + time.Sleep(wait) + + p.r.Logger().Println("Successfully cleaned up the SQL database!") + return nil +} + func (p *Persister) update(ctx context.Context, v node, columnNames ...string) error { c := p.GetConnection(ctx) quoter, ok := c.Dialect.(quotable) diff --git a/persistence/sql/persister_continuity.go b/persistence/sql/persister_continuity.go index 04cf60c6e43c..8f7d27fb6696 100644 --- a/persistence/sql/persister_continuity.go +++ b/persistence/sql/persister_continuity.go @@ -3,6 +3,7 @@ package sql import ( "context" "fmt" + "time" "github.com/pkg/errors" @@ -41,3 +42,18 @@ func (p *Persister) DeleteContinuitySession(ctx context.Context, id uuid.UUID) e } return nil } + +func (p *Persister) DeleteExpiredContinuitySessions(ctx context.Context, expiresAt time.Time, limit int) error { + // #nosec G201 + err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( + "DELETE FROM %s WHERE expires_at <= ? LIMIT ?", + new(continuity.Container).TableName(ctx), + ), + expiresAt, + limit, + ).Exec() + if err != nil { + return sqlcon.HandleError(err) + } + return nil +} diff --git a/persistence/sql/persister_login.go b/persistence/sql/persister_login.go index 96016d99ec50..258103070600 100644 --- a/persistence/sql/persister_login.go +++ b/persistence/sql/persister_login.go @@ -2,6 +2,8 @@ package sql import ( "context" + "fmt" + "time" "github.com/ory/kratos/corp" @@ -51,3 +53,18 @@ func (p *Persister) ForceLoginFlow(ctx context.Context, id uuid.UUID) error { return tx.Save(lr, "nid") }) } + +func (p *Persister) DeleteExpiredLoginFlows(ctx context.Context, expiresAt time.Time, limit int) error { + // #nosec G201 + err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( + "DELETE FROM %s WHERE expires_at <= ? LIMIT ?", + new(login.Flow).TableName(ctx), + ), + expiresAt, + limit, + ).Exec() + if err != nil { + return sqlcon.HandleError(err) + } + return nil +} diff --git a/persistence/sql/persister_recovery.go b/persistence/sql/persister_recovery.go index 9ff3a2b4fae6..900e7eed0a39 100644 --- a/persistence/sql/persister_recovery.go +++ b/persistence/sql/persister_recovery.go @@ -93,3 +93,18 @@ func (p *Persister) DeleteRecoveryToken(ctx context.Context, token string) error /* #nosec G201 TableName is static */ return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE token=? AND nid = ?", new(link.RecoveryToken).TableName(ctx)), token, corp.ContextualizeNID(ctx, p.nid)).Exec() } + +func (p *Persister) DeleteExpiredRecoveryFlows(ctx context.Context, expiresAt time.Time, limit int) error { + // #nosec G201 + err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( + "DELETE FROM %s WHERE expires_at <= ? LIMIT ?", + new(recovery.Flow).TableName(ctx), + ), + expiresAt, + limit, + ).Exec() + if err != nil { + return sqlcon.HandleError(err) + } + return nil +} diff --git a/persistence/sql/persister_registration.go b/persistence/sql/persister_registration.go index 0895f0791e9f..b95778ea1190 100644 --- a/persistence/sql/persister_registration.go +++ b/persistence/sql/persister_registration.go @@ -2,6 +2,8 @@ package sql import ( "context" + "fmt" + "time" "github.com/ory/kratos/corp" @@ -34,3 +36,18 @@ func (p *Persister) GetRegistrationFlow(ctx context.Context, id uuid.UUID) (*reg return &r, nil } + +func (p *Persister) DeleteExpiredRegistrationFlows(ctx context.Context, expiresAt time.Time, limit int) error { + // #nosec G201 + err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( + "DELETE FROM %s WHERE expires_at <= ? LIMIT ?", + new(registration.Flow).TableName(ctx), + ), + expiresAt, + limit, + ).Exec() + if err != nil { + return sqlcon.HandleError(err) + } + return nil +} diff --git a/persistence/sql/persister_session.go b/persistence/sql/persister_session.go index 27243c19acfb..ef205c9931f5 100644 --- a/persistence/sql/persister_session.go +++ b/persistence/sql/persister_session.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "time" "github.com/gobuffalo/pop/v6" @@ -196,3 +197,18 @@ func (p *Persister) RevokeSessionsIdentityExcept(ctx context.Context, iID, sID u } return count, nil } + +func (p *Persister) DeleteExpiredSessions(ctx context.Context, expiresAt time.Time, limit int) error { + // #nosec G201 + err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( + "DELETE FROM %s WHERE expires_at <= ? LIMIT ?", + corp.ContextualizeTableName(ctx, "sessions"), + ), + expiresAt, + limit, + ).Exec() + if err != nil { + return sqlcon.HandleError(err) + } + return nil +} diff --git a/persistence/sql/persister_settings.go b/persistence/sql/persister_settings.go index 5bb3c2593a9e..b41bb94fe8b2 100644 --- a/persistence/sql/persister_settings.go +++ b/persistence/sql/persister_settings.go @@ -2,6 +2,8 @@ package sql import ( "context" + "fmt" + "time" "github.com/gofrs/uuid" @@ -42,3 +44,18 @@ func (p *Persister) UpdateSettingsFlow(ctx context.Context, r *settings.Flow) er cp.NID = corp.ContextualizeNID(ctx, p.nid) return p.update(ctx, cp) } + +func (p *Persister) DeleteExpiredSettingsFlows(ctx context.Context, expiresAt time.Time, limit int) error { + // #nosec G201 + err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( + "DELETE FROM %s WHERE expires_at <= ? LIMIT ?", + new(settings.Flow).TableName(ctx), + ), + expiresAt, + limit, + ).Exec() + if err != nil { + return sqlcon.HandleError(err) + } + return nil +} diff --git a/persistence/sql/persister_verification.go b/persistence/sql/persister_verification.go index 06a6e3030854..277ae79af616 100644 --- a/persistence/sql/persister_verification.go +++ b/persistence/sql/persister_verification.go @@ -95,3 +95,18 @@ func (p *Persister) DeleteVerificationToken(ctx context.Context, token string) e /* #nosec G201 TableName is static */ return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE token=? AND nid = ?", new(link.VerificationToken).TableName(ctx)), token, nid).Exec() } + +func (p *Persister) DeleteExpiredVerificationFlows(ctx context.Context, expiresAt time.Time, limit int) error { + // #nosec G201 + err := p.GetConnection(ctx).RawQuery(fmt.Sprintf( + "DELETE FROM %s WHERE expires_at <= ? LIMIT ?", + new(verification.Flow).TableName(ctx), + ), + expiresAt, + limit, + ).Exec() + if err != nil { + return sqlcon.HandleError(err) + } + return nil +} diff --git a/selfservice/flow/login/persistence.go b/selfservice/flow/login/persistence.go index 68884a4dc0c5..fac9daab56ed 100644 --- a/selfservice/flow/login/persistence.go +++ b/selfservice/flow/login/persistence.go @@ -2,6 +2,7 @@ package login import ( "context" + "time" "github.com/gofrs/uuid" ) @@ -12,6 +13,7 @@ type ( CreateLoginFlow(context.Context, *Flow) error GetLoginFlow(context.Context, uuid.UUID) (*Flow, error) ForceLoginFlow(ctx context.Context, id uuid.UUID) error + DeleteExpiredLoginFlows(context.Context, time.Time, int) error } FlowPersistenceProvider interface { LoginFlowPersister() FlowPersister diff --git a/selfservice/flow/recovery/persistence.go b/selfservice/flow/recovery/persistence.go index d64dce3196b4..35f482cc329f 100644 --- a/selfservice/flow/recovery/persistence.go +++ b/selfservice/flow/recovery/persistence.go @@ -2,6 +2,7 @@ package recovery import ( "context" + "time" "github.com/gofrs/uuid" ) @@ -11,6 +12,7 @@ type ( CreateRecoveryFlow(context.Context, *Flow) error GetRecoveryFlow(ctx context.Context, id uuid.UUID) (*Flow, error) UpdateRecoveryFlow(context.Context, *Flow) error + DeleteExpiredRecoveryFlows(context.Context, time.Time, int) error } FlowPersistenceProvider interface { RecoveryFlowPersister() FlowPersister diff --git a/selfservice/flow/registration/persistence.go b/selfservice/flow/registration/persistence.go index 54af2006d1f4..f19965789a23 100644 --- a/selfservice/flow/registration/persistence.go +++ b/selfservice/flow/registration/persistence.go @@ -2,6 +2,7 @@ package registration import ( "context" + "time" "github.com/gofrs/uuid" ) @@ -10,6 +11,7 @@ type FlowPersister interface { UpdateRegistrationFlow(context.Context, *Flow) error CreateRegistrationFlow(context.Context, *Flow) error GetRegistrationFlow(context.Context, uuid.UUID) (*Flow, error) + DeleteExpiredRegistrationFlows(context.Context, time.Time, int) error } type FlowPersistenceProvider interface { diff --git a/selfservice/flow/settings/persistence.go b/selfservice/flow/settings/persistence.go index 167260ff2803..fe7d78108f6d 100644 --- a/selfservice/flow/settings/persistence.go +++ b/selfservice/flow/settings/persistence.go @@ -2,6 +2,7 @@ package settings import ( "context" + "time" "github.com/gofrs/uuid" ) @@ -11,6 +12,7 @@ type ( CreateSettingsFlow(context.Context, *Flow) error GetSettingsFlow(ctx context.Context, id uuid.UUID) (*Flow, error) UpdateSettingsFlow(context.Context, *Flow) error + DeleteExpiredSettingsFlows(context.Context, time.Time, int) error } FlowPersistenceProvider interface { SettingsFlowPersister() FlowPersister diff --git a/selfservice/flow/verification/persistence.go b/selfservice/flow/verification/persistence.go index 275f860ff2ae..f8898a1d71a4 100644 --- a/selfservice/flow/verification/persistence.go +++ b/selfservice/flow/verification/persistence.go @@ -2,6 +2,7 @@ package verification import ( "context" + "time" "github.com/gofrs/uuid" ) @@ -14,5 +15,6 @@ type ( CreateVerificationFlow(context.Context, *Flow) error GetVerificationFlow(ctx context.Context, id uuid.UUID) (*Flow, error) UpdateVerificationFlow(context.Context, *Flow) error + DeleteExpiredVerificationFlows(context.Context, time.Time, int) error } ) diff --git a/session/persistence.go b/session/persistence.go index 075209f7e866..48bb98c7d6e9 100644 --- a/session/persistence.go +++ b/session/persistence.go @@ -3,6 +3,7 @@ package session import ( "context" "testing" + "time" "github.com/bxcodec/faker/v3" "github.com/gofrs/uuid" @@ -40,6 +41,9 @@ type Persister interface { // instead of a session ID. GetSessionByToken(context.Context, string) (*Session, error) + // DeleteExpiredSessions deletes sessions that expired before the given time. + DeleteExpiredSessions(context.Context, time.Time, int) error + // DeleteSessionByToken deletes a session associated with the given token. // // Functionality is similar to DeleteSession but accepts a session token