Skip to content

Commit

Permalink
fix: mysql slice delete
Browse files Browse the repository at this point in the history
- Add a workaround for [mysql slice delete](gobuffalo/pop#699)
- Optimize logout verification (save 1 db rountrip)
- Update a test to use StaticContextualizer & revert CleanAndMigrate workaround
- Ensure a Client generated with faker satisfies the DB schema
- Remove unused argument from HandleConsentRequest
  • Loading branch information
grantzvolsky authored and aeneasr committed Aug 1, 2022
1 parent c39d19a commit 0fc784a
Show file tree
Hide file tree
Showing 16 changed files with 56 additions and 51 deletions.
10 changes: 5 additions & 5 deletions client/client.go
Expand Up @@ -125,7 +125,7 @@ type Client struct {

// SubjectType requested for responses to this Client. The subject_types_supported Discovery parameter contains a
// list of the supported subject_type values for this server. Valid types include `pairwise` and `public`.
SubjectType string `json:"subject_type" db:"subject_type"`
SubjectType string `json:"subject_type" db:"subject_type" faker:"len=15"`

// URL using the https scheme to be used in calculating Pseudonymous Identifiers by the OP. The URL references a
// file with a single JSON array of redirect_uri values.
Expand All @@ -152,10 +152,10 @@ type Client struct {

// Requested Client Authentication method for the Token Endpoint. The options are client_secret_post,
// client_secret_basic, private_key_jwt, and none.
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty" db:"token_endpoint_auth_method"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty" db:"token_endpoint_auth_method" faker:"len=25"`

// Requested Client Authentication signing algorithm for the Token Endpoint.
TokenEndpointAuthSigningAlgorithm string `json:"token_endpoint_auth_signing_alg,omitempty" db:"token_endpoint_auth_signing_alg"`
TokenEndpointAuthSigningAlgorithm string `json:"token_endpoint_auth_signing_alg,omitempty" db:"token_endpoint_auth_signing_alg" faker:"len=10"`

// Array of request_uri values that are pre-registered by the RP for use at the OP. Servers MAY cache the
// contents of the files referenced by these URIs and not retrieve them at the time they are used in a request.
Expand All @@ -165,12 +165,12 @@ type Client struct {

// JWS [JWS] alg algorithm [JWA] that MUST be used for signing Request Objects sent to the OP. All Request Objects
// from this Client MUST be rejected, if not signed with this algorithm.
RequestObjectSigningAlgorithm string `json:"request_object_signing_alg,omitempty" db:"request_object_signing_alg"`
RequestObjectSigningAlgorithm string `json:"request_object_signing_alg,omitempty" db:"request_object_signing_alg" faker:"len=10"`

// JWS alg algorithm [JWA] REQUIRED for signing UserInfo Responses. If this is specified, the response will be JWT
// [JWT] serialized, and signed using JWS. The default, if omitted, is for the UserInfo Response to return the Claims
// as a UTF-8 encoded JSON object using the application/json content-type.
UserinfoSignedResponseAlg string `json:"userinfo_signed_response_alg,omitempty" db:"userinfo_signed_response_alg"`
UserinfoSignedResponseAlg string `json:"userinfo_signed_response_alg,omitempty" db:"userinfo_signed_response_alg" faker:"len=10"`

// CreatedAt returns the timestamp of the client's creation.
CreatedAt time.Time `json:"created_at,omitempty" db:"created_at"`
Expand Down
4 changes: 3 additions & 1 deletion client/manager_test_helpers.go
Expand Up @@ -132,7 +132,7 @@ func TestHelperCreateGetUpdateDeleteClientNext(t *testing.T, m Storage, networks
t.Run(fmt.Sprintf("nid=%s", nid), func(t *testing.T) {
var client Client
require.NoError(t, faker.FakeData(&client))
client.CreatedAt = time.Now()
client.CreatedAt = time.Now().Truncate(time.Second).UTC()

t.Run("lifecycle=does not exist", func(t *testing.T) {
_, err := m.GetClient(ctx, "1234")
Expand All @@ -148,6 +148,7 @@ func TestHelperCreateGetUpdateDeleteClientNext(t *testing.T, m Storage, networks
assertx.EqualAsJSONExcept(t, &client, c, []string{
"registration_access_token",
"registration_client_uri",
"updated_at",
})

n, err := m.CountClients(ctx)
Expand All @@ -165,6 +166,7 @@ func TestHelperCreateGetUpdateDeleteClientNext(t *testing.T, m Storage, networks
assertx.EqualAsJSONExcept(t, &client, c, []string{
"registration_access_token",
"registration_client_uri",
"updated_at",
})
resources[nid] = append(resources[nid], client)
})
Expand Down
4 changes: 2 additions & 2 deletions consent/handler.go
Expand Up @@ -574,7 +574,7 @@ func (h *Handler) AcceptConsentRequest(w http.ResponseWriter, r *http.Request, p
p.RequestedAt = cr.RequestedAt
p.HandledAt = sqlxx.NullTime(time.Now().UTC())

hr, err := h.r.ConsentManager().HandleConsentRequest(r.Context(), challenge, &p)
hr, err := h.r.ConsentManager().HandleConsentRequest(r.Context(), &p)
if err != nil {
h.r.Writer().WriteError(w, r, errorsx.WithStack(err))
return
Expand Down Expand Up @@ -651,7 +651,7 @@ func (h *Handler) RejectConsentRequest(w http.ResponseWriter, r *http.Request, p
return
}

request, err := h.r.ConsentManager().HandleConsentRequest(r.Context(), challenge, &HandledConsentRequest{
request, err := h.r.ConsentManager().HandleConsentRequest(r.Context(), &HandledConsentRequest{
Error: &p,
ID: challenge,
RequestedAt: hr.RequestedAt,
Expand Down
2 changes: 1 addition & 1 deletion consent/handler_test.go
Expand Up @@ -191,7 +191,7 @@ func TestGetConsentRequest(t *testing.T) {
}))

if tc.handled {
_, err := reg.ConsentManager().HandleConsentRequest(context.Background(), challenge, &HandledConsentRequest{
_, err := reg.ConsentManager().HandleConsentRequest(context.Background(), &HandledConsentRequest{
ID: challenge,
WasHandled: true,
HandledAt: sqlxx.NullTime(time.Now()),
Expand Down
2 changes: 1 addition & 1 deletion consent/manager.go
Expand Up @@ -42,7 +42,7 @@ func (_ ForcedObfuscatedLoginSession) TableName() string {
type Manager interface {
CreateConsentRequest(ctx context.Context, req *ConsentRequest) error
GetConsentRequest(ctx context.Context, challenge string) (*ConsentRequest, error)
HandleConsentRequest(ctx context.Context, challenge string, r *HandledConsentRequest) (*ConsentRequest, error)
HandleConsentRequest(ctx context.Context, r *HandledConsentRequest) (*ConsentRequest, error)
RevokeSubjectConsentSession(ctx context.Context, user string) error
RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error

Expand Down
16 changes: 8 additions & 8 deletions consent/manager_test_helpers.go
Expand Up @@ -197,7 +197,7 @@ func SaneMockHandleConsentRequest(t *testing.T, m Manager, c *ConsentRequest, au
HandledAt: sqlxx.NullTime(time.Now().UTC().Add(-time.Minute)),
}

_, err := m.HandleConsentRequest(context.Background(), c.ID, h)
_, err := m.HandleConsentRequest(context.Background(), h)
require.NoError(t, err)
return h
}
Expand Down Expand Up @@ -487,13 +487,13 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
compareConsentRequest(t, c, got1)
assert.False(t, got1.WasHandled)

got1, err = m.HandleConsentRequest(context.Background(), consentChallenge, h)
got1, err = m.HandleConsentRequest(context.Background(), h)
require.NoError(t, err)
require.Equal(t, time.Now().UTC().Round(time.Minute), time.Time(h.HandledAt).Round(time.Minute))
compareConsentRequest(t, c, got1)

h.GrantedAudience = sqlxx.StringSlicePipeDelimiter{"new-audience"}
_, err = m.HandleConsentRequest(context.Background(), consentChallenge, h)
_, err = m.HandleConsentRequest(context.Background(), h)
require.NoError(t, err)

got2, err := m.VerifyAndInvalidateConsentRequest(context.Background(), makeID("verifier", tenant, tc.key))
Expand All @@ -504,7 +504,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit

// Trying to update this again should return an error because the consent request was used.
h.GrantedAudience = sqlxx.StringSlicePipeDelimiter{"new-audience", "new-audience-2"}
_, err = m.HandleConsentRequest(context.Background(), consentChallenge, h)
_, err = m.HandleConsentRequest(context.Background(), h)
require.Error(t, err)

if tc.hasError {
Expand Down Expand Up @@ -605,9 +605,9 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit

require.NoError(t, m.CreateConsentRequest(context.Background(), cr1))
require.NoError(t, m.CreateConsentRequest(context.Background(), cr2))
_, err := m.HandleConsentRequest(context.Background(), challengerv1, hcr1)
_, err := m.HandleConsentRequest(context.Background(), hcr1)
require.NoError(t, err)
_, err = m.HandleConsentRequest(context.Background(), challengerv2, hcr2)
_, err = m.HandleConsentRequest(context.Background(), hcr2)
require.NoError(t, err)

require.NoError(t, fositeManager.CreateAccessTokenSession(context.Background(), makeID("", tenant, "trva1"), &fosite.Request{Client: cr1.Client, ID: challengerv1, RequestedAt: time.Now()}))
Expand Down Expand Up @@ -678,9 +678,9 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit

require.NoError(t, m.CreateConsentRequest(context.Background(), cr1))
require.NoError(t, m.CreateConsentRequest(context.Background(), cr2))
_, err := m.HandleConsentRequest(context.Background(), challengerv1, hcr1)
_, err := m.HandleConsentRequest(context.Background(), hcr1)
require.NoError(t, err)
_, err = m.HandleConsentRequest(context.Background(), challengerv2, hcr2)
_, err = m.HandleConsentRequest(context.Background(), hcr2)
require.NoError(t, err)

for i, tc := range []struct {
Expand Down
8 changes: 4 additions & 4 deletions consent/sdk_test.go
Expand Up @@ -101,13 +101,13 @@ func TestSDK(t *testing.T) {
require.NoError(t, m.CreateConsentRequest(context.Background(), cr2))
require.NoError(t, m.CreateConsentRequest(context.Background(), cr3))
require.NoError(t, m.CreateConsentRequest(context.Background(), cr4))
_, err := m.HandleConsentRequest(context.Background(), makeID("challenge", tenant, "1"), hcr1)
_, err := m.HandleConsentRequest(context.Background(), hcr1)
require.NoError(t, err)
_, err = m.HandleConsentRequest(context.Background(), makeID("challenge", tenant, "2"), hcr2)
_, err = m.HandleConsentRequest(context.Background(), hcr2)
require.NoError(t, err)
_, err = m.HandleConsentRequest(context.Background(), makeID("challenge", tenant, "3"), hcr3)
_, err = m.HandleConsentRequest(context.Background(), hcr3)
require.NoError(t, err)
_, err = m.HandleConsentRequest(context.Background(), makeID("challenge", tenant, "4"), hcr4)
_, err = m.HandleConsentRequest(context.Background(), hcr4)
require.NoError(t, err)

lur1 := MockLogoutRequest("testsdk-1", true, tenant)
Expand Down
2 changes: 1 addition & 1 deletion driver/registry_base.go
Expand Up @@ -188,7 +188,7 @@ func (m *RegistryBase) AuditLogger() *logrusx.Logger {

func (m *RegistryBase) ClientHasher() fosite.Hasher {
if m.fh == nil {
if m.Tracer(context.TODO()).IsLoaded() {
if m.Tracer(contextx.RootContext).IsLoaded() {
m.fh = &tracing.TracedBCrypt{WorkFactor: m.Config(contextx.RootContext).BCryptCost()}
} else {
m.fh = x.NewBCrypt(m.Config(contextx.RootContext))
Expand Down
7 changes: 0 additions & 7 deletions internal/driver.go
Expand Up @@ -7,7 +7,6 @@ import (
"testing"

"github.com/ory/x/configx"
"github.com/ory/x/networkx"

"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -78,14 +77,8 @@ func newRegistryDefault(t *testing.T, url string, c *config.Provider, migrate bo

func CleanAndMigrate(reg driver.Registry) func(*testing.T) {
return func(t *testing.T) {
net := &networkx.Network{}
recreateNetwork := reg.Persister().Connection(context.Background()).First(net) == nil
x.CleanSQLPop(t, reg.Persister().Connection(context.Background()))
require.NoError(t, reg.Persister().MigrateUp(context.Background()))
if recreateNetwork {
require.NoError(t, reg.Persister().Connection(context.Background()).RawQuery("DELETE FROM networks").Exec())
require.NoError(t, reg.Persister().Connection(context.Background()).Create(net))
}
t.Log("clean and migrate done")
}
}
Expand Down
6 changes: 3 additions & 3 deletions internal/testhelpers/janitor_test_helper.go
Expand Up @@ -274,12 +274,12 @@ func (j *JanitorConsentTestHelper) ConsentRejectionSetup(ctx context.Context, cm
for _, r := range j.flushConsentRequests {
if r.ID == j.flushConsentRequests[0].ID {
// accept this one
_, err = cm.HandleConsentRequest(ctx, r.ID, consent.NewHandledConsentRequest(
_, err = cm.HandleConsentRequest(ctx, consent.NewHandledConsentRequest(
r.ID, false, r.RequestedAt, r.AuthenticatedAt))
require.NoError(t, err)
continue
}
_, err = cm.HandleConsentRequest(ctx, r.ID, consent.NewHandledConsentRequest(
_, err = cm.HandleConsentRequest(ctx, consent.NewHandledConsentRequest(
r.ID, true, r.RequestedAt, r.AuthenticatedAt))
require.NoError(t, err)
}
Expand Down Expand Up @@ -362,7 +362,7 @@ func (j *JanitorConsentTestHelper) ConsentTimeoutSetup(ctx context.Context, cm c
}

// Create at least 1 consent request that has been accepted
_, err = cm.HandleConsentRequest(ctx, j.flushConsentRequests[0].ID, &consent.HandledConsentRequest{
_, err = cm.HandleConsentRequest(ctx, &consent.HandledConsentRequest{
ID: j.flushConsentRequests[0].ID,
WasHandled: true,
HandledAt: sqlxx.NullTime(time.Now()),
Expand Down
2 changes: 1 addition & 1 deletion oauth2/fosite_store_helpers.go
Expand Up @@ -153,7 +153,7 @@ func mockRequestForeignKey(t *testing.T, id string, x InternalRegistry, createCl

require.NoError(t, x.ConsentManager().CreateLoginRequest(context.Background(), &consent.LoginRequest{Client: cl, OpenIDConnectContext: new(consent.OpenIDConnectContext), ID: id, Verifier: id, AuthenticatedAt: sqlxx.NullTime(time.Now()), RequestedAt: time.Now()}))
require.NoError(t, x.ConsentManager().CreateConsentRequest(context.Background(), cr))
_, err := x.ConsentManager().HandleConsentRequest(context.Background(), id, &consent.HandledConsentRequest{
_, err := x.ConsentManager().HandleConsentRequest(context.Background(), &consent.HandledConsentRequest{
ConsentRequest: cr, Session: new(consent.ConsentRequestSessionData), AuthenticatedAt: sqlxx.NullTime(time.Now()),
ID: id,
RequestedAt: time.Now(),
Expand Down
3 changes: 2 additions & 1 deletion oauth2/oauth2_refresh_token_test.go
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/ory/fosite"
hc "github.com/ory/hydra/client"
"github.com/ory/hydra/driver"
"github.com/ory/hydra/internal"
"github.com/ory/hydra/oauth2"
"github.com/ory/hydra/x/contextx"
"github.com/ory/x/dbal"
Expand Down Expand Up @@ -84,7 +85,7 @@ func TestCreateRefreshTokenSessionStress(t *testing.T) {
}
net := &networkx.Network{}
require.NoError(t, dbRegistry.Persister().Connection(context.Background()).First(net))
dbRegistry.WithContextualizer(&contextx.StaticContextualizer{NID: net.ID})
dbRegistry.WithContextualizer(&contextx.StaticContextualizer{NID: net.ID, C: internal.NewConfigurationWithDefaults()})

ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second))
require.NoError(t, dbRegistry.OAuth2Storage().(clientCreator).CreateClient(ctx, &testClient))
Expand Down
21 changes: 9 additions & 12 deletions persistence/sql/persister_consent.go
Expand Up @@ -225,7 +225,7 @@ func (p *Persister) GetLoginRequest(ctx context.Context, login_challenge string)
})
}

func (p *Persister) HandleConsentRequest(ctx context.Context, challenge string, r *consent.HandledConsentRequest) (*consent.ConsentRequest, error) {
func (p *Persister) HandleConsentRequest(ctx context.Context, r *consent.HandledConsentRequest) (*consent.ConsentRequest, error) {
f := &flow.Flow{}

if err := sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("consent_challenge_id = ?", r.ID).First(f)); errors.Is(err, sqlcon.ErrNoRows) {
Expand Down Expand Up @@ -496,24 +496,21 @@ func (p *Persister) GetLogoutRequest(ctx context.Context, challenge string) (*co
func (p *Persister) VerifyAndInvalidateLogoutRequest(ctx context.Context, verifier string) (*consent.LogoutRequest, error) {
var lr consent.LogoutRequest
return &lr, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
if err := p.QueryWithNetwork(ctx).Where("verifier=? AND was_used=FALSE AND accepted=TRUE AND rejected=FALSE", verifier).Select("challenge").First(&lr); err != nil {
if err == sql.ErrNoRows {
return errorsx.WithStack(x.ErrNotFound)
}

return sqlcon.HandleError(err)
}

if err := c.RawQuery("UPDATE hydra_oauth2_logout_request SET was_used=TRUE WHERE verifier=? AND nid = ?", verifier, p.NetworkID(ctx)).Exec(); err != nil {
if count, err := c.RawQuery(
"UPDATE hydra_oauth2_logout_request SET was_used=TRUE WHERE nid = ? AND verifier=? AND was_used=FALSE AND accepted=TRUE AND rejected=FALSE",
p.NetworkID(ctx),
verifier,
).ExecWithCount(); count == 0 && err == nil {
return errorsx.WithStack(x.ErrNotFound)
} else if err != nil {
return sqlcon.HandleError(err)
}

updated, err := p.GetLogoutRequest(ctx, lr.ID)
err := sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("verifier=?", verifier).First(&lr))
if err != nil {
return err
}

lr = *updated
return nil
})
}
Expand Down
9 changes: 8 additions & 1 deletion persistence/sql/persister_oauth2.go
Expand Up @@ -398,7 +398,14 @@ func (p *Persister) flushInactiveTokens(ctx context.Context, notAfter time.Time,
}

if i != j {
err = p.QueryWithNetwork(ctx).Where("signature in (?)", signatures[i:j]).Delete(&OAuth2RequestSQL{Table: table})
ss := signatures[i:j]
iss := make([]interface{}, len(ss))
for i, v := range ss { // Workaround for https://github.com/gobuffalo/pop/issues/699
iss[i] = v
}

err = p.QueryWithNetwork(ctx).Where("signature in (?)", iss...).Delete(&OAuth2RequestSQL{Table: table})

if err != nil {
return sqlcon.HandleError(err)
}
Expand Down
8 changes: 6 additions & 2 deletions persistence/sql/persister_test.go
Expand Up @@ -145,19 +145,23 @@ func TestManagers(t *testing.T) {
"memory": internal.NewRegistrySQLFromURL(t, dbal.SQLiteSharedInMemory, true, &contextx.DefaultContextualizer{}),
}

tenant2NID, _ := uuid.NewV4()
t2registries := map[string]driver.Registry{
"memory": internal.NewRegistrySQLFromURL(t, dbal.SQLiteSharedInMemory, false, &contextx.DefaultContextualizer{}),
}

if !testing.Short() {
t1registries["postgres"], t1registries["mysql"], t1registries["cockroach"], _ = internal.ConnectDatabases(t, true, &contextx.DefaultContextualizer{})
t2registries["postgres"], t2registries["mysql"], t2registries["cockroach"], _ = internal.ConnectDatabases(t, false, &contextx.DefaultContextualizer{})
t1registries["postgres"], t1registries["mysql"], t1registries["cockroach"], _ = internal.ConnectDatabases(t, true, &contextx.DefaultContextualizer{})
}

tenant1NID, _ := uuid.NewV4()
tenant2NID, _ := uuid.NewV4()

for k, t1 := range t1registries {
t2 := t2registries[k]
require.NoError(t, t1.Persister().Connection(ctx).Create(&networkx.Network{ID: tenant1NID}))
require.NoError(t, t2.Persister().Connection(ctx).Create(&networkx.Network{ID: tenant2NID}))
t1.WithContextualizer(&contextx.StaticContextualizer{NID: tenant1NID, C: t1.Config(ctx)})
t2.WithContextualizer(&contextx.StaticContextualizer{NID: tenant2NID, C: t2.Config(ctx)})
t.Run("parallel-boundary", func(t *testing.T) { testRegistry(t, ctx, k, t1, t2) })
}
Expand Down
Expand Up @@ -4,7 +4,8 @@ UPDATE hydra_oauth2_jti_blacklist SET nid = (SELECT id FROM networks LIMIT 1);
CREATE TABLE "_hydra_oauth2_jti_blacklist_tmp" (
signature VARCHAR(64) NOT NULL PRIMARY KEY,
expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
nid CHAR(36) NOT NULL
nid CHAR(36) NOT NULL,
CHECK (nid != '00000000-0000-0000-0000-000000000000')
);
INSERT INTO "_hydra_oauth2_jti_blacklist_tmp" (signature, expires_at, nid) SELECT signature, expires_at, nid FROM "hydra_oauth2_jti_blacklist";
DROP TABLE "hydra_oauth2_jti_blacklist";
Expand Down

0 comments on commit 0fc784a

Please sign in to comment.