Skip to content

Commit

Permalink
update schema in udfs
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 27, 2022
1 parent c82c177 commit 092603b
Show file tree
Hide file tree
Showing 12 changed files with 107 additions and 84 deletions.
86 changes: 49 additions & 37 deletions polars/polars-lazy/src/frame/mod.rs
Expand Up @@ -345,12 +345,16 @@ impl LazyFrame {
}
}

// schema after renaming
let mut new_schema = (*self.schema()).clone();

for (old, new) in existing.iter().zip(new.iter()) {
new_schema.rename(old, new.to_string()).unwrap();
}
let existing2 = existing.clone();
let new2 = new.clone();
let udf_schema = move |s: &Schema| {
// schema after renaming
let mut new_schema = s.clone();
for (old, new) in existing2.iter().zip(new2.iter()) {
new_schema.rename(old, new.to_string()).unwrap();
}
Ok(Arc::new(new_schema))
};

let prefix = "__POLARS_TEMP_";

Expand Down Expand Up @@ -393,17 +397,21 @@ impl LazyFrame {
DataFrame::new(cols)
},
None,
Some(new_schema),
Some(Arc::new(udf_schema)),
Some("RENAME_SWAPPING"),
)
}

fn rename_imp(self, existing: Vec<String>, new: Vec<String>) -> Self {
let mut schema = (*self.schema()).clone();

for (old, new) in existing.iter().zip(&new) {
let _ = schema.rename(old, new.clone());
}
fn rename_impl(self, existing: Vec<String>, new: Vec<String>) -> Self {
let existing2 = existing.clone();
let new2 = new.clone();
let udf_schema = move |s: &Schema| {
let mut new_schema = s.clone();
for (old, new) in existing2.iter().zip(&new2) {
let _ = new_schema.rename(old, new.clone());
}
Ok(Arc::new(new_schema))
};

self.with_columns(
existing
Expand All @@ -427,7 +435,7 @@ impl LazyFrame {
Ok(df)
},
None,
Some(schema),
Some(Arc::new(udf_schema)),
Some("RENAME"),
)
}
Expand Down Expand Up @@ -457,7 +465,7 @@ impl LazyFrame {
if new.iter().any(|name| schema.get(name).is_some()) {
self.rename_impl_swapping(existing, new)
} else {
self.rename_imp(existing, new)
self.rename_impl(existing, new)
}
}

Expand Down Expand Up @@ -1154,7 +1162,7 @@ impl LazyFrame {
self,
function: F,
optimizations: Option<AllowedOptimizations>,
schema: Option<Schema>,
schema: Option<Arc<dyn UdfSchema>>,
name: Option<&'static str>,
) -> LazyFrame
where
Expand All @@ -1166,7 +1174,7 @@ impl LazyFrame {
.map(
function,
optimizations.unwrap_or_default(),
schema.map(Arc::new),
schema,
name.unwrap_or("ANONYMOUS UDF"),
)
.build();
Expand Down Expand Up @@ -1208,10 +1216,12 @@ impl LazyFrame {
}
}

let new_schema = self
.schema()
.insert_index(0, name.to_string(), IDX_DTYPE)
.unwrap();
let name2 = name.to_string();
let udf_schema = move |s: &Schema| {
let new = s.insert_index(0, name2.clone(), IDX_DTYPE).unwrap();
Ok(Arc::new(new))
};

let name = name.to_owned();

// if we do the row count at scan we add a dummy map, to update the schema
Expand All @@ -1234,7 +1244,7 @@ impl LazyFrame {
}
},
Some(opt),
Some(new_schema),
Some(Arc::new(udf_schema)),
Some("WITH ROW COUNT"),
)
}
Expand All @@ -1250,27 +1260,29 @@ impl LazyFrame {

#[cfg(feature = "dtype-struct")]
fn unnest_impl(self, cols: PlHashSet<String>) -> Self {
let schema = self.schema();

let mut new_schema = Schema::with_capacity(schema.len() * 2);
for (name, dtype) in schema.iter() {
if cols.contains(name) {
if let DataType::Struct(flds) = dtype {
for fld in flds {
new_schema.with_column(fld.name().clone(), fld.data_type().clone())
let cols2 = cols.clone();
let udf_schema = move |schema: &Schema| {
let mut new_schema = Schema::with_capacity(schema.len() * 2);
for (name, dtype) in schema.iter() {
if cols.contains(name) {
if let DataType::Struct(flds) = dtype {
for fld in flds {
new_schema.with_column(fld.name().clone(), fld.data_type().clone())
}
} else {
// todo: return lazy error here.
panic!("expected struct dtype")
}
} else {
// todo: return lazy error here.
panic!("expected struct dtype")
new_schema.with_column(name.clone(), dtype.clone())
}
} else {
new_schema.with_column(name.clone(), dtype.clone())
}
}
Ok(Arc::new(new_schema))
};
self.map(
move |df| df.unnest(&cols),
move |df| df.unnest(&cols2),
Some(AllowedOptimizations::default()),
Some(new_schema),
Some(Arc::new(udf_schema)),
Some("unnest"),
)
}
Expand Down
18 changes: 12 additions & 6 deletions polars/polars-lazy/src/logical_plan/alp.rs
@@ -1,4 +1,3 @@
use std::borrow::Cow;
#[cfg(feature = "ipc")]
use crate::logical_plan::IpcScanOptionsInner;
#[cfg(feature = "parquet")]
Expand All @@ -9,6 +8,7 @@ use crate::utils::{aexprs_to_schema, PushNode};
use polars_core::frame::explode::MeltArgs;
use polars_core::prelude::*;
use polars_utils::arena::{Arena, Node};
use std::borrow::Cow;
#[cfg(any(feature = "ipc", feature = "csv-file", feature = "parquet"))]
use std::path::PathBuf;
use std::sync::Arc;
Expand Down Expand Up @@ -133,7 +133,7 @@ pub enum ALogicalPlan {
input: Node,
function: Arc<dyn DataFrameUdf>,
options: LogicalPlanUdfOptions,
schema: Option<Arc<dyn UdfSchema>>
schema: Option<Arc<dyn UdfSchema>>,
},
Union {
inputs: Vec<Node>,
Expand Down Expand Up @@ -219,8 +219,8 @@ impl ALogicalPlan {
return match schema {
Some(schema) => Cow::Owned(schema.get_schema(&input_schema).unwrap()),
None => input_schema,
}
},
};
}
};
Cow::Borrowed(schema)
}
Expand Down Expand Up @@ -715,8 +715,14 @@ impl<'a> ALogicalPlanBuilder<'a> {
// TODO! add this line if LogicalPlan is dropped in favor of ALogicalPlan
// let aggs = rewrite_projections(aggs, current_schema);

let mut schema = aexprs_to_schema(&keys, &current_schema, Context::Default, self.expr_arena);
let other = aexprs_to_schema(&aggs, &current_schema, Context::Aggregation, self.expr_arena);
let mut schema =
aexprs_to_schema(&keys, &current_schema, Context::Default, self.expr_arena);
let other = aexprs_to_schema(
&aggs,
&current_schema,
Context::Aggregation,
self.expr_arena,
);
schema.merge(other);

let index_columns = &[
Expand Down
5 changes: 2 additions & 3 deletions polars/polars-lazy/src/logical_plan/apply.rs
Expand Up @@ -20,14 +20,13 @@ impl Debug for dyn DataFrameUdf {
}
}


pub trait UdfSchema: Send + Sync {
fn get_schema(&self, input_schema: &Schema) -> Result<SchemaRef>;
}

impl<F> UdfSchema for F
where
F: Fn(&Schema) -> Result<SchemaRef> + Send + Sync,
where
F: Fn(&Schema) -> Result<SchemaRef> + Send + Sync,
{
fn get_schema(&self, input_schema: &Schema) -> Result<SchemaRef> {
self(input_schema)
Expand Down
8 changes: 3 additions & 5 deletions polars/polars-lazy/src/logical_plan/builder.rs
Expand Up @@ -236,7 +236,7 @@ impl LogicalPlanBuilder {
self.map(
|_| Ok(DataFrame::new_no_checks(vec![])),
AllowedOptimizations::default(),
Some(Arc::new(Schema::default())),
Some(Arc::new(|_: &Schema| Ok(Arc::new(Schema::default())))),
"EMPTY PROJECTION",
)
} else {
Expand Down Expand Up @@ -524,7 +524,7 @@ impl LogicalPlanBuilder {
self,
function: F,
optimizations: AllowedOptimizations,
schema: Option<SchemaRef>,
schema: Option<Arc<dyn UdfSchema>>,
name: &'static str,
) -> Self
where
Expand All @@ -540,9 +540,7 @@ impl LogicalPlanBuilder {
input: Box::new(self.0),
function: Arc::new(function),
options,
schema: schema.map(|s| {
Arc::new(move |_: &Schema| Ok(s.clone())) as Arc<dyn UdfSchema>
})
schema,
}
.into()
}
Expand Down
6 changes: 3 additions & 3 deletions polars/polars-lazy/src/logical_plan/mod.rs
@@ -1,8 +1,8 @@
use parking_lot::Mutex;
use std::borrow::Cow;
#[cfg(any(feature = "ipc", feature = "csv-file", feature = "parquet"))]
use std::path::PathBuf;
use std::{cell::Cell, fmt::Debug, sync::Arc};
use std::borrow::Cow;

use polars_core::prelude::*;

Expand Down Expand Up @@ -215,7 +215,7 @@ impl LogicalPlan {
}

impl LogicalPlan {
pub(crate) fn schema<'a>(&'a self) -> Cow<'a, SchemaRef> {
pub(crate) fn schema(&self) -> Cow<'_, SchemaRef> {
use LogicalPlan::*;
match self {
#[cfg(feature = "python")]
Expand Down Expand Up @@ -247,7 +247,7 @@ impl LogicalPlan {
Some(schema) => Cow::Owned(schema.get_schema(&input_schema).unwrap()),
None => input_schema,
}
},
}
Error { input, .. } => input.schema(),
}
}
Expand Down
Expand Up @@ -40,9 +40,7 @@ fn impl_fast_projection(
let lp = ALogicalPlan::Udf {
input,
function: Arc::new(function),
schema: schema.map(|s| {
Arc::new(move |_: &Schema| Ok(s.clone())) as Arc<dyn UdfSchema>
}),
schema: schema.map(|s| Arc::new(move |_: &Schema| Ok(s.clone())) as Arc<dyn UdfSchema>),
options,
};

Expand Down
10 changes: 6 additions & 4 deletions polars/polars-lazy/src/logical_plan/optimizer/type_coercion.rs
@@ -1,7 +1,7 @@
use std::borrow::Cow;
use crate::dsl::function_expr::FunctionExpr;
use polars_core::prelude::*;
use polars_core::utils::get_supertype;
use std::borrow::Cow;

use crate::logical_plan::optimizer::stack_opt::OptimizationRule;
use crate::logical_plan::Context;
Expand Down Expand Up @@ -96,7 +96,7 @@ fn get_input(lp_arena: &Arena<ALogicalPlan>, lp_node: Node) -> [Option<Node>; 2]
inputs
}

fn get_schema<'a>(lp_arena: &'a Arena<ALogicalPlan>, lp_node: Node) -> Cow<'a, SchemaRef> {
fn get_schema(lp_arena: &Arena<ALogicalPlan>, lp_node: Node) -> Cow<'_, SchemaRef> {
match get_input(lp_arena, lp_node) {
[Some(input), _] => lp_arena.get(input).schema(lp_arena),
// files don't have an input, so we must take their schema
Expand Down Expand Up @@ -140,7 +140,8 @@ impl OptimizationRule for TypeCoercionRule {
let input_schema = get_schema(lp_arena, lp_node);
let (truthy, type_true) =
get_aexpr_and_type(expr_arena, truthy_node, &input_schema)?;
let (falsy, type_false) = get_aexpr_and_type(expr_arena, falsy_node, &input_schema)?;
let (falsy, type_false) =
get_aexpr_and_type(expr_arena, falsy_node, &input_schema)?;

if type_true == type_false {
None
Expand Down Expand Up @@ -186,7 +187,8 @@ impl OptimizationRule for TypeCoercionRule {
} => {
let input_schema = get_schema(lp_arena, lp_node);
let (left, type_left) = get_aexpr_and_type(expr_arena, node_left, &input_schema)?;
let (right, type_right) = get_aexpr_and_type(expr_arena, node_right, &input_schema)?;
let (right, type_right) =
get_aexpr_and_type(expr_arena, node_right, &input_schema)?;

// don't coerce string with number comparisons. They must error
match (&type_left, &type_right, op) {
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/physical_plan/planner/lp.rs
Expand Up @@ -442,7 +442,7 @@ impl DefaultPlanner {
phys_aggs,
maintain_order,
options.slice,
input_schema
input_schema,
)))
} else {
Ok(Box::new(executors::GroupByExec::new(
Expand Down
1 change: 0 additions & 1 deletion polars/polars-lazy/src/tests/tpch.rs
Expand Up @@ -96,7 +96,6 @@ fn test_q2() -> Result<()> {
Field::new("s_comment", DataType::Utf8),
]);
assert_eq!(&out.schema(), &schema);
dbg!(out);

Ok(())
}

0 comments on commit 092603b

Please sign in to comment.