From cf674bd0ec6dc7db02d8bf08350795247a17e329 Mon Sep 17 00:00:00 2001 From: Garren Date: Fri, 16 Sep 2022 07:32:37 +0200 Subject: [PATCH] fix(qe) OCC fixes for update/delete many MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This makes sure that the correct where clauses get passed to the update and delete clauses for an operation so that when a user is implementing application level optimistic concurrency control (occ) that is will work as expected. Co-authored-by: Tom Houlé --- Cargo.lock | 2 +- .../query-engine-tests/Cargo.toml | 2 +- .../query-engine-tests/tests/new/mod.rs | 1 + .../tests/new/multi_schema.rs | 1 + .../query-engine-tests/tests/new/occ.rs | 362 ++++++++++++++++++ .../tests/new/occ_simple.prisma | 18 + .../new/ref_actions/on_update/no_action.rs | 16 + .../writes/top_level_mutations/create.rs | 2 +- .../writes/top_level_mutations/update_many.rs | 15 + .../query-tests-setup/src/query_result.rs | 8 +- .../src/interface/connection.rs | 37 +- .../src/interface/transaction.rs | 27 +- .../src/root_queries/write.rs | 7 +- .../query-connector/src/interface.rs | 10 + .../connectors/query-connector/src/lib.rs | 10 + .../src/database/connection.rs | 19 +- .../src/database/operations/write.rs | 25 +- .../src/database/transaction.rs | 19 +- .../src/filter_conversion.rs | 95 +++-- .../sql-query-connector/src/join_utils.rs | 4 +- .../src/query_builder/read.rs | 4 +- .../src/query_builder/write.rs | 14 +- .../sql-query-connector/src/query_ext.rs | 2 +- .../interpreter/query_interpreters/write.rs | 4 +- 24 files changed, 648 insertions(+), 56 deletions(-) create mode 100644 query-engine/connector-test-kit-rs/query-engine-tests/tests/new/occ.rs create mode 100644 query-engine/connector-test-kit-rs/query-engine-tests/tests/new/occ_simple.prisma diff --git a/Cargo.lock b/Cargo.lock index ac937db65b11..ad4974af34ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3086,7 +3086,7 @@ dependencies = [ [[package]] name = "quaint" version = "0.2.0-alpha.13" -source = "git+https://github.com/prisma/quaint#2baa684d8c4d2512c9d99dc84fdafbdffc638b6a" +source = "git+https://github.com/prisma/quaint#ffe1979c7b1761931a8ece6aad47a4627af1bc82" dependencies = [ "async-trait", "base64 0.12.3", diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/Cargo.toml b/query-engine/connector-test-kit-rs/query-engine-tests/Cargo.toml index df629dce6c68..0b9f9afe9e7b 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/Cargo.toml +++ b/query-engine/connector-test-kit-rs/query-engine-tests/Cargo.toml @@ -17,7 +17,7 @@ chrono = "0.4" datamodel-connector = { path = "../../../libs/datamodel/connectors/datamodel-connector" } base64 = "0.13" uuid = "1" -tokio = "1.8" +tokio = "1.21.0" prisma-value = { path = "../../../libs/prisma-value" } query-engine-metrics = { path = "../../metrics"} diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/mod.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/mod.rs index c65fc7ef45be..b2f0a08506e7 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/mod.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/mod.rs @@ -5,5 +5,6 @@ mod interactive_tx; mod metrics; mod multi_schema; mod native_types; +mod occ; mod ref_actions; mod regressions; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/multi_schema.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/multi_schema.rs index f58e754c5bf0..34a6504afd5b 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/multi_schema.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/multi_schema.rs @@ -1,4 +1,5 @@ use query_engine_tests::test_suite; + #[test_suite(capabilities(MultiSchema), exclude(Mysql))] mod multi_schema { use query_engine_tests::*; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/occ.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/occ.rs new file mode 100644 index 000000000000..242ff691d9ec --- /dev/null +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/occ.rs @@ -0,0 +1,362 @@ +use query_engine_tests::*; +use std::sync::Arc; + +#[test_suite] +mod occ { + pub fn occ_simple() -> String { + include_str!("occ_simple.prisma").to_owned() + } + + async fn create_one_seat(runner: Arc) { + runner + .query(r#"mutation { createOneSeat(data: { movie: "zardoz", id: 1 }) { id } }"#) + .await + .unwrap() + .assert_success(); + } + + async fn create_one_user(user_id: u64, runner: Arc) { + let query = format!(r#"mutation {{ createOneUser(data: {{ id: {user_id} }}) {{ id }} }}"#); + runner.query(query).await.unwrap().assert_success(); + } + + async fn find_unclaimed_seat(runner: Arc) -> (u64, u64) { + let seat_query = "query { findFirstSeat(where: { movie: \"zardoz\", userId: null }) { id version } }"; + let seat_result = runner.query(seat_query).await.unwrap().to_json_value(); + let available_seat = &seat_result["data"]["findFirstSeat"]; + match available_seat { + serde_json::Value::Null => (0, 0), + other => (other["id"].as_u64().unwrap(), other["version"].as_u64().unwrap()), + } + } + + async fn book_unclaimed_seat(user_id: u64, seat_id: u64, runner: Arc) -> (u64, u64) { + let query = indoc::formatdoc!( + r##" + mutation {{ + updateManySeat( + data: {{ userId: {user_id}, version: {{ increment: 1 }} }}, + where: {{ id: {seat_id}, version: 0 }} + ) + {{ count }} + }} + "## + ); + let response = runner.query(query).await.unwrap().to_json_value(); + let seat_count = response["data"]["updateManySeat"]["count"].as_u64().unwrap(); + (user_id, seat_count) + } + + async fn book_seat_for_user(user_id: u64, runner: Arc) -> (u64, u64) { + let (seat_id, _version) = find_unclaimed_seat(runner.clone()).await; + book_unclaimed_seat(user_id, seat_id, runner).await + } + + async fn delete_seats(runner: Arc) { + let delete_seats = r#" + mutation { + deleteManySeat(where: {}) { + count + } + } + "#; + runner.query(delete_seats).await.unwrap().assert_success(); + } + + async fn delete_users(runner: Arc) { + let delete_users = r#" + mutation { + deleteManyUser(where: {}) { + count + } + } + "#; + runner.query(delete_users).await.unwrap().assert_success(); + } + + async fn run_occ_reproduce_test(runner: Arc) { + const USERS_COUNT: u64 = 5; + + create_one_seat(runner.clone()).await; + + for i in 0..=USERS_COUNT { + create_one_user(i, runner.clone()).await; + } + + let mut set = tokio::task::JoinSet::new(); + for user_id in 0..=USERS_COUNT { + set.spawn(book_seat_for_user(user_id, runner.clone())); + } + + let mut booked_user_id = 100; + let mut total_booked = 0; + while let Some(res) = set.join_next().await { + let (user_id, count) = res.unwrap(); + + if count > 0 { + total_booked += count; + booked_user_id = user_id; + } + } + + assert_eq!(total_booked, 1); + + let booked_seat = runner + .query("query { findFirstSeat { id version userId } }") + .await + .unwrap() + .to_json_value(); + + let found_booked_user_id = booked_seat["data"]["findFirstSeat"]["userId"].as_u64().unwrap(); + + assert_eq!(booked_user_id, found_booked_user_id); + } + + #[connector_test(schema(occ_simple), exclude(MongoDB, CockroachDb))] + async fn occ_update_many_test(runner: Runner) -> TestResult<()> { + let runner = Arc::new(runner); + + // This test can give false positives so we run it a few times + // to make sure. + for _ in 0..=5 { + delete_seats(runner.clone()).await; + delete_users(runner.clone()).await; + run_occ_reproduce_test(runner.clone()).await; + } + + Ok(()) + } + + #[connector_test(schema(occ_simple), exclude(CockroachDb))] + async fn occ_update_test(runner: Runner) -> TestResult<()> { + let runner = Arc::new(runner); + + create_one_resource(runner.clone()).await; + + let mut set = tokio::task::JoinSet::new(); + + set.spawn(update_one_resource(runner.clone())); + set.spawn(update_one_resource(runner.clone())); + set.spawn(update_one_resource(runner.clone())); + set.spawn(update_one_resource(runner.clone())); + + while (set.join_next().await).is_some() {} + + let res = find_one_resource(runner).await; + + let expected = serde_json::json!({ + "data": { + "findFirstResource": { + "occStamp": 1, + "id": 1 + } + } + }); + + assert_eq!(res, expected); + + Ok(()) + } + + #[connector_test(schema(occ_simple))] + async fn occ_delete_test(runner: Runner) -> TestResult<()> { + let runner = Arc::new(runner); + + create_one_resource(runner.clone()).await; + + let mut set = tokio::task::JoinSet::new(); + + set.spawn(update_and_delete(runner.clone())); + set.spawn(update_and_delete(runner.clone())); + set.spawn(update_and_delete(runner.clone())); + set.spawn(update_and_delete(runner.clone())); + set.spawn(update_and_delete(runner.clone())); + + while (set.join_next().await).is_some() {} + + let res = find_one_resource(runner).await; + + let expected = serde_json::json!({ + "data": { + "findFirstResource": { + "occStamp": 1, + "id": 1 + } + } + }); + + assert_eq!(res, expected); + + Ok(()) + } + + #[connector_test(schema(occ_simple))] + async fn occ_delete_many_test(runner: Runner) -> TestResult<()> { + let runner = Arc::new(runner); + + create_one_resource(runner.clone()).await; + + let mut set = tokio::task::JoinSet::new(); + + set.spawn(delete_many_resource(runner.clone())); + set.spawn(delete_many_resource(runner.clone())); + set.spawn(delete_many_resource(runner.clone())); + set.spawn(delete_many_resource(runner.clone())); + set.spawn(delete_many_resource(runner.clone())); + + let mut num_deleted: u64 = 0; + while let Some(res) = set.join_next().await { + if let Ok(row_count) = res { + if row_count > 0 { + num_deleted += 1; + } + } + } + + assert_eq!(num_deleted, 1); + let res = find_one_resource(runner).await; + + let expected = serde_json::json!({ + "data": { + "findFirstResource": serde_json::Value::Null + } + }); + assert_eq!(res, expected); + + Ok(()) + } + + // Because of the way upsert works this test is a little bit flaky. Ignoring until we fix upsert + #[allow(dead_code)] + #[ignore] + async fn occ_upsert_test(runner: Runner) -> TestResult<()> { + let runner = Arc::new(runner); + + let mut set = tokio::task::JoinSet::new(); + + set.spawn(upsert_one_resource(runner.clone())); + set.spawn(upsert_one_resource(runner.clone())); + set.spawn(upsert_one_resource(runner.clone())); + set.spawn(upsert_one_resource(runner.clone())); + set.spawn(upsert_one_resource(runner.clone())); + + while (set.join_next().await).is_some() {} + + let res = find_one_resource(runner.clone()).await; + + // MongoDB is different here and seems to only do one create with all the upserts + // where as all the sql databases will do one create and one upsert + let expected = if matches!(runner.connector(), ConnectorTag::MongoDb(_)) { + serde_json::json!({ + "data": { + "findFirstResource": { + "occStamp": 0, + "id": 1 + } + } + }) + } else { + serde_json::json!({ + "data": { + "findFirstResource": { + "occStamp": 1, + "id": 1 + } + } + }) + }; + assert_eq!(res, expected); + + Ok(()) + } + + async fn update_and_delete(runner: Arc) { + update_one_resource(runner.clone()).await; + delete_one_resource(runner).await; + } + + async fn create_one_resource(runner: Arc) { + let create_one_resource = r#" + mutation { + createOneResource(data: {id: 1}) { + id + } + }"#; + + runner.query(create_one_resource).await.unwrap().to_json_value(); + } + + async fn update_one_resource(runner: Arc) -> serde_json::Value { + let update_one_resource = r#" + mutation { + updateOneResource(data: {occStamp: {increment: 1}}, where: {occStamp: 0}) { + occStamp, + id + } + } + "#; + + runner.query(update_one_resource).await.unwrap().to_json_value() + } + + #[allow(dead_code)] + async fn upsert_one_resource(runner: Arc) -> serde_json::Value { + let upsert_one_resource = r#" + mutation { + upsertOneResource(where: {occStamp: 0}, + create: { + occStamp: 0, + id: 1 + }, + update: { + occStamp: {increment: 1} + }) { + id, + occStamp + } + } + "#; + + runner.query(upsert_one_resource).await.unwrap().to_json_value() + } + + async fn delete_one_resource(runner: Arc) -> serde_json::Value { + let delete_one_resource = r#" + mutation { + deleteOneResource(where: {occStamp: 0}) { + occStamp, + id + } + } + "#; + + runner.query(delete_one_resource).await.unwrap().to_json_value() + } + + async fn delete_many_resource(runner: Arc) -> u64 { + let delete_many_resource = r#" + mutation { + deleteManyResource(where: {occStamp: 0}) { + count + } + } + "#; + + let res = runner.query(delete_many_resource).await.unwrap().to_json_value(); + + res["data"]["deleteManyResource"]["count"].as_u64().unwrap() + } + + async fn find_one_resource(runner: Arc) -> serde_json::Value { + let find_one_resource = r#" + { + findFirstResource(where: {}) { + occStamp, + id + } + } + "#; + + runner.query(find_one_resource).await.unwrap().to_json_value() + } +} diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/occ_simple.prisma b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/occ_simple.prisma new file mode 100644 index 000000000000..cc81d0d2f4eb --- /dev/null +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/occ_simple.prisma @@ -0,0 +1,18 @@ +model User { + #id(id, Int, @id) + Seat Seat? +} + +model Seat { + #id(id, Int, @id) + movie String @unique + userId Int? @unique + claimedBy User? @relation(fields: [userId], references: [id]) + version Int @default(0) + @@unique([id, version]) +} + +model Resource { + #id(id, Int, @id) + occStamp Int @default(0) @unique +} diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/no_action.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/no_action.rs index 8c7a281ca0c8..137b5cd0ce59 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/no_action.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/no_action.rs @@ -176,6 +176,22 @@ mod one2one_opt { @r###"{"data":{"updateOneParent":{"id":1}}}"### ); + Ok(()) + } + + /// Updating the parent succeeds if no child is connected. + #[connector_test] + async fn update_parent_with_many(runner: Runner) -> TestResult<()> { + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, uniq: "1" }) { id }}"#), + @r###"{"data":{"createOneParent":{"id":1}}}"### + ); + + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneParent(data: { id: 2, uniq: "2" }) { id }}"#), + @r###"{"data":{"createOneParent":{"id":2}}}"### + ); + insta::assert_snapshot!( run_query!(&runner, r#"mutation { updateManyParent(where: { id: 1 }, data: { uniq: "u1" }) { count }}"#), @r###"{"data":{"updateManyParent":{"count":1}}}"### diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create.rs index 3641fa6188b7..524711550ba8 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create.rs @@ -122,7 +122,7 @@ mod create { } // A Create Mutation should create and return item with explicit null values after previous mutation with explicit non-null values - #[connector_test] + #[connector_test(exclude(CockroachDb))] async fn return_item_non_null_attrs_then_explicit_null_attrs(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many.rs index 54aa68322874..edfb80795741 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/update_many.rs @@ -333,11 +333,26 @@ mod json_update_many { @r###"{"data":{"updateManyTestModel":{"count":1}}}"### ); + insta::assert_snapshot!( + run_query!(&runner, r#"{ findFirstTestModel(where: {id: 1}) { id, json } }"#), + @r###"{"data":{"findFirstTestModel":{"id":1,"json":null}}}"### + ); + + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { updateManyTestModel(where: { id: 1 }, data: { json: "{}" }) { count }}"#), + @r###"{"data":{"updateManyTestModel":{"count":1}}}"### + ); + insta::assert_snapshot!( run_query!(&runner, r#"mutation { updateManyTestModel(where: { id: 1 }, data: { json: null }) { count }}"#), @r###"{"data":{"updateManyTestModel":{"count":1}}}"### ); + insta::assert_snapshot!( + run_query!(&runner, r#"{ findFirstTestModel(where: {id: 1}) { id, json } }"#), + @r###"{"data":{"findFirstTestModel":{"id":1,"json":null}}}"### + ); + Ok(()) } diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/query_result.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/query_result.rs index b83d2a22d64f..e0ef4637819a 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/query_result.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/query_result.rs @@ -15,10 +15,8 @@ impl QueryResult { /// Asserts absence of errors in the result. Panics with assertion error. pub fn assert_success(&self) { if self.failed() { - dbg!(self.errors()); + panic!("{}", self.to_string()); } - - assert!(!self.failed()) } /// Asserts presence of errors in the result. @@ -71,6 +69,10 @@ impl QueryResult { } } + pub fn to_json_value(&self) -> serde_json::Value { + serde_json::to_value(&self.response).unwrap() + } + pub fn to_string_pretty(&self) -> String { serde_json::to_string_pretty(&self.response).unwrap() } diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs b/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs index 84fd0b1ebf93..a273349d4432 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs @@ -6,7 +6,8 @@ use crate::{ }; use async_trait::async_trait; use connector_interface::{ - Connection, ConnectionLike, ReadOperations, RelAggregationSelection, Transaction, WriteArgs, WriteOperations, + Connection, ConnectionLike, ReadOperations, RelAggregationSelection, Transaction, UpdateType, WriteArgs, + WriteOperations, }; use mongodb::{ClientSession, Database}; use prisma_models::{prelude::*, SelectionResult}; @@ -76,8 +77,40 @@ impl WriteOperations for MongoDbConnection { args: WriteArgs, _trace_id: Option, ) -> connector_interface::Result> { - catch(async move { write::update_records(&self.database, &mut self.session, model, record_filter, args).await }) + catch(async move { + write::update_records( + &self.database, + &mut self.session, + model, + record_filter, + args, + UpdateType::Many, + ) .await + }) + .await + } + + async fn update_record( + &mut self, + model: &ModelRef, + record_filter: connector_interface::RecordFilter, + args: WriteArgs, + _trace_id: Option, + ) -> connector_interface::Result> { + catch(async move { + let mut res = write::update_records( + &self.database, + &mut self.session, + model, + record_filter, + args, + UpdateType::One, + ) + .await?; + Ok(res.pop()) + }) + .await } async fn delete_records( diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs index 6460d031425d..a780ae7c7438 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs @@ -3,7 +3,9 @@ use crate::{ error::MongoError, root_queries::{aggregate, read, write}, }; -use connector_interface::{ConnectionLike, ReadOperations, RelAggregationSelection, Transaction, WriteOperations}; +use connector_interface::{ + ConnectionLike, ReadOperations, RelAggregationSelection, Transaction, UpdateType, WriteOperations, +}; use mongodb::options::{Acknowledgment, ReadConcern, TransactionOptions, WriteConcern}; use prisma_models::SelectionResult; use query_engine_metrics::{decrement_gauge, increment_gauge, metrics, PRISMA_CLIENT_QUERIES_ACTIVE}; @@ -113,12 +115,35 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { model, record_filter, args, + UpdateType::Many, ) .await }) .await } + async fn update_record( + &mut self, + model: &ModelRef, + record_filter: connector_interface::RecordFilter, + args: connector_interface::WriteArgs, + _trace_id: Option, + ) -> connector_interface::Result> { + catch(async move { + let mut res = write::update_records( + &self.connection.database, + &mut self.connection.session, + model, + record_filter, + args, + UpdateType::One, + ) + .await?; + Ok(res.pop()) + }) + .await + } + async fn delete_records( &mut self, model: &ModelRef, diff --git a/query-engine/connectors/mongodb-query-connector/src/root_queries/write.rs b/query-engine/connectors/mongodb-query-connector/src/root_queries/write.rs index 5c1b7675dc9b..af15774d5abc 100644 --- a/query-engine/connectors/mongodb-query-connector/src/root_queries/write.rs +++ b/query-engine/connectors/mongodb-query-connector/src/root_queries/write.rs @@ -142,6 +142,7 @@ pub async fn update_records<'conn>( model: &ModelRef, record_filter: RecordFilter, mut args: WriteArgs, + update_type: UpdateType, ) -> crate::Result> { let coll = database.collection::(model.db_name()); @@ -198,9 +199,13 @@ pub async fn update_records<'conn>( if !update_docs.is_empty() { logger::log_update_many_vec(coll.name(), &filter, &update_docs); - metrics(|| coll.update_many_with_session(filter, update_docs, None, session)) + let res = metrics(|| coll.update_many_with_session(filter, update_docs, None, session)) .instrument(span) .await?; + + if update_type == UpdateType::Many && res.modified_count == 0 { + return Ok(Vec::new()); + } } let ids = ids diff --git a/query-engine/connectors/query-connector/src/interface.rs b/query-engine/connectors/query-connector/src/interface.rs index 41a182419f7f..964c164568bc 100644 --- a/query-engine/connectors/query-connector/src/interface.rs +++ b/query-engine/connectors/query-connector/src/interface.rs @@ -301,6 +301,16 @@ pub trait WriteOperations { trace_id: Option, ) -> crate::Result>; + /// Update record in the `Model` with the given `WriteArgs` filtered by the + /// `Filter`. + async fn update_record( + &mut self, + model: &ModelRef, + record_filter: RecordFilter, + args: WriteArgs, + trace_id: Option, + ) -> crate::Result>; + /// Delete records in the `Model` with the given `Filter`. async fn delete_records( &mut self, diff --git a/query-engine/connectors/query-connector/src/lib.rs b/query-engine/connectors/query-connector/src/lib.rs index 6a529c25a123..7d70517516a0 100644 --- a/query-engine/connectors/query-connector/src/lib.rs +++ b/query-engine/connectors/query-connector/src/lib.rs @@ -18,3 +18,13 @@ pub use query_arguments::*; pub use write_args::*; pub type Result = std::result::Result; + +/// When we write a single record using this update_records function, we always +/// want the id of the changed record back. Even if the row wasn't updated. This can happen in situations where +/// we could increment a null value and the update count would be zero for mysql. +/// However when we updating any records we want to return an empty array if zero items were updated +#[derive(PartialEq)] +pub enum UpdateType { + Many, + One, +} diff --git a/query-engine/connectors/sql-query-connector/src/database/connection.rs b/query-engine/connectors/sql-query-connector/src/database/connection.rs index 08c7e65d67fe..01df9cf84d17 100644 --- a/query-engine/connectors/sql-query-connector/src/database/connection.rs +++ b/query-engine/connectors/sql-query-connector/src/database/connection.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use connector::{ConnectionLike, RelAggregationSelection}; use connector_interface::{ self as connector, filter::Filter, AggregationRow, AggregationSelection, Connection, QueryArguments, - ReadOperations, RecordFilter, Transaction, WriteArgs, WriteOperations, + ReadOperations, RecordFilter, Transaction, UpdateType, WriteArgs, WriteOperations, }; use prisma_models::{prelude::*, SelectionResult}; use prisma_value::PrismaValue; @@ -209,7 +209,22 @@ where trace_id: Option, ) -> connector::Result> { catch(self.connection_info.clone(), async move { - write::update_records(&self.inner, model, record_filter, args, trace_id).await + write::update_records(&self.inner, model, record_filter, args, UpdateType::Many, trace_id).await + }) + .await + } + + async fn update_record( + &mut self, + model: &ModelRef, + record_filter: RecordFilter, + args: WriteArgs, + trace_id: Option, + ) -> connector::Result> { + catch(self.connection_info.clone(), async move { + let mut res = + write::update_records(&self.inner, model, record_filter, args, UpdateType::One, trace_id).await?; + Ok(res.pop()) }) .await } diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs index fd59485a7a70..316eeba111ef 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs @@ -1,3 +1,4 @@ +use crate::filter_conversion::AliasedCondition; use crate::sql_trace::SqlTraceComment; use crate::{error::SqlError, model_extensions::*, query_builder::write, sql_info::SqlInfo, QueryExt}; use connector_interface::*; @@ -305,8 +306,10 @@ pub async fn update_records( model: &ModelRef, record_filter: RecordFilter, args: WriteArgs, + update_type: UpdateType, trace_id: Option, ) -> crate::Result> { + let filter_condition = record_filter.clone().filter.aliased_condition_from(None, false); let ids = conn.filter_selectors(model, record_filter, trace_id.clone()).await?; let id_args = pick_args(&model.primary_identifier().into(), &args); @@ -316,11 +319,18 @@ pub async fn update_records( let updates = { let ids: Vec<&SelectionResult> = ids.iter().collect(); - write::update_many(model, ids.as_slice(), args, trace_id)? + write::update_many(model, ids.as_slice(), args, filter_condition, trace_id)? }; + let mut count = 0; for update in updates { - conn.query(update).await?; + let update_count = conn.execute(update).await?; + + count += update_count; + } + + if update_type == UpdateType::Many && count == 0 { + return Ok(Vec::new()); } Ok(merge_write_args(ids, id_args)) @@ -333,6 +343,7 @@ pub async fn delete_records( record_filter: RecordFilter, trace_id: Option, ) -> crate::Result { + let filter_condition = record_filter.clone().filter.aliased_condition_from(None, false); let ids = conn.filter_selectors(model, record_filter, trace_id.clone()).await?; let ids: Vec<&SelectionResult> = ids.iter().collect(); let count = ids.len(); @@ -341,11 +352,15 @@ pub async fn delete_records( return Ok(count); } - for delete in write::delete_many(model, ids.as_slice(), trace_id) { - conn.query(delete).await?; + let mut row_count = 0; + for delete in write::delete_many(model, ids.as_slice(), filter_condition, trace_id) { + row_count += conn.execute(delete).await?; } - Ok(count) + match usize::try_from(row_count) { + Ok(row_count) => Ok(row_count), + Err(_) => Ok(count), + } } /// Connect relations defined in `child_ids` to a parent defined in `parent_id`. diff --git a/query-engine/connectors/sql-query-connector/src/database/transaction.rs b/query-engine/connectors/sql-query-connector/src/database/transaction.rs index f2c834cfffcb..9fb13cee996f 100644 --- a/query-engine/connectors/sql-query-connector/src/database/transaction.rs +++ b/query-engine/connectors/sql-query-connector/src/database/transaction.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use connector::{ConnectionLike, RelAggregationSelection}; use connector_interface::{ self as connector, filter::Filter, AggregationRow, AggregationSelection, QueryArguments, ReadOperations, - RecordFilter, Transaction, WriteArgs, WriteOperations, + RecordFilter, Transaction, UpdateType, WriteArgs, WriteOperations, }; use prisma_models::{prelude::*, SelectionResult}; use prisma_value::PrismaValue; @@ -190,7 +190,22 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { trace_id: Option, ) -> connector::Result> { catch(self.connection_info.clone(), async move { - write::update_records(&self.inner, model, record_filter, args, trace_id).await + write::update_records(&self.inner, model, record_filter, args, UpdateType::Many, trace_id).await + }) + .await + } + + async fn update_record( + &mut self, + model: &ModelRef, + record_filter: RecordFilter, + args: WriteArgs, + trace_id: Option, + ) -> connector::Result> { + catch(self.connection_info.clone(), async move { + let mut res = + write::update_records(&self.inner, model, record_filter, args, UpdateType::One, trace_id).await?; + Ok(res.pop()) }) .await } diff --git a/query-engine/connectors/sql-query-connector/src/filter_conversion.rs b/query-engine/connectors/sql-query-connector/src/filter_conversion.rs index 0d4ce037471f..b35979a80f47 100644 --- a/query-engine/connectors/sql-query-connector/src/filter_conversion.rs +++ b/query-engine/connectors/sql-query-connector/src/filter_conversion.rs @@ -57,13 +57,44 @@ impl Alias { } } +#[derive(Clone)] +pub struct ConditionState { + reverse: bool, + alias: Option, +} + +impl ConditionState { + fn new(alias: Option, reverse: bool) -> Self { + Self { reverse, alias } + } + + fn invert_reverse(self) -> Self { + Self::new(self.alias, !self.reverse) + } + + fn alias(&self) -> Option { + self.alias + } + + fn reverse(&self) -> bool { + self.reverse + } +} + pub trait AliasedCondition { /// Conversion to a query condition tree. Columns will point to the given /// alias if provided, otherwise using the fully qualified path. /// /// Alias should be used only when nesting, making the top level queries /// more explicit. - fn aliased_cond(self, alias: Option, reverse: bool) -> ConditionTree<'static>; + fn aliased_cond(self, state: ConditionState) -> ConditionTree<'static>; + + fn aliased_condition_from(&self, alias: Option, reverse: bool) -> ConditionTree<'static> + where + Self: Sized + Clone, + { + self.clone().aliased_cond(ConditionState::new(alias, reverse)) + } } trait AliasedSelect { @@ -100,15 +131,15 @@ impl AliasedColumn for Column<'static> { impl AliasedCondition for Filter { /// Conversion from a `Filter` to a query condition tree. Aliased when in a nested `SELECT`. - fn aliased_cond(self, alias: Option, reverse: bool) -> ConditionTree<'static> { + fn aliased_cond(self, state: ConditionState) -> ConditionTree<'static> { match self { Filter::And(mut filters) => match filters.len() { n if n == 0 => ConditionTree::NoCondition, - n if n == 1 => filters.pop().unwrap().aliased_cond(alias, reverse), + n if n == 1 => filters.pop().unwrap().aliased_cond(state), _ => { let exprs = filters .into_iter() - .map(|f| f.aliased_cond(alias, reverse)) + .map(|f| f.aliased_cond(state.clone())) .map(Expression::from) .collect(); @@ -117,11 +148,11 @@ impl AliasedCondition for Filter { }, Filter::Or(mut filters) => match filters.len() { n if n == 0 => ConditionTree::NegativeCondition, - n if n == 1 => filters.pop().unwrap().aliased_cond(alias, reverse), + n if n == 1 => filters.pop().unwrap().aliased_cond(state), _ => { let exprs = filters .into_iter() - .map(|f| f.aliased_cond(alias, reverse)) + .map(|f| f.aliased_cond(state.clone())) .map(Expression::from) .collect(); @@ -130,20 +161,20 @@ impl AliasedCondition for Filter { }, Filter::Not(mut filters) => match filters.len() { n if n == 0 => ConditionTree::NoCondition, - n if n == 1 => filters.pop().unwrap().aliased_cond(alias, !reverse).not(), + n if n == 1 => filters.pop().unwrap().aliased_cond(state.invert_reverse()).not(), _ => { let exprs = filters .into_iter() - .map(|f| f.aliased_cond(alias, !reverse).not()) + .map(|f| f.aliased_cond(state.clone().invert_reverse()).not()) .map(Expression::from) .collect(); ConditionTree::And(exprs) } }, - Filter::Scalar(filter) => filter.aliased_cond(alias, reverse), - Filter::OneRelationIsNull(filter) => filter.aliased_cond(alias, reverse), - Filter::Relation(filter) => filter.aliased_cond(alias, reverse), + Filter::Scalar(filter) => filter.aliased_cond(state), + Filter::OneRelationIsNull(filter) => filter.aliased_cond(state), + Filter::Relation(filter) => filter.aliased_cond(state), Filter::BoolFilter(b) => { if b { ConditionTree::NoCondition @@ -151,8 +182,8 @@ impl AliasedCondition for Filter { ConditionTree::NegativeCondition } } - Filter::Aggregation(filter) => filter.aliased_cond(alias, reverse), - Filter::ScalarList(filter) => filter.aliased_cond(alias, reverse), + Filter::Aggregation(filter) => filter.aliased_cond(state), + Filter::ScalarList(filter) => filter.aliased_cond(state), Filter::Empty => ConditionTree::NoCondition, Filter::Composite(_) => unimplemented!("SQL connectors do not support composites yet."), } @@ -161,7 +192,7 @@ impl AliasedCondition for Filter { impl AliasedCondition for ScalarFilter { /// Conversion from a `ScalarFilter` to a query condition tree. Aliased when in a nested `SELECT`. - fn aliased_cond(self, alias: Option, reverse: bool) -> ConditionTree<'static> { + fn aliased_cond(self, state: ConditionState) -> ConditionTree<'static> { match self.condition { ScalarCondition::Search(_, _) | ScalarCondition::NotSearch(_, _) => { let mut projections = match self.condition.clone() { @@ -175,7 +206,7 @@ impl AliasedCondition for ScalarFilter { let columns: Vec = projections .into_iter() .map(|p| match p { - ScalarProjection::Single(field) => field.aliased_col(alias), + ScalarProjection::Single(field) => field.aliased_col(state.alias()), ScalarProjection::Compound(_) => { unreachable!("Full-text search does not support compound fields") } @@ -184,9 +215,17 @@ impl AliasedCondition for ScalarFilter { let comparable: Expression = text_search(columns.as_slice()).into(); - convert_scalar_filter(comparable, self.condition, reverse, self.mode, &[], alias, false) + convert_scalar_filter( + comparable, + self.condition, + state.reverse(), + self.mode, + &[], + state.alias(), + false, + ) } - _ => scalar_filter_aliased_cond(self, alias, reverse), + _ => scalar_filter_aliased_cond(self, state.alias(), state.reverse()), } } } @@ -219,10 +258,10 @@ fn scalar_filter_aliased_cond(sf: ScalarFilter, alias: Option, reverse: b } impl AliasedCondition for ScalarListFilter { - fn aliased_cond(self, alias: Option, _reverse: bool) -> ConditionTree<'static> { - let comparable: Expression = self.field.aliased_col(alias).into(); + fn aliased_cond(self, state: ConditionState) -> ConditionTree<'static> { + let comparable: Expression = self.field.aliased_col(state.alias()).into(); - convert_scalar_list_filter(comparable, self.condition, &self.field, alias) + convert_scalar_list_filter(comparable, self.condition, &self.field, state.alias()) } } @@ -263,12 +302,12 @@ fn convert_scalar_list_filter( impl AliasedCondition for RelationFilter { /// Conversion from a `RelationFilter` to a query condition tree. Aliased when in a nested `SELECT`. - fn aliased_cond(self, alias: Option, _reverse: bool) -> ConditionTree<'static> { + fn aliased_cond(self, state: ConditionState) -> ConditionTree<'static> { let ids = ModelProjection::from(self.field.model().primary_identifier()).as_columns(); - let columns: Vec> = ids.map(|col| col.aliased_col(alias)).collect(); + let columns: Vec> = ids.map(|col| col.aliased_col(state.alias())).collect(); let condition = self.condition; - let sub_select = self.aliased_sel(alias.map(|a| a.inc(AliasMode::Table))); + let sub_select = self.aliased_sel(state.alias().map(|a| a.inc(AliasMode::Table))); let comparison = match condition { RelationCondition::AtLeastOneRelatedRecord => Row::from(columns).in_selection(sub_select), @@ -304,7 +343,7 @@ impl AliasedSelect for RelationFilter { let nested_conditions = self .nested_filter - .aliased_cond(Some(alias.flip(AliasMode::Join)), false) + .aliased_condition_from(Some(alias.flip(AliasMode::Join)), false) .invert_if(condition.invert_of_subselect()); let conditions = selected_identifier @@ -325,8 +364,8 @@ impl AliasedSelect for RelationFilter { impl AliasedCondition for OneRelationIsNullFilter { /// Conversion from a `OneRelationIsNullFilter` to a query condition tree. Aliased when in a nested `SELECT`. - fn aliased_cond(self, alias: Option, _reverse: bool) -> ConditionTree<'static> { - let alias = alias.map(|a| a.to_string(None)); + fn aliased_cond(self, state: ConditionState) -> ConditionTree<'static> { + let alias = state.alias().map(|a| a.to_string(None)); let condition = if self.field.relation_is_inlined_in_parent() { self.field.as_columns().fold(ConditionTree::NoCondition, |acc, column| { @@ -386,7 +425,9 @@ impl AliasedCondition for OneRelationIsNullFilter { impl AliasedCondition for AggregationFilter { /// Conversion from an `AggregationFilter` to a query condition tree. Aliased when in a nested `SELECT`. - fn aliased_cond(self, alias: Option, reverse: bool) -> ConditionTree<'static> { + fn aliased_cond(self, state: ConditionState) -> ConditionTree<'static> { + let alias = state.alias(); + let reverse = state.reverse(); match self { AggregationFilter::Count(filter) => aggregate_conditions(*filter, alias, reverse, |x| count(x).into()), AggregationFilter::Average(filter) => aggregate_conditions(*filter, alias, reverse, |x| avg(x).into()), diff --git a/query-engine/connectors/sql-query-connector/src/join_utils.rs b/query-engine/connectors/sql-query-connector/src/join_utils.rs index b047428ac4e8..3b2f33efc3f0 100644 --- a/query-engine/connectors/sql-query-connector/src/join_utils.rs +++ b/query-engine/connectors/sql-query-connector/src/join_utils.rs @@ -74,7 +74,7 @@ fn compute_aggr_join_one2m( }; let select_columns = right_fields.iter().map(|f| f.as_column()); let conditions: ConditionTree = filter - .map(|f| f.aliased_cond(None, false)) + .map(|f| f.aliased_condition_from(None, false)) .unwrap_or(ConditionTree::NoCondition); // + SELECT Child. FROM Child WHERE @@ -151,7 +151,7 @@ fn compute_aggr_join_m2m( let parent_ids: ModelProjection = rf.model().primary_identifier().into(); // Rendered filters let conditions: ConditionTree = filter - .map(|f| f.aliased_cond(None, false)) + .map(|f| f.aliased_condition_from(None, false)) .unwrap_or(ConditionTree::NoCondition); // + SELECT _ParentToChild.ChildId FROM Child WHERE diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/read.rs b/query-engine/connectors/sql-query-connector/src/query_builder/read.rs index d149b58770fa..2febe154c049 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/read.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/read.rs @@ -67,7 +67,7 @@ impl SelectDefinition for QueryArguments { let filter: ConditionTree = self .filter - .map(|f| f.aliased_cond(None, false)) + .map(|f| f.aliased_condition_from(None, false)) .unwrap_or(ConditionTree::NoCondition); let conditions = match (filter, cursor_condition) { @@ -253,7 +253,7 @@ pub fn group_by_aggregate( ); match having { - Some(filter) => grouped.having(filter.aliased_cond(None, false)), + Some(filter) => grouped.having(filter.aliased_condition_from(None, false)), None => grouped, } } diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/write.rs b/query-engine/connectors/sql-query-connector/src/query_builder/write.rs index 1ba47e1d69ec..8282e958729f 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/write.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/write.rs @@ -102,6 +102,7 @@ pub fn update_many( model: &ModelRef, ids: &[&SelectionResult], args: WriteArgs, + filter_condition: ConditionTree<'static>, trace_id: Option, ) -> crate::Result>> { if args.args.is_empty() || ids.is_empty() { @@ -161,17 +162,24 @@ pub fn update_many( let query = query.append_trace(&Span::current()).add_trace_id(trace_id); let columns: Vec<_> = ModelProjection::from(model.primary_identifier()).as_columns().collect(); - let result: Vec = super::chunked_conditions(&columns, ids, |conditions| query.clone().so_that(conditions)); + let result: Vec = super::chunked_conditions(&columns, ids, |conditions| { + query.clone().so_that(conditions.and(filter_condition.clone())) + }); Ok(result) } -pub fn delete_many(model: &ModelRef, ids: &[&SelectionResult], trace_id: Option) -> Vec> { +pub fn delete_many( + model: &ModelRef, + ids: &[&SelectionResult], + filter_condition: ConditionTree<'static>, + trace_id: Option, +) -> Vec> { let columns: Vec<_> = ModelProjection::from(model.primary_identifier()).as_columns().collect(); super::chunked_conditions(&columns, ids, |conditions| { Delete::from_table(model.as_table()) - .so_that(conditions) + .so_that(conditions.and(filter_condition.clone())) .append_trace(&Span::current()) .add_trace_id(trace_id.clone()) }) diff --git a/query-engine/connectors/sql-query-connector/src/query_ext.rs b/query-engine/connectors/sql-query-connector/src/query_ext.rs index 3d52e22af2e7..fa6bf168e157 100644 --- a/query-engine/connectors/sql-query-connector/src/query_ext.rs +++ b/query-engine/connectors/sql-query-connector/src/query_ext.rs @@ -161,7 +161,7 @@ pub trait QueryExt: Queryable + Send + Sync { .columns(id_cols) .append_trace(&Span::current()) .add_trace_id(trace_id.clone()) - .so_that(filter.aliased_cond(None, false)); + .so_that(filter.aliased_condition_from(None, false)); self.select_ids(select, model_id, trace_id).await } diff --git a/query-engine/core/src/interpreter/query_interpreters/write.rs b/query-engine/core/src/interpreter/query_interpreters/write.rs index b3fde0fbc94a..c68f684731ed 100644 --- a/query-engine/core/src/interpreter/query_interpreters/write.rs +++ b/query-engine/core/src/interpreter/query_interpreters/write.rs @@ -62,9 +62,9 @@ async fn update_one( q: UpdateRecord, trace_id: Option, ) -> InterpretationResult { - let mut res = tx.update_records(&q.model, q.record_filter, q.args, trace_id).await?; + let res = tx.update_record(&q.model, q.record_filter, q.args, trace_id).await?; - Ok(QueryResult::Id(res.pop())) + Ok(QueryResult::Id(res)) } async fn delete_one(