Skip to content

Commit

Permalink
Use direct arrow conversion methods if available
Browse files Browse the repository at this point in the history
  • Loading branch information
jonmmease committed Mar 23, 2024
1 parent 9c6454d commit f225b55
Showing 1 changed file with 26 additions and 14 deletions.
40 changes: 26 additions & 14 deletions altair/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,11 @@ def raise_max_rows_error():
# as equivalent to TDataType
return data # type: ignore[return-value]
elif hasattr(data, "__dataframe__"):
pi = import_pyarrow_interchange()
pa_table = pi.from_dataframe(data)
pa_table = arrow_table_from_dfi_dataframe(data)
if max_rows is not None and pa_table.num_rows > max_rows:
raise_max_rows_error()
# Return pyarrow Table instead of input since the
# `from_dataframe` call may be expensive
# `arrow_table_from_dfi_dataframe` call above may be expensive
return pa_table

if max_rows is not None and len(values) > max_rows:
Expand Down Expand Up @@ -143,9 +142,7 @@ def sample(
# Maybe this should raise an error or return something useful?
return None
elif hasattr(data, "__dataframe__"):
# experimental interchange dataframe support
pi = import_pyarrow_interchange()
pa_table = pi.from_dataframe(data)
pa_table = arrow_table_from_dfi_dataframe(data)
if not n:
if frac is None:
raise ValueError(
Expand Down Expand Up @@ -233,9 +230,7 @@ def to_values(data: DataType) -> ToValuesReturnType:
raise KeyError("values expected in data dict, but not present.")
return data
elif hasattr(data, "__dataframe__"):
# experimental interchange dataframe support
pi = import_pyarrow_interchange()
pa_table = sanitize_arrow_table(pi.from_dataframe(data))
pa_table = sanitize_arrow_table(arrow_table_from_dfi_dataframe(data))
return {"values": pa_table.to_pylist()}
else:
# Should never reach this state as tested by check_data_type
Expand Down Expand Up @@ -278,9 +273,7 @@ def _data_to_json_string(data: DataType) -> str:
raise KeyError("values expected in data dict, but not present.")
return json.dumps(data["values"], sort_keys=True)
elif hasattr(data, "__dataframe__"):
# experimental interchange dataframe support
pi = import_pyarrow_interchange()
pa_table = pi.from_dataframe(data)
pa_table = arrow_table_from_dfi_dataframe(data)
return json.dumps(pa_table.to_pylist())
else:
raise NotImplementedError(
Expand All @@ -305,11 +298,10 @@ def _data_to_csv_string(data: Union[dict, pd.DataFrame, DataFrameLike]) -> str:
return pd.DataFrame.from_dict(data["values"]).to_csv(index=False)
elif hasattr(data, "__dataframe__"):
# experimental interchange dataframe support
pi = import_pyarrow_interchange()
import pyarrow as pa
import pyarrow.csv as pa_csv

pa_table = pi.from_dataframe(data)
pa_table = arrow_table_from_dfi_dataframe(data)
csv_buffer = pa.BufferOutputStream()
pa_csv.write_csv(pa_table, csv_buffer)
return csv_buffer.getvalue().to_pybytes().decode()
Expand Down Expand Up @@ -346,3 +338,23 @@ def curry(*args, **kwargs):
stacklevel=1,
)
return curried.curry(*args, **kwargs)


def arrow_table_from_dfi_dataframe(dfi_df: DataFrameLike) -> "pyarrow.lib.Table":
"""Convert a DataFrame Interchange Protocol compatible object to an Arrow Table"""
import pyarrow as pa

# First check if the dataframe object has a method to convert to arrow.
# Give this preference over the pyarrow from_dataframe function since the object
# has more control over the conversion, and may have broader compatibility.
# This is the case for Polars, which supports Date32 columns in direct conversion
# while pyarrow does not yet support this type in from_dataframe
for convert_method_name in ("arrow", "to_arrow", "to_arrow_table"):
convert_method = getattr(dfi_df, convert_method_name, None)
if callable(convert_method):
result = convert_method()
if isinstance(result, pa.Table):
return result

pi = import_pyarrow_interchange()
return pi.from_dataframe(dfi_df)

0 comments on commit f225b55

Please sign in to comment.