/
persister_recovery.go
112 lines (94 loc) · 3.39 KB
/
persister_recovery.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package sql
import (
"context"
"errors"
"fmt"
"time"
"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/ory/kratos/corp"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/flow/recovery"
"github.com/ory/kratos/selfservice/strategy/link"
"github.com/ory/x/sqlcon"
)
var _ recovery.FlowPersister = new(Persister)
var _ link.RecoveryTokenPersister = new(Persister)
func (p Persister) CreateRecoveryFlow(ctx context.Context, r *recovery.Flow) error {
r.NID = corp.ContextualizeNID(ctx, p.nid)
return p.GetConnection(ctx).Create(r)
}
func (p Persister) GetRecoveryFlow(ctx context.Context, id uuid.UUID) (*recovery.Flow, error) {
var r recovery.Flow
if err := p.GetConnection(ctx).Where("id = ? AND nid = ?", id, corp.ContextualizeNID(ctx, p.nid)).First(&r); err != nil {
return nil, sqlcon.HandleError(err)
}
return &r, nil
}
func (p Persister) UpdateRecoveryFlow(ctx context.Context, r *recovery.Flow) error {
cp := *r
cp.NID = corp.ContextualizeNID(ctx, p.nid)
return p.update(ctx, cp)
}
func (p *Persister) CreateRecoveryToken(ctx context.Context, token *link.RecoveryToken) error {
t := token.Token
token.Token = p.hmacValue(ctx, t)
token.NID = corp.ContextualizeNID(ctx, p.nid)
// This should not create the request eagerly because otherwise we might accidentally create an address that isn't
// supposed to be in the database.
if err := p.GetConnection(ctx).Create(token); err != nil {
return err
}
token.Token = t
return nil
}
func (p *Persister) UseRecoveryToken(ctx context.Context, token string) (*link.RecoveryToken, error) {
var rt link.RecoveryToken
nid := corp.ContextualizeNID(ctx, p.nid)
if err := sqlcon.HandleError(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) {
for _, secret := range p.r.Config(ctx).SecretsSession() {
if err = tx.Where("token = ? AND nid = ? AND NOT used", p.hmacValueWithSecret(token, secret), nid).First(&rt); err != nil {
if !errors.Is(sqlcon.HandleError(err), sqlcon.ErrNoRows) {
return err
}
} else {
break
}
}
if err != nil {
return err
}
var ra identity.RecoveryAddress
if err := tx.Where("id = ? AND nid = ?", rt.RecoveryAddressID, nid).First(&ra); err != nil {
if !errors.Is(sqlcon.HandleError(err), sqlcon.ErrNoRows) {
return err
}
}
rt.RecoveryAddress = &ra
/* #nosec G201 TableName is static */
return tx.RawQuery(fmt.Sprintf("UPDATE %s SET used=true, used_at=? WHERE id=? AND nid = ?", rt.TableName(ctx)), time.Now().UTC(), rt.ID, nid).Exec()
})); err != nil {
return nil, err
}
return &rt, nil
}
func (p *Persister) DeleteRecoveryToken(ctx context.Context, token string) error {
/* #nosec G201 TableName is static */
return p.GetConnection(ctx).RawQuery(fmt.Sprintf("DELETE FROM %s WHERE token=? AND nid = ?", new(link.RecoveryToken).TableName(ctx)), token, corp.ContextualizeNID(ctx, p.nid)).Exec()
}
func (p *Persister) DeleteExpiredRecoveryFlows(ctx context.Context, expiresAt time.Time, limit int) error {
// #nosec G201
err := p.GetConnection(ctx).RawQuery(fmt.Sprintf(
"DELETE FROM %s WHERE id in (SELECT id FROM (SELECT id FROM %s c WHERE expires_at <= ? and nid = ? ORDER BY expires_at ASC LIMIT %d ) AS s )",
new(recovery.Flow).TableName(ctx),
new(recovery.Flow).TableName(ctx),
limit,
),
expiresAt,
corp.ContextualizeNID(ctx, p.nid),
).Exec()
if err != nil {
return sqlcon.HandleError(err)
}
return nil
}