Skip to content

Commit

Permalink
Infer large_string type as pyarrow_numpy strings (pandas-dev#54826)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl committed Sep 2, 2023
1 parent 7688d52 commit 1539526
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 2 deletions.
9 changes: 9 additions & 0 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,15 @@ def _str_rstrip(self, to_strip=None):
class ArrowStringArrayNumpySemantics(ArrowStringArray):
_storage = "pyarrow_numpy"

def __init__(self, values) -> None:
_chk_pyarrow_available()

if isinstance(values, (pa.Array, pa.ChunkedArray)) and pa.types.is_large_string(
values.type
):
values = pc.cast(values, pa.string())
super().__init__(values)

@classmethod
def _result_converter(cls, values, na=None):
if not isna(na):
Expand Down
5 changes: 4 additions & 1 deletion pandas/io/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,7 @@ def _arrow_dtype_mapping() -> dict:
def arrow_string_types_mapper() -> Callable:
pa = import_optional_dependency("pyarrow")

return {pa.string(): pd.StringDtype(storage="pyarrow_numpy")}.get
return {
pa.string(): pd.StringDtype(storage="pyarrow_numpy"),
pa.large_string(): pd.StringDtype(storage="pyarrow_numpy"),
}.get
8 changes: 7 additions & 1 deletion pandas/tests/arrays/string_/test_string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
StringArray,
StringDtype,
)
from pandas.core.arrays.string_arrow import ArrowStringArray
from pandas.core.arrays.string_arrow import (
ArrowStringArray,
ArrowStringArrayNumpySemantics,
)

skip_if_no_pyarrow = pytest.mark.skipif(
pa_version_under7p0,
Expand Down Expand Up @@ -166,6 +169,9 @@ def test_pyarrow_not_installed_raises():
with pytest.raises(ImportError, match=msg):
ArrowStringArray([])

with pytest.raises(ImportError, match=msg):
ArrowStringArrayNumpySemantics([])

with pytest.raises(ImportError, match=msg):
ArrowStringArray._from_sequence(["a", None, "b"])

Expand Down
19 changes: 19 additions & 0 deletions pandas/tests/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,25 @@ def test_roundtrip_decimal(self, tmp_path, pa):
expected = pd.DataFrame({"a": ["123"]}, dtype="string[python]")
tm.assert_frame_equal(result, expected)

def test_infer_string_large_string_type(self, tmp_path, pa):
# GH#54798
import pyarrow as pa
import pyarrow.parquet as pq

path = tmp_path / "large_string.p"

table = pa.table({"a": pa.array([None, "b", "c"], pa.large_string())})
pq.write_table(table, path)

with pd.option_context("future.infer_string", True):
result = read_parquet(path)
expected = pd.DataFrame(
data={"a": [None, "b", "c"]},
dtype="string[pyarrow_numpy]",
columns=pd.Index(["a"], dtype="string[pyarrow_numpy]"),
)
tm.assert_frame_equal(result, expected)


class TestParquetFastParquet(Base):
def test_basic(self, fp, df_full):
Expand Down

0 comments on commit 1539526

Please sign in to comment.