Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update schema in udfs #4165

Merged
merged 2 commits into from Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
90 changes: 51 additions & 39 deletions polars/polars-lazy/src/frame/mod.rs
Expand Up @@ -151,7 +151,7 @@ impl LazyFrame {
/// Get a hold on the schema of the current LazyFrame computation.
pub fn schema(&self) -> SchemaRef {
let logical_plan = self.clone().get_plan_builder().build();
logical_plan.schema().clone()
logical_plan.schema().into_owned()
}

pub(crate) fn get_plan_builder(self) -> LogicalPlanBuilder {
Expand Down Expand Up @@ -345,12 +345,16 @@ impl LazyFrame {
}
}

// schema after renaming
let mut new_schema = (*self.schema()).clone();

for (old, new) in existing.iter().zip(new.iter()) {
new_schema.rename(old, new.to_string()).unwrap();
}
let existing2 = existing.clone();
let new2 = new.clone();
let udf_schema = move |s: &Schema| {
// schema after renaming
let mut new_schema = s.clone();
for (old, new) in existing2.iter().zip(new2.iter()) {
new_schema.rename(old, new.to_string()).unwrap();
}
Ok(Arc::new(new_schema))
};

let prefix = "__POLARS_TEMP_";

Expand Down Expand Up @@ -393,17 +397,21 @@ impl LazyFrame {
DataFrame::new(cols)
},
None,
Some(new_schema),
Some(Arc::new(udf_schema)),
Some("RENAME_SWAPPING"),
)
}

fn rename_imp(self, existing: Vec<String>, new: Vec<String>) -> Self {
let mut schema = (*self.schema()).clone();

for (old, new) in existing.iter().zip(&new) {
let _ = schema.rename(old, new.clone());
}
fn rename_impl(self, existing: Vec<String>, new: Vec<String>) -> Self {
let existing2 = existing.clone();
let new2 = new.clone();
let udf_schema = move |s: &Schema| {
let mut new_schema = s.clone();
for (old, new) in existing2.iter().zip(&new2) {
let _ = new_schema.rename(old, new.clone());
}
Ok(Arc::new(new_schema))
};

self.with_columns(
existing
Expand All @@ -427,7 +435,7 @@ impl LazyFrame {
Ok(df)
},
None,
Some(schema),
Some(Arc::new(udf_schema)),
Some("RENAME"),
)
}
Expand Down Expand Up @@ -457,7 +465,7 @@ impl LazyFrame {
if new.iter().any(|name| schema.get(name).is_some()) {
self.rename_impl_swapping(existing, new)
} else {
self.rename_imp(existing, new)
self.rename_impl(existing, new)
}
}

Expand Down Expand Up @@ -556,7 +564,7 @@ impl LazyFrame {

// during debug we check if the optimizations have not modified the final schema
#[cfg(debug_assertions)]
let prev_schema = logical_plan.schema().clone();
let prev_schema = logical_plan.schema().into_owned();

let mut lp_top = to_alp(logical_plan, expr_arena, lp_arena)?;

Expand Down Expand Up @@ -1154,7 +1162,7 @@ impl LazyFrame {
self,
function: F,
optimizations: Option<AllowedOptimizations>,
schema: Option<Schema>,
schema: Option<Arc<dyn UdfSchema>>,
name: Option<&'static str>,
) -> LazyFrame
where
Expand All @@ -1166,7 +1174,7 @@ impl LazyFrame {
.map(
function,
optimizations.unwrap_or_default(),
schema.map(Arc::new),
schema,
name.unwrap_or("ANONYMOUS UDF"),
)
.build();
Expand Down Expand Up @@ -1208,10 +1216,12 @@ impl LazyFrame {
}
}

let new_schema = self
.schema()
.insert_index(0, name.to_string(), IDX_DTYPE)
.unwrap();
let name2 = name.to_string();
let udf_schema = move |s: &Schema| {
let new = s.insert_index(0, name2.clone(), IDX_DTYPE).unwrap();
Ok(Arc::new(new))
};

let name = name.to_owned();

// if we do the row count at scan we add a dummy map, to update the schema
Expand All @@ -1234,7 +1244,7 @@ impl LazyFrame {
}
},
Some(opt),
Some(new_schema),
Some(Arc::new(udf_schema)),
Some("WITH ROW COUNT"),
)
}
Expand All @@ -1250,27 +1260,29 @@ impl LazyFrame {

#[cfg(feature = "dtype-struct")]
fn unnest_impl(self, cols: PlHashSet<String>) -> Self {
let schema = self.schema();

let mut new_schema = Schema::with_capacity(schema.len() * 2);
for (name, dtype) in schema.iter() {
if cols.contains(name) {
if let DataType::Struct(flds) = dtype {
for fld in flds {
new_schema.with_column(fld.name().clone(), fld.data_type().clone())
let cols2 = cols.clone();
let udf_schema = move |schema: &Schema| {
let mut new_schema = Schema::with_capacity(schema.len() * 2);
for (name, dtype) in schema.iter() {
if cols.contains(name) {
if let DataType::Struct(flds) = dtype {
for fld in flds {
new_schema.with_column(fld.name().clone(), fld.data_type().clone())
}
} else {
// todo: return lazy error here.
panic!("expected struct dtype")
}
} else {
// todo: return lazy error here.
panic!("expected struct dtype")
new_schema.with_column(name.clone(), dtype.clone())
}
} else {
new_schema.with_column(name.clone(), dtype.clone())
}
}
Ok(Arc::new(new_schema))
};
self.map(
move |df| df.unnest(&cols),
move |df| df.unnest(&cols2),
Some(AllowedOptimizations::default()),
Some(new_schema),
Some(Arc::new(udf_schema)),
Some("unnest"),
)
}
Expand Down
55 changes: 33 additions & 22 deletions polars/polars-lazy/src/logical_plan/alp.rs
Expand Up @@ -8,6 +8,7 @@ use crate::utils::{aexprs_to_schema, PushNode};
use polars_core::frame::explode::MeltArgs;
use polars_core::prelude::*;
use polars_utils::arena::{Arena, Node};
use std::borrow::Cow;
#[cfg(any(feature = "ipc", feature = "csv-file", feature = "parquet"))]
use std::path::PathBuf;
use std::sync::Arc;
Expand Down Expand Up @@ -132,7 +133,7 @@ pub enum ALogicalPlan {
input: Node,
function: Arc<dyn DataFrameUdf>,
options: LogicalPlanUdfOptions,
schema: Option<SchemaRef>,
schema: Option<Arc<dyn UdfSchema>>,
},
Union {
inputs: Vec<Node>,
Expand Down Expand Up @@ -171,14 +172,14 @@ impl ALogicalPlan {
}

/// Get the schema of the logical plan node.
pub(crate) fn schema<'a>(&'a self, arena: &'a Arena<ALogicalPlan>) -> &'a SchemaRef {
pub(crate) fn schema<'a>(&'a self, arena: &'a Arena<ALogicalPlan>) -> Cow<'a, SchemaRef> {
use ALogicalPlan::*;
match self {
let schema = match self {
#[cfg(feature = "python")]
PythonScan { options } => &options.schema,
Union { inputs, .. } => arena.get(inputs[0]).schema(arena),
Cache { input } => arena.get(*input).schema(arena),
Sort { input, .. } => arena.get(*input).schema(arena),
Union { inputs, .. } => return arena.get(inputs[0]).schema(arena),
Cache { input } => return arena.get(*input).schema(arena),
Sort { input, .. } => return arena.get(*input).schema(arena),
Explode { schema, .. } => schema,
#[cfg(feature = "parquet")]
ParquetScan {
Expand All @@ -198,7 +199,7 @@ impl ALogicalPlan {
output_schema,
..
} => output_schema.as_ref().unwrap_or(schema),
Selection { input, .. } => arena.get(*input).schema(arena),
Selection { input, .. } => return arena.get(*input).schema(arena),
#[cfg(feature = "csv-file")]
CsvScan {
schema,
Expand All @@ -210,14 +211,18 @@ impl ALogicalPlan {
Aggregate { schema, .. } => schema,
Join { schema, .. } => schema,
HStack { schema, .. } => schema,
Distinct { input, .. } => arena.get(*input).schema(arena),
Slice { input, .. } => arena.get(*input).schema(arena),
Distinct { input, .. } => return arena.get(*input).schema(arena),
Slice { input, .. } => return arena.get(*input).schema(arena),
Melt { schema, .. } => schema,
Udf { input, schema, .. } => match schema {
Some(schema) => schema,
None => arena.get(*input).schema(arena),
},
}
Udf { input, schema, .. } => {
let input_schema = arena.get(*input).schema(arena);
return match schema {
Some(schema) => Cow::Owned(schema.get_schema(&input_schema).unwrap()),
None => input_schema,
};
}
};
Cow::Borrowed(schema)
}
}

Expand Down Expand Up @@ -622,7 +627,7 @@ impl<'a> ALogicalPlanBuilder<'a> {
}

pub fn melt(self, args: Arc<MeltArgs>) -> Self {
let schema = det_melt_schema(&args, self.schema());
let schema = det_melt_schema(&args, &self.schema());

let lp = ALogicalPlan::Melt {
input: self.root,
Expand All @@ -635,7 +640,7 @@ impl<'a> ALogicalPlanBuilder<'a> {

pub fn project_local(self, exprs: Vec<Node>) -> Self {
let input_schema = self.lp_arena.get(self.root).schema(self.lp_arena);
let schema = aexprs_to_schema(&exprs, input_schema, Context::Default, self.expr_arena);
let schema = aexprs_to_schema(&exprs, &input_schema, Context::Default, self.expr_arena);
let lp = ALogicalPlan::LocalProjection {
expr: exprs,
input: self.root,
Expand All @@ -647,7 +652,7 @@ impl<'a> ALogicalPlanBuilder<'a> {

pub fn project(self, exprs: Vec<Node>) -> Self {
let input_schema = self.lp_arena.get(self.root).schema(self.lp_arena);
let schema = aexprs_to_schema(&exprs, input_schema, Context::Default, self.expr_arena);
let schema = aexprs_to_schema(&exprs, &input_schema, Context::Default, self.expr_arena);

// if len == 0, no projection has to be done. This is a select all operation.
if !exprs.is_empty() {
Expand All @@ -671,19 +676,19 @@ impl<'a> ALogicalPlanBuilder<'a> {
}
}

pub(crate) fn schema(&self) -> &Schema {
pub(crate) fn schema(&'a self) -> Cow<'a, SchemaRef> {
self.lp_arena.get(self.root).schema(self.lp_arena)
}

pub(crate) fn with_columns(self, exprs: Vec<Node>) -> Self {
let schema = self.schema();
let mut new_schema = (*schema).clone();
let mut new_schema = (**schema).clone();

for e in &exprs {
let field = self
.expr_arena
.get(*e)
.to_field(schema, Context::Default, self.expr_arena)
.to_field(&schema, Context::Default, self.expr_arena)
.unwrap();

new_schema.with_column(field.name().clone(), field.data_type().clone());
Expand All @@ -710,8 +715,14 @@ impl<'a> ALogicalPlanBuilder<'a> {
// TODO! add this line if LogicalPlan is dropped in favor of ALogicalPlan
// let aggs = rewrite_projections(aggs, current_schema);

let mut schema = aexprs_to_schema(&keys, current_schema, Context::Default, self.expr_arena);
let other = aexprs_to_schema(&aggs, current_schema, Context::Aggregation, self.expr_arena);
let mut schema =
aexprs_to_schema(&keys, &current_schema, Context::Default, self.expr_arena);
let other = aexprs_to_schema(
&aggs,
&current_schema,
Context::Aggregation,
self.expr_arena,
);
schema.merge(other);

let index_columns = &[
Expand Down
21 changes: 20 additions & 1 deletion polars/polars-lazy/src/logical_plan/apply.rs
Expand Up @@ -16,6 +16,25 @@ where

impl Debug for dyn DataFrameUdf {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "udf")
write!(f, "dyn DataFrameUdf")
}
}

pub trait UdfSchema: Send + Sync {
fn get_schema(&self, input_schema: &Schema) -> Result<SchemaRef>;
}

impl<F> UdfSchema for F
where
F: Fn(&Schema) -> Result<SchemaRef> + Send + Sync,
{
fn get_schema(&self, input_schema: &Schema) -> Result<SchemaRef> {
self(input_schema)
}
}

impl Debug for dyn UdfSchema {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "dyn UdfSchema")
}
}