Skip to content

Commit

Permalink
improve virtual selection iterations do avoid allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
Weakky committed Mar 11, 2024
1 parent b7451de commit 1ebfa03
Show file tree
Hide file tree
Showing 20 changed files with 431 additions and 316 deletions.
4 changes: 2 additions & 2 deletions libs/prisma-value/src/lib.rs
Expand Up @@ -149,7 +149,7 @@ pub fn parse_datetime(datetime: &str) -> chrono::ParseResult<DateTime<FixedOffse
DateTime::parse_from_rfc3339(datetime)
}

pub fn stringify_bigdecimal(decimal: &BigDecimal) -> f64 {
pub fn stringify_decimal(decimal: &BigDecimal) -> f64 {
decimal.to_string().parse::<f64>().unwrap()
}

Expand Down Expand Up @@ -262,7 +262,7 @@ fn serialize_decimal<S>(decimal: &BigDecimal, serializer: S) -> Result<S::Ok, S:
where
S: Serializer,
{
decimal.to_string().parse::<f64>().unwrap().serialize(serializer)
stringify_decimal(decimal).serialize(serializer)
}

fn deserialize_decimal<'de, D>(deserializer: D) -> Result<BigDecimal, D::Error>
Expand Down
Expand Up @@ -361,14 +361,14 @@ impl MongoReadQueryBuilder {
) -> crate::Result<Self> {
for aggr in virtual_selections {
let join = match aggr {
VirtualSelection::RelationCount(rf, filter) => {
let filter = filter
.as_ref()
VirtualSelection::RelationCount(x) => {
let filter = x
.filter()
.map(|f| MongoFilterVisitor::new(FilterPrefix::default(), false).visit(f.clone()))
.transpose()?;

JoinStage {
source: rf.clone(),
source: x.field().clone(),
alias: Some(aggr.db_alias()),
nested: vec![],
filter,
Expand Down
26 changes: 14 additions & 12 deletions query-engine/connectors/sql-query-connector/src/column_metadata.rs
@@ -1,10 +1,12 @@
use query_structure::{FieldArity, FieldSelection, RelationSelection, SelectedField, TypeIdentifier};
use query_structure::{
FieldArity, FieldSelection, GroupedSelectedField, GroupedVirtualSelection, RelationSelection, TypeIdentifier,
};

#[derive(Clone, Debug, Copy)]
#[derive(Clone, Debug)]
pub enum MetadataFieldKind<'a> {
Scalar,
Relation(&'a RelationSelection),
Virtual,
Virtual(GroupedVirtualSelection<'a>),
}

/// Helps dealing with column value conversion and possible error resolution.
Expand Down Expand Up @@ -47,8 +49,8 @@ impl<'a> ColumnMetadata<'a> {
self.arity
}

pub(crate) fn kind(&self) -> MetadataFieldKind {
self.kind
pub(crate) fn kind(&self) -> &MetadataFieldKind<'_> {
&self.kind
}
}

Expand All @@ -69,23 +71,23 @@ where
.collect()
}

pub(crate) fn create_from_selection<'a, T>(
pub(crate) fn create_from_selection_for_json<'a, T>(
selection: &'a FieldSelection,
field_names: &'a [T],
) -> Vec<ColumnMetadata<'a>>
where
T: AsRef<str>,
{
selection
.selections()
.grouped_selections()
.zip(field_names.iter())
.map(|(field, name)| {
let (type_identifier, arity) = field.type_identifier_with_arity().unwrap();
let (type_identifier, arity) = field.type_identifier_with_arity_for_json();

let kind = match field {
SelectedField::Scalar(_) => MetadataFieldKind::Scalar,
SelectedField::Relation(rs) => MetadataFieldKind::Relation(rs),
SelectedField::Virtual(_) => MetadataFieldKind::Virtual,
SelectedField::Composite(_) => unreachable!(),
GroupedSelectedField::Scalar(_) => MetadataFieldKind::Scalar,
GroupedSelectedField::Relation(rs) => MetadataFieldKind::Relation(rs),
GroupedSelectedField::Virtual(vs) => MetadataFieldKind::Virtual(vs),
};

ColumnMetadata::new(type_identifier, arity, kind).set_name(name.as_ref())
Expand Down
Expand Up @@ -5,13 +5,6 @@ use std::{io, str::FromStr};

use crate::{query_arguments_ext::QueryArgumentsExt, SqlError};

#[inline]
fn fields_to_serialize(rs: &RelationSelection) -> impl Iterator<Item = &SelectedField> {
rs.result_fields
.iter()
.filter_map(|field_name| rs.selections.iter().find(|f| f.prisma_name().as_ref() == field_name))
}

pub(crate) fn coerce_json_relation_to_pv(
mut value: serde_json::Value,
rs: &RelationSelection,
Expand Down Expand Up @@ -48,25 +41,28 @@ fn internal_coerce_json_relation_to_pv(value: &mut serde_json::Value, rs: &Relat
serde_json::Value::Object(obj) => {
let mut new_obj = serde_json::Map::with_capacity(obj.len());

for field in fields_to_serialize(rs) {
let (field_name, mut obj_val) = obj.remove_entry(field.prisma_name().as_ref()).unwrap();

for field in rs.grouped_fields_to_serialize() {
match field {
SelectedField::Scalar(sf) => {
GroupedSelectedField::Scalar(sf) => {
let (field_name, mut obj_val) = obj.remove_entry(sf.name()).unwrap();

coerce_json_scalar_to_pv(&mut obj_val, sf)?;

new_obj.insert(field_name, obj_val);
}
SelectedField::Relation(nested_rs) => {
GroupedSelectedField::Relation(nested_rs) => {
let (field_name, mut obj_val) = obj.remove_entry(nested_rs.field.name()).unwrap();

internal_coerce_json_relation_to_pv(&mut obj_val, nested_rs)?;

new_obj.insert(field_name, obj_val);
}
SelectedField::Virtual(_) => {
todo!()
// let coerced_value = coerce_json_virtual_field_to_pv(&key, value)?;
// map.push((key, coerced_value));
GroupedSelectedField::Virtual(vs) => {
let (field_name, obj_val) = obj.remove_entry(vs.serialized_name().0).unwrap();

new_obj.insert(field_name, reorder_virtuals_group(obj_val, &vs));
}
_ => unreachable!(),
}

new_obj.insert(field_name, obj_val);
}

*obj = new_obj;
Expand All @@ -89,18 +85,17 @@ fn coerce_json_scalar_to_pv(value: &mut serde_json::Value, sf: &ScalarField) ->
*value = serde_json::Value::Array(vec![]);
}
}
serde_json::Value::Number(n) => match sf.type_identifier() {
serde_json::Value::Number(ref n) => match sf.type_identifier() {
TypeIdentifier::Decimal => {
let bd = n
.as_f64()
.and_then(BigDecimal::from_f64)
.map(|bd| bd.normalized())
.ok_or_else(|| {
build_conversion_error(sf, &format!("Number({n})"), &format!("{:?}", sf.type_identifier()))
})?;
let bd = parse_json_f64(n, sf)?;

*value = serde_json::Value::String(bd.normalized().to_string());
}
TypeIdentifier::Float => {
let bd = parse_json_f64(n, sf)?;

*value = serde_json::Value::Number(Number::from_f64(stringify_decimal(&bd)).unwrap());
}
TypeIdentifier::Boolean => {
let err =
|| build_conversion_error(sf, &format!("Number({n})"), &format!("{:?}", sf.type_identifier()));
Expand Down Expand Up @@ -183,31 +178,34 @@ fn coerce_json_scalar_to_pv(value: &mut serde_json::Value, sf: &ScalarField) ->
Ok(())
}

fn coerce_json_virtual_field_to_pv(key: &str, value: serde_json::Value) -> crate::Result<PrismaValue> {
match value {
serde_json::Value::Object(obj) => {
let values: crate::Result<Vec<_>> = obj
.into_iter()
.map(|(key, value)| coerce_json_virtual_field_to_pv(&key, value).map(|value| (key, value)))
.collect();
Ok(PrismaValue::Object(values?))
}
pub fn reorder_virtuals_group(val: serde_json::Value, vs: &GroupedVirtualSelection) -> serde_json::Value {
match val {
serde_json::Value::Object(mut obj) => {
let mut new_obj = serde_json::Map::with_capacity(obj.len());

match vs {
GroupedVirtualSelection::RelationCounts(rcs) => {
for rc in rcs {
let (field_name, obj_val) = obj.remove_entry(rc.field().name()).unwrap();

new_obj.insert(field_name, obj_val);
}
}
}

serde_json::Value::Number(num) => num
.as_i64()
.ok_or_else(|| {
build_generic_conversion_error(format!(
"Unexpected numeric value {num} for virtual field '{key}': only integers are supported"
))
})
.map(PrismaValue::Int),

_ => Err(build_generic_conversion_error(format!(
"Field '{key}' is not a model field and doesn't have a supported type for a virtual field"
))),
new_obj.into()
}
_ => val,
}
}

fn parse_json_f64(n: &Number, sf: &Zipper<ScalarFieldId>) -> crate::Result<BigDecimal> {
n.as_f64()
.and_then(BigDecimal::from_f64)
.map(|bd| bd.normalized())
.ok_or_else(|| build_conversion_error(sf, &format!("Number({n})"), &format!("{:?}", sf.type_identifier())))
}

fn build_conversion_error(sf: &ScalarField, from: &str, to: &str) -> SqlError {
let container_name = sf.container().name();
let field_name = sf.name();
Expand Down
Expand Up @@ -33,8 +33,8 @@ pub(crate) async fn get_single_record_joins(
ctx: &Context<'_>,
) -> crate::Result<Option<SingleRecord>> {
let selected_fields = selected_fields.to_virtuals_last();
let field_names: Vec<_> = selected_fields.prisma_names_grouping_virtuals().collect();
let meta = column_metadata::create_from_selection(&selected_fields, &field_names);
let field_names: Vec<_> = selected_fields.grouped_prisma_names();
let meta = column_metadata::create_from_selection_for_json(&selected_fields, &field_names);

let query = query_builder::select::SelectBuilder::build(
QueryArguments::from((model.clone(), filter.clone())),
Expand Down Expand Up @@ -117,8 +117,9 @@ pub(crate) async fn get_many_records_joins(
ctx: &Context<'_>,
) -> crate::Result<ManyRecords> {
let selected_fields = selected_fields.to_virtuals_last();
let field_names: Vec<_> = selected_fields.prisma_names_grouping_virtuals().collect();
let meta = column_metadata::create_from_selection(&selected_fields, &field_names);
let field_names: Vec<_> = selected_fields.grouped_prisma_names();
let meta = column_metadata::create_from_selection_for_json(&selected_fields, &field_names);
// dbg!(&meta);

let mut records = ManyRecords::new(field_names.clone());

Expand Down
Expand Up @@ -147,7 +147,7 @@ fn process_result_row(
meta: &[ColumnMetadata],
selected_fields: &ModelProjection,
) -> crate::Result<SelectionResult> {
let sql_row = row.to_sql_row(meta, &mut std::time::Duration::ZERO)?;
let sql_row = row.to_sql_row(meta)?;
let prisma_row = selected_fields.scalar_fields().zip(sql_row.values).collect_vec();

Ok(SelectionResult::new(prisma_row))
Expand Down
Expand Up @@ -33,7 +33,7 @@ pub(crate) async fn native_upsert(
let result_set = conn.query(query).await?;

let row = result_set.into_single()?;
let record = Record::from(row.to_sql_row(&meta, &mut std::time::Duration::ZERO)?);
let record = Record::from(row.to_sql_row(&meta)?);

Ok(SingleRecord { record, field_names })
}
Expand Up @@ -167,7 +167,7 @@ pub(crate) async fn create_record(
let field_names: Vec<_> = selected_fields.db_names().collect();
let idents = ModelProjection::from(&selected_fields).type_identifiers_with_arities();
let meta = column_metadata::create(&field_names, &idents);
let sql_row = row.to_sql_row(&meta, &mut std::time::Duration::ZERO)?;
let sql_row = row.to_sql_row(&meta)?;
let record = Record::from(sql_row);

Ok(SingleRecord { record, field_names })
Expand Down Expand Up @@ -273,7 +273,7 @@ pub(crate) async fn create_records_returning(
for insert in inserts {
let result_set = conn.query(insert.into()).await?;
for result_row in result_set {
let sql_row = result_row.to_sql_row(&meta, &mut std::time::Duration::ZERO)?;
let sql_row = result_row.to_sql_row(&meta)?;
let record = Record::from(sql_row);
records.push(record);
}
Expand Down Expand Up @@ -448,7 +448,7 @@ pub(crate) async fn delete_record(
let field_db_names: Vec<_> = selected_fields.db_names().collect();
let types_and_arities = selected_fields.type_identifiers_with_arities();
let meta = column_metadata::create(&field_db_names, &types_and_arities);
let sql_row = result_row.to_sql_row(&meta, &mut std::time::Duration::ZERO)?;
let sql_row = result_row.to_sql_row(&meta)?;

let record = Record::from(sql_row);
Ok(SingleRecord {
Expand Down
Expand Up @@ -22,13 +22,13 @@ pub(crate) fn build<'a>(

for (index, selection) in virtual_selections.into_iter().enumerate() {
match selection {
VirtualSelection::RelationCount(rf, filter) => {
VirtualSelection::RelationCount(rc) => {
let join_alias = format!("aggr_selection_{index}");
let aggregator_alias = selection.db_alias();
let join = compute_aggr_join(
rf,
rc.field(),
AggregationType::Count,
filter.clone(),
rc.filter().cloned(),
aggregator_alias.as_str(),
join_alias.as_str(),
None,
Expand Down
Expand Up @@ -21,7 +21,7 @@ enum VirtualSelectionKey {
impl From<&VirtualSelection> for VirtualSelectionKey {
fn from(vs: &VirtualSelection) -> Self {
match vs {
VirtualSelection::RelationCount(rf, _) => Self::RelationCount(rf.clone()),
VirtualSelection::RelationCount(rc) => Self::RelationCount(rc.field().clone()),
}
}
}
Expand Down
Expand Up @@ -295,11 +295,11 @@ pub(crate) trait JoinSelectBuilder {
ctx: &Context<'_>,
) -> Select<'static> {
match vs {
VirtualSelection::RelationCount(rf, filter) => {
if rf.relation().is_many_to_many() {
self.build_relation_count_query_m2m(vs.db_alias(), rf, filter, parent_alias, ctx)
VirtualSelection::RelationCount(rc) => {
if rc.field().relation().is_many_to_many() {
self.build_relation_count_query_m2m(vs.db_alias(), rc.field(), rc.filter(), parent_alias, ctx)
} else {
self.build_relation_count_query(vs.db_alias(), rf, filter, parent_alias, ctx)
self.build_relation_count_query(vs.db_alias(), rc.field(), rc.filter(), parent_alias, ctx)
}
}
}
Expand Down Expand Up @@ -333,7 +333,7 @@ pub(crate) trait JoinSelectBuilder {
&mut self,
selection_name: impl Into<Cow<'static, str>>,
rf: &RelationField,
filter: &Option<Filter>,
filter: Option<&Filter>,
parent_alias: Alias,
ctx: &Context<'_>,
) -> Select<'a> {
Expand All @@ -347,7 +347,7 @@ pub(crate) trait JoinSelectBuilder {
let select = Select::from_table(related_table)
.value(count(asterisk()).alias(selection_name))
.with_join_conditions(rf, parent_alias, related_table_alias, ctx)
.with_filters(filter.clone(), Some(related_table_alias), ctx);
.with_filters(filter.cloned(), Some(related_table_alias), ctx);

select
}
Expand All @@ -356,7 +356,7 @@ pub(crate) trait JoinSelectBuilder {
&mut self,
selection_name: impl Into<Cow<'static, str>>,
rf: &RelationField,
filter: &Option<Filter>,
filter: Option<&Filter>,
parent_alias: Alias,
ctx: &Context<'_>,
) -> Select<'a> {
Expand Down Expand Up @@ -395,7 +395,7 @@ pub(crate) trait JoinSelectBuilder {
.value(count(asterisk()).alias(selection_name))
.left_join(m2m_join_data)
.and_where(aggregation_join_conditions)
.with_filters(filter.clone(), Some(related_table_alias), ctx);
.with_filters(filter.cloned(), Some(related_table_alias), ctx);

select
}
Expand Down
5 changes: 1 addition & 4 deletions query-engine/connectors/sql-query-connector/src/query_ext.rs
Expand Up @@ -41,14 +41,11 @@ impl<Q: Queryable + ?Sized> QueryExt for Q {

let mut sql_rows = Vec::with_capacity(result_set.len());

let mut dur = std::time::Duration::ZERO;

let now = std::time::Instant::now();
for row in result_set {
sql_rows.push(row.to_sql_row(idents, &mut dur)?);
sql_rows.push(row.to_sql_row(idents)?);
}

println!("coerce_json_relation: {:.2?}", dur);
println!("to_row: {:.2?}", now.elapsed());

Ok(sql_rows)
Expand Down

0 comments on commit 1ebfa03

Please sign in to comment.