Skip to content

Commit

Permalink
Work around LargeUtf8 to Utf8 comparison error (#288)
Browse files Browse the repository at this point in the history
* Work around LargeUtf8 sorting bug by casting to Utf8

apache/arrow-rs#2654

* Fix pre-transform arrow

* Add test for LargeUtf8 crash
  • Loading branch information
jonmmease committed Mar 30, 2023
1 parent 6ae4295 commit f832f4e
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 2 deletions.
20 changes: 20 additions & 0 deletions python/vegafusion/tests/test_transformed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import pytest
from altair.utils.execeval import eval_block
import vegafusion as vf
from vega_datasets import data
import polars as pl
import altair as alt


here = Path(__file__).parent
Expand Down Expand Up @@ -126,3 +129,20 @@ def test_transformed_data_for_mock(mock_name, expected_len, expected_cols, conne

# Check expected length
assert len(df) == expected_len


def test_gh_286():
# https://github.com/hex-inc/vegafusion/issues/286
source = pl.from_pandas(data.seattle_weather())

chart = alt.Chart(source).mark_bar(
cornerRadiusTopLeft=3,
cornerRadiusTopRight=3
).encode(
x='month(date):O',
y='count():Q',
color='weather:N'
)
transformed = vf.transformed_data(chart)
assert isinstance(transformed, pl.DataFrame)
assert len(transformed) == 53
2 changes: 1 addition & 1 deletion python/vegafusion/vegafusion/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def pre_transform_datasets(self, spec, datasets, local_tz, default_input_tz=None

return processed_datasets, warnings
elif _all_datasets_have_type(inline_datasets, pa.Table):
return datasets, warnings
return values, warnings
else:
# Deserialize values to pandas DataFrames
datasets = [value.to_pandas() for value in values]
Expand Down
31 changes: 30 additions & 1 deletion vegafusion-runtime/src/data/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::task_graph::task::TaskCall;

use async_trait::async_trait;

use datafusion_expr::{lit, Expr};
use datafusion_expr::{expr, lit, Expr};
use std::collections::{HashMap, HashSet};

use std::sync::Arc;
Expand Down Expand Up @@ -142,6 +142,9 @@ impl TaskCall for DataUrlTask {
df
};

// Perform any up-front type conversions
let df = pre_process_column_types(df).await?;

// Process datetime columns
let df = process_datetimes(&parse, df, &config.tz_config).await?;

Expand Down Expand Up @@ -268,6 +271,32 @@ fn check_builtin_dataset(url: String) -> String {
}
}

async fn pre_process_column_types(df: Arc<dyn DataFrame>) -> Result<Arc<dyn DataFrame>> {
let mut selections: Vec<Expr> = Vec::new();
let mut pre_proc_needed = false;
for field in df.schema().fields.iter() {
if field.data_type() == &DataType::LargeUtf8 {
// Work around https://github.com/apache/arrow-rs/issues/2654 by converting
// LargeUtf8 to Utf8
selections.push(
Expr::Cast(expr::Cast {
expr: Box::new(flat_col(field.name())),
data_type: DataType::Utf8,
})
.alias(field.name()),
);
pre_proc_needed = true;
} else {
selections.push(flat_col(field.name()))
}
}
if pre_proc_needed {
df.select(selections).await
} else {
Ok(df)
}
}

async fn process_datetimes(
parse: &Option<Parse>,
sql_df: Arc<dyn DataFrame>,
Expand Down

0 comments on commit f832f4e

Please sign in to comment.