Skip to content

Commit

Permalink
Delete using subselect, ISSUE-952
Browse files Browse the repository at this point in the history
  • Loading branch information
abador committed May 17, 2022
1 parent a571e3b commit 6448602
Show file tree
Hide file tree
Showing 25 changed files with 116 additions and 39 deletions.
3 changes: 2 additions & 1 deletion cmd/cleanup/root.go
@@ -1,8 +1,9 @@
package cleanup

import (
"github.com/ory/x/configx"
"github.com/spf13/cobra"

"github.com/ory/x/configx"
)

func NewCleanupCmd() *cobra.Command {
Expand Down
4 changes: 3 additions & 1 deletion cmd/cleanup/sql.go
Expand Up @@ -16,9 +16,10 @@ import (
"fmt"
"time"

"github.com/ory/kratos/driver/config"
"github.com/spf13/cobra"

"github.com/ory/kratos/driver/config"

"github.com/ory/kratos/cmd/cliclient"
"github.com/ory/x/configx"
)
Expand Down Expand Up @@ -48,6 +49,7 @@ Before running this command on an existing database, create a back up!
configx.RegisterFlags(c.PersistentFlags())
c.Flags().BoolP("read-from-env", "e", true, "If set, reads the database connection string from the environment variable DSN or config file key dsn.")
c.Flags().Duration(config.ViperKeyDatabaseCleanupSleepTables, time.Minute, "How long to wait between each table cleanup")
c.Flags().IntP(config.ViperKeyDatabaseCleanupBatchSize, "b", 100, "Set the number of records to be cleaned per run")
c.Flags().Duration("keep-last", 0, "Don't remove records younger than")
return c
}
6 changes: 5 additions & 1 deletion cmd/cliclient/cleanup.go
Expand Up @@ -56,7 +56,11 @@ func (h *CleanupHandler) CleanupSQL(cmd *cobra.Command, args []string) error {

keepLast := flagx.MustGetDuration(cmd, "keep-last")

err = d.Persister().CleanupDatabase(cmd.Context(), d.Config(cmd.Context()).DatabaseCleanupSleepTables(), keepLast)
err = d.Persister().CleanupDatabase(
cmd.Context(),
d.Config(cmd.Context()).DatabaseCleanupSleepTables(),
keepLast,
d.Config(cmd.Context()).DatabaseCleanupBatchSize())
if err != nil {
return errors.Wrap(err, "An error occurred while cleaning up expired data")
}
Expand Down
2 changes: 1 addition & 1 deletion continuity/persistence.go
Expand Up @@ -15,5 +15,5 @@ type Persister interface {
SaveContinuitySession(ctx context.Context, c *Container) error
GetContinuitySession(ctx context.Context, id uuid.UUID) (*Container, error)
DeleteContinuitySession(ctx context.Context, id uuid.UUID) error
DeleteExpiredContinuitySessions(context.Context, time.Time) error
DeleteExpiredContinuitySessions(context.Context, time.Time, int) error
}
23 changes: 23 additions & 0 deletions continuity/test/persistence.go
Expand Up @@ -94,5 +94,28 @@ func TestPersister(ctx context.Context, p interface {
require.ErrorIs(t, err, sqlcon.ErrNoRows)
})
})

t.Run("case=cleanup", func(t *testing.T) {
id := x.NewUUID()
now := time.Now().Add(-24 * time.Hour).UTC().Truncate(time.Second)
m := sqlxx.NullJSONRawMessage(`{"foo": "bar"}`)
expected := continuity.Container{Name: "foo", IdentityID: x.PointToUUID(createIdentity(t).ID),
ExpiresAt: now,
Payload: m,
}
expected.ID = id

t.Run("can cleanup", func(t *testing.T) {
require.NoError(t, p.SaveContinuitySession(ctx, &expected))

assert.EqualValues(t, id, expected.ID)
assert.EqualValues(t, nid, expected.NID)

require.NoError(t, p.DeleteExpiredContinuitySessions(ctx, time.Now(), 5))

_, err := p.GetContinuitySession(ctx, id)
require.Error(t, err)
})
})
}
}
9 changes: 6 additions & 3 deletions driver/config/config.go
Expand Up @@ -15,8 +15,6 @@ import (
"testing"
"time"

"github.com/santhosh-tekuri/jsonschema"

"github.com/ory/jsonschema/v3/httploader"
"github.com/ory/x/httpx"

Expand Down Expand Up @@ -156,6 +154,7 @@ const (
ViperKeyHasherBcryptCost = "hashers.bcrypt.cost"
ViperKeyCipherAlgorithm = "ciphers.algorithm"
ViperKeyDatabaseCleanupSleepTables = "database.cleanup.sleep.tables"
ViperKeyDatabaseCleanupBatchSize = "database.cleanup.batch_size"
ViperKeyLinkLifespan = "selfservice.methods.link.config.lifespan"
ViperKeyLinkBaseURL = "selfservice.methods.link.config.base_url"
ViperKeyPasswordHaveIBeenPwnedHost = "selfservice.methods.password.config.haveibeenpwned_host"
Expand Down Expand Up @@ -1084,7 +1083,11 @@ func (p *Config) SelfServiceLinkMethodBaseURL() *url.URL {
}

func (p *Config) DatabaseCleanupSleepTables() time.Duration {
return p.p.DurationF(ViperKeyDatabaseCleanupSleepTables, 1*time.Minute)
return p.p.DurationF(ViperKeyDatabaseCleanupSleepTables, 5*time.Second)
}

func (p *Config) DatabaseCleanupBatchSize() int {
return p.p.IntF(ViperKeyDatabaseCleanupBatchSize, 100)
}

func (p *Config) SelfServiceFlowRecoveryAfterHooks(strategy string) []SelfServiceHook {
Expand Down
16 changes: 15 additions & 1 deletion driver/config/config_test.go
Expand Up @@ -48,7 +48,7 @@ func TestViperProvider(t *testing.T) {
p := config.MustNew(t, logrusx.New("", ""), os.Stderr,
configx.WithConfigFiles("stub/.kratos.yaml"))

t.Run("gourp=client config", func(t *testing.T) {
t.Run("group=client config", func(t *testing.T) {
assert.False(t, p.ClientHTTPNoPrivateIPRanges(), "Should not have private IP ranges disabled per default")
p.MustSet(config.ViperKeyClientHTTPNoPrivateIPRanges, true)
assert.True(t, p.ClientHTTPNoPrivateIPRanges(), "Should disallow private IP ranges if set")
Expand Down Expand Up @@ -1152,3 +1152,17 @@ func TestCourierTemplatesConfig(t *testing.T) {
assert.Equal(t, courierTemplateConfig, c.CourierTemplatesHelper(config.ViperKeyCourierTemplatesRecoveryValidEmail))
})
}

func TestCleanup(t *testing.T) {
p := config.MustNew(t, logrusx.New("", ""), os.Stderr,
configx.WithConfigFiles("stub/.kratos.yaml"))

t.Run("group=cleanup config", func(t *testing.T) {
assert.Equal(t, p.DatabaseCleanupSleepTables(), 1*time.Minute)
p.MustSet(config.ViperKeyDatabaseCleanupSleepTables, time.Second)
assert.Equal(t, p.DatabaseCleanupSleepTables(), time.Second)
assert.Equal(t, p.DatabaseCleanupBatchSize(), 100)
p.MustSet(config.ViperKeyDatabaseCleanupBatchSize, 1)
assert.Equal(t, p.DatabaseCleanupSleepTables(), 1)
})
}
2 changes: 1 addition & 1 deletion go.mod
Expand Up @@ -34,7 +34,6 @@ require (
github.com/dgraph-io/ristretto v0.1.0
github.com/duo-labs/webauthn v0.0.0-20220330035159-03696f3d4499
github.com/fatih/color v1.13.0
github.com/form3tech-oss/jwt-go v3.2.3+incompatible
github.com/ghodss/yaml v1.0.0
github.com/go-errors/errors v1.0.1
github.com/go-openapi/strfmt v0.20.3
Expand Down Expand Up @@ -151,6 +150,7 @@ require (
github.com/evanphx/json-patch v4.11.0+incompatible // indirect
github.com/fatih/structs v1.1.0 // indirect
github.com/felixge/httpsnoop v1.0.1 // indirect
github.com/form3tech-oss/jwt-go v3.2.3+incompatible // indirect
github.com/fsnotify/fsnotify v1.5.1 // indirect
github.com/fullstorydev/grpcurl v1.8.1 // indirect
github.com/fxamacker/cbor/v2 v2.4.0 // indirect
Expand Down
1 change: 1 addition & 0 deletions internal/driver.go
Expand Up @@ -47,6 +47,7 @@ func NewConfigurationWithDefaults(t *testing.T) *config.Config {
config.ViperKeySelfServiceBrowserDefaultReturnTo: "https://www.ory.sh/redirect-not-set",
config.ViperKeySecretsCipher: []string{"secret-thirty-two-character-long"},
config.ViperKeyDatabaseCleanupSleepTables: 1 * time.Minute,
config.ViperKeyDatabaseCleanupBatchSize: 100,
}),
configx.SkipValidation(),
)
Expand Down
2 changes: 1 addition & 1 deletion persistence/reference.go
Expand Up @@ -44,7 +44,7 @@ type Persister interface {
link.RecoveryTokenPersister
link.VerificationTokenPersister

CleanupDatabase(context.Context, time.Duration, time.Duration) error
CleanupDatabase(context.Context, time.Duration, time.Duration, int) error
Close(context.Context) error
Ping() error
MigrationStatus(c context.Context) (popx.MigrationStatuses, error)
Expand Down
16 changes: 8 additions & 8 deletions persistence/sql/persister.go
Expand Up @@ -136,48 +136,48 @@ type node interface {
GetNID() uuid.UUID
}

func (p *Persister) CleanupDatabase(ctx context.Context, wait time.Duration, older time.Duration) error {
func (p *Persister) CleanupDatabase(ctx context.Context, wait time.Duration, older time.Duration, batchSize int) error {
currentTime := time.Now().Add(-older)
p.r.Logger().Printf("Cleaning up records older than %s\n", currentTime)

p.r.Logger().Println("Cleaning up expired sessions")
if err := p.DeleteExpiredSessions(ctx, currentTime); err != nil {
if err := p.DeleteExpiredSessions(ctx, currentTime, batchSize); err != nil {
return err
}
time.Sleep(wait)

p.r.Logger().Println("Cleaning up expired continuity containers")
if err := p.DeleteExpiredContinuitySessions(ctx, currentTime); err != nil {
if err := p.DeleteExpiredContinuitySessions(ctx, currentTime, batchSize); err != nil {
return err
}
time.Sleep(wait)

p.r.Logger().Println("Cleaning up expired login flows")
if err := p.DeleteExpiredLoginFlows(ctx, currentTime); err != nil {
if err := p.DeleteExpiredLoginFlows(ctx, currentTime, batchSize); err != nil {
return err
}
time.Sleep(wait)

p.r.Logger().Println("Cleaning up expired recovery flows")
if err := p.DeleteExpiredRecoveryFlows(ctx, currentTime); err != nil {
if err := p.DeleteExpiredRecoveryFlows(ctx, currentTime, batchSize); err != nil {
return err
}
time.Sleep(wait)

p.r.Logger().Println("Cleaning up expired registation flows")
if err := p.DeleteExpiredRegistrationFlows(ctx, currentTime); err != nil {
if err := p.DeleteExpiredRegistrationFlows(ctx, currentTime, batchSize); err != nil {
return err
}
time.Sleep(wait)

p.r.Logger().Println("Cleaning up expired settings flows")
if err := p.DeleteExpiredSettingsFlows(ctx, currentTime); err != nil {
if err := p.DeleteExpiredSettingsFlows(ctx, currentTime, batchSize); err != nil {
return err
}
time.Sleep(wait)

p.r.Logger().Println("Cleaning up expired verification flows")
if err := p.DeleteExpiredVerificationFlows(ctx, currentTime); err != nil {
if err := p.DeleteExpiredVerificationFlows(ctx, currentTime, batchSize); err != nil {
return err
}
time.Sleep(wait)
Expand Down
7 changes: 5 additions & 2 deletions persistence/sql/persister_continuity.go
Expand Up @@ -43,13 +43,16 @@ func (p *Persister) DeleteContinuitySession(ctx context.Context, id uuid.UUID) e
return nil
}

func (p *Persister) DeleteExpiredContinuitySessions(ctx context.Context, expiresAt time.Time) error {
func (p *Persister) DeleteExpiredContinuitySessions(ctx context.Context, expiresAt time.Time, limit int) error {
// #nosec G201
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE expires_at <= ?",
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(continuity.Container).TableName(ctx),
new(continuity.Container).TableName(ctx),
limit,
),
expiresAt,
corp.ContextualizeNID(ctx, p.nid),
).Exec()
if err != nil {
return sqlcon.HandleError(err)
Expand Down
7 changes: 5 additions & 2 deletions persistence/sql/persister_login.go
Expand Up @@ -54,13 +54,16 @@ func (p *Persister) ForceLoginFlow(ctx context.Context, id uuid.UUID) error {
})
}

func (p *Persister) DeleteExpiredLoginFlows(ctx context.Context, expiresAt time.Time) error {
func (p *Persister) DeleteExpiredLoginFlows(ctx context.Context, expiresAt time.Time, limit int) error {
// #nosec G201
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE expires_at <= ?",
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(login.Flow).TableName(ctx),
new(login.Flow).TableName(ctx),
limit,
),
expiresAt,
corp.ContextualizeNID(ctx, p.nid),
).Exec()
if err != nil {
return sqlcon.HandleError(err)
Expand Down
7 changes: 5 additions & 2 deletions persistence/sql/persister_recovery.go
Expand Up @@ -94,13 +94,16 @@ func (p *Persister) DeleteRecoveryToken(ctx context.Context, token string) error
return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE token=? AND nid = ?", new(link.RecoveryToken).TableName(ctx)), token, corp.ContextualizeNID(ctx, p.nid)).Exec()
}

func (p *Persister) DeleteExpiredRecoveryFlows(ctx context.Context, expiresAt time.Time) error {
func (p *Persister) DeleteExpiredRecoveryFlows(ctx context.Context, expiresAt time.Time, limit int) error {
// #nosec G201
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE expires_at <= ?",
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(recovery.Flow).TableName(ctx),
new(recovery.Flow).TableName(ctx),
limit,
),
expiresAt,
corp.ContextualizeNID(ctx, p.nid),
).Exec()
if err != nil {
return sqlcon.HandleError(err)
Expand Down
7 changes: 5 additions & 2 deletions persistence/sql/persister_registration.go
Expand Up @@ -37,13 +37,16 @@ func (p *Persister) GetRegistrationFlow(ctx context.Context, id uuid.UUID) (*reg
return &r, nil
}

func (p *Persister) DeleteExpiredRegistrationFlows(ctx context.Context, expiresAt time.Time) error {
func (p *Persister) DeleteExpiredRegistrationFlows(ctx context.Context, expiresAt time.Time, limit int) error {
// #nosec G201
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE expires_at <= ?",
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(registration.Flow).TableName(ctx),
new(registration.Flow).TableName(ctx),
limit,
),
expiresAt,
corp.ContextualizeNID(ctx, p.nid),
).Exec()
if err != nil {
return sqlcon.HandleError(err)
Expand Down
8 changes: 5 additions & 3 deletions persistence/sql/persister_session.go
Expand Up @@ -198,13 +198,15 @@ func (p *Persister) RevokeSessionsIdentityExcept(ctx context.Context, iID, sID u
return count, nil
}

func (p *Persister) DeleteExpiredSessions(ctx context.Context, expiresAt time.Time) error {
// #nosec G201
func (p *Persister) DeleteExpiredSessions(ctx context.Context, expiresAt time.Time, limit int) error {
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE expires_at <= ?",
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
corp.ContextualizeTableName(ctx, "sessions"),
corp.ContextualizeTableName(ctx, "sessions"),
limit,
),
expiresAt,
corp.ContextualizeNID(ctx, p.nid),
).Exec()
if err != nil {
return sqlcon.HandleError(err)
Expand Down
7 changes: 5 additions & 2 deletions persistence/sql/persister_settings.go
Expand Up @@ -45,13 +45,16 @@ func (p *Persister) UpdateSettingsFlow(ctx context.Context, r *settings.Flow) er
return p.update(ctx, cp)
}

func (p *Persister) DeleteExpiredSettingsFlows(ctx context.Context, expiresAt time.Time) error {
func (p *Persister) DeleteExpiredSettingsFlows(ctx context.Context, expiresAt time.Time, limit int) error {
// #nosec G201
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE expires_at <= ?",
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(settings.Flow).TableName(ctx),
new(settings.Flow).TableName(ctx),
limit,
),
expiresAt,
corp.ContextualizeNID(ctx, p.nid),
).Exec()
if err != nil {
return sqlcon.HandleError(err)
Expand Down
9 changes: 9 additions & 0 deletions persistence/sql/persister_test.go
Expand Up @@ -315,3 +315,12 @@ func TestPersister_Transaction(t *testing.T) {
assert.Equal(t, sqlcon.ErrNoRows.Error(), err.Error())
})
}

func TestPersister_Cleanup(t *testing.T) {
_, reg := internal.NewFastRegistryWithMocks(t)
p := reg.Persister()

t.Run("case=should not throw error on cleanup", func(t *testing.T) {
assert.Nil(t, p.CleanupDatabase(context.Background(), 0, 0, reg.Config(context.Background()).DatabaseCleanupBatchSize()))
})
}
7 changes: 5 additions & 2 deletions persistence/sql/persister_verification.go
Expand Up @@ -96,13 +96,16 @@ func (p *Persister) DeleteVerificationToken(ctx context.Context, token string) e
return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE token=? AND nid = ?", new(link.VerificationToken).TableName(ctx)), token, nid).Exec()
}

func (p *Persister) DeleteExpiredVerificationFlows(ctx context.Context, expiresAt time.Time) error {
func (p *Persister) DeleteExpiredVerificationFlows(ctx context.Context, expiresAt time.Time, limit int) error {
// #nosec G201
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE expires_at <= ?",
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(verification.Flow).TableName(ctx),
new(verification.Flow).TableName(ctx),
limit,
),
expiresAt,
corp.ContextualizeNID(ctx, p.nid),
).Exec()
if err != nil {
return sqlcon.HandleError(err)
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/login/persistence.go
Expand Up @@ -13,7 +13,7 @@ type (
CreateLoginFlow(context.Context, *Flow) error
GetLoginFlow(context.Context, uuid.UUID) (*Flow, error)
ForceLoginFlow(ctx context.Context, id uuid.UUID) error
DeleteExpiredLoginFlows(context.Context, time.Time) error
DeleteExpiredLoginFlows(context.Context, time.Time, int) error
}
FlowPersistenceProvider interface {
LoginFlowPersister() FlowPersister
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/recovery/persistence.go
Expand Up @@ -12,7 +12,7 @@ type (
CreateRecoveryFlow(context.Context, *Flow) error
GetRecoveryFlow(ctx context.Context, id uuid.UUID) (*Flow, error)
UpdateRecoveryFlow(context.Context, *Flow) error
DeleteExpiredRecoveryFlows(context.Context, time.Time) error
DeleteExpiredRecoveryFlows(context.Context, time.Time, int) error
}
FlowPersistenceProvider interface {
RecoveryFlowPersister() FlowPersister
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/registration/persistence.go
Expand Up @@ -11,7 +11,7 @@ type FlowPersister interface {
UpdateRegistrationFlow(context.Context, *Flow) error
CreateRegistrationFlow(context.Context, *Flow) error
GetRegistrationFlow(context.Context, uuid.UUID) (*Flow, error)
DeleteExpiredRegistrationFlows(context.Context, time.Time) error
DeleteExpiredRegistrationFlows(context.Context, time.Time, int) error
}

type FlowPersistenceProvider interface {
Expand Down

0 comments on commit 6448602

Please sign in to comment.