Skip to content

Commit

Permalink
Merge.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Dec 6, 2023
1 parent e964d55 commit fa48de4
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 103 deletions.
163 changes: 63 additions & 100 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,24 +270,24 @@ def _is_modin_df(data: DataType) -> bool:
"Int16": "int",
"Int32": "int",
"Int64": "int",
"UInt8": "i",
"UInt16": "i",
"UInt32": "i",
"UInt64": "i",
"UInt8": "int",
"UInt16": "int",
"UInt32": "int",
"UInt64": "int",
"Float32": "float",
"Float64": "float",
"boolean": "i",
}

pandas_pyarrow_mapper = {
"int8[pyarrow]": "i",
"int16[pyarrow]": "i",
"int32[pyarrow]": "i",
"int64[pyarrow]": "i",
"uint8[pyarrow]": "i",
"uint16[pyarrow]": "i",
"uint32[pyarrow]": "i",
"uint64[pyarrow]": "i",
"int8[pyarrow]": "int",
"int16[pyarrow]": "int",
"int32[pyarrow]": "int",
"int64[pyarrow]": "int",
"uint8[pyarrow]": "int",
"uint16[pyarrow]": "int",
"uint32[pyarrow]": "int",
"uint64[pyarrow]": "int",
"float[pyarrow]": "float",
"float32[pyarrow]": "float",
"double[pyarrow]": "float",
Expand Down Expand Up @@ -412,83 +412,36 @@ def is_pd_sparse_dtype(dtype: PandasDType) -> bool:
return is_sparse(dtype)


def pandas_cat_null(data: DataFrame) -> List[np.ndarray]:
"""Handle categorical dtype and nullable extension types from pandas."""
def pandas_transform_data(data: DataFrame) -> List[np.ndarray]:
"""Handle categorical dtype and extension types from pandas."""
import pandas as pd

# handle category codes and nullable.
cat_columns = []
nul_columns = []
# avoid an unnecessary conversion if possible
for col, dtype in zip(data.columns, data.dtypes):
if is_pd_cat_dtype(dtype):
cat_columns.append(col)
elif is_pa_ext_categorical_dtype(dtype):
raise ValueError(
"pyarrow dictionary type is not supported. Use pandas category instead."
)
elif is_nullable_dtype(dtype):
nul_columns.append(col)

if cat_columns or nul_columns:
# Avoid transformation due to: PerformanceWarning: DataFrame is highly
# fragmented
transformed = data.copy(deep=False)
else:
transformed = data
result = []

def cat_codes(ser: pd.Series) -> pd.Series:
def cat_codes(ser: pd.Series) -> np.ndarray:
if is_pd_cat_dtype(ser.dtype):
return ser.cat.codes
assert is_pa_ext_categorical_dtype(ser.dtype)
return _ensure_np_dtype(
ser.cat.codes.astype(np.float32).replace(-1.0, np.NaN).values,
np.float32,
)[0]
# Not yet supported, the index is not ordered for some reason. Alternately:
# `combine_chunks().to_pandas().cat.codes`. The result is the same.
return ser.array.__arrow_array__().combine_chunks().dictionary_encode().indices

if cat_columns:
# DF doesn't have the cat attribute, as a result, we use apply here
transformed[cat_columns] = (
transformed[cat_columns]
.apply(cat_codes)
.astype(np.float32)
assert is_pa_ext_categorical_dtype(ser.dtype)
return (
ser.array.__arrow_array__()
.combine_chunks()
.dictionary_encode()
.indices.astype(np.float32)
.replace(-1.0, np.NaN)
)
if nul_columns:
transformed[nul_columns] = transformed[nul_columns].astype(np.float32)

# FIXME(jiamingy): Investigate the possibility of using dataframe protocol or arrow
# IPC format for pandas so that we can apply the data transformation inside XGBoost
# for better memory efficiency.

def map_np(column: str) -> np.ndarray:
arr = transformed[column].values
if is_pd_sparse_dtype(arr.dtype):
# FIXME(jiamingy): We can support mixed type with (Quantile)DMatrix, but
# needs more work on column-major data. Inplace-predict however, is less
# likely.
arr = cast(pd.arrays.SparseArray, arr)
arr = arr.to_dense()
if _is_np_array_like(arr):
arr, _ = _ensure_np_dtype(arr, arr.dtype)
return arr

return list(map(map_np, transformed.columns))


def pandas_ext_num_types(data: DataFrame) -> List[np.ndarray]:
"""Experimental suppport for handling pyarrow extension numeric types."""
import pandas as pd
import pyarrow as pa

arrays = []
def pa_type(ser: pd.Series) -> np.ndarray:
import pyarrow as pa

for col, dtype in zip(data.columns, data.dtypes):
if not is_pa_ext_dtype(dtype):
continue
# No copy, callstack:
# pandas.core.internals.managers.SingleBlockManager.array_values()
# pandas.core.internals.blocks.EABackedBlock.values
d_array: pd.arrays.ArrowExtensionArray = data[col].array
d_array: pd.arrays.ArrowExtensionArray = ser.array
# no copy in __arrow_array__
# ArrowExtensionArray._data is a chunked array
aa: pa.ChunkedArray = d_array.__arrow_array__()
Expand All @@ -503,16 +456,39 @@ def pandas_ext_num_types(data: DataFrame) -> List[np.ndarray]:
# mask.
arr: np.ndarray = chunk.to_numpy(zero_copy_only=zero_copy, writable=False)
arr, _ = _ensure_np_dtype(arr, arr.dtype)
arrays.append(arr)
return arr

# avoid an unnecessary conversion if possible
for col, dtype in zip(data.columns, data.dtypes):
if is_pd_cat_dtype(dtype):
result.append(cat_codes(data[col]))
elif is_pa_ext_categorical_dtype(dtype):
raise ValueError(
"pyarrow dictionary type is not supported. Use pandas category instead."
)
elif is_pa_ext_dtype(dtype):
result.append(pa_type(data[col]))
elif is_nullable_dtype(dtype):
result.append(data[col].astype(np.float32).values)
elif is_pd_sparse_dtype(dtype):
arr = cast(pd.arrays.SparseArray, data[col].values)
arr = arr.to_dense()
if _is_np_array_like(arr):
arr, _ = _ensure_np_dtype(arr, arr.dtype)
result.append(arr)
else:
result.append(_ensure_np_dtype(data[col].values, data[col].dtype)[0])

return arrays
# FIXME(jiamingy): Investigate the possibility of using dataframe protocol or arrow
# IPC format for pandas so that we can apply the data transformation inside XGBoost
# for better memory efficiency.
return result


def pandas_check_dtypes(data: DataFrame, enable_categorical: bool) -> bool:
def pandas_check_dtypes(data: DataFrame, enable_categorical: bool) -> None:
"""Validate the input types, returns True if the dataframe is backed by arrow."""
pyarrow_extension = False
sparse_extension = False
non_arrow = False

for dtype in data.dtypes:
if not (
(dtype.name in _pandas_dtype_mapper)
Expand All @@ -522,22 +498,11 @@ def pandas_check_dtypes(data: DataFrame, enable_categorical: bool) -> bool:
):
_invalid_dataframe_dtype(data)

if is_pa_ext_dtype(dtype):
pyarrow_extension = True
else:
non_arrow = True
if is_pd_sparse_dtype(dtype):
sparse_extension = True

if sparse_extension:
warnings.warn("Sparse arrays from pandas are converted into dense.")
if pyarrow_extension and non_arrow:
raise TypeError(
"Mixed column type in a dataframe is not supported. Found arrow-based"
" columns and non-arrow-based columns at the same time."
)

return pyarrow_extension


class PandasTransformed:
Expand Down Expand Up @@ -565,20 +530,15 @@ def _transform_pandas_df(
feature_types: Optional[FeatureTypes] = None,
meta: Optional[str] = None,
) -> Tuple[PandasTransformed, Optional[FeatureNames], Optional[FeatureTypes]]:
pyarrow_extension = pandas_check_dtypes(data, enable_categorical)
pandas_check_dtypes(data, enable_categorical)
if meta and len(data.columns) > 1 and meta not in _matrix_meta:
raise ValueError(f"DataFrame for {meta} cannot have multiple columns")

feature_names, feature_types = pandas_feature_info(
data, meta, feature_names, feature_types, enable_categorical
)

if pyarrow_extension:
# Categorical dtype doesn't work wtih arrow yet.
arrays = pandas_ext_num_types(data)
else:
arrays = pandas_cat_null(data)

arrays = pandas_transform_data(data)
return PandasTransformed(arrays), feature_names, feature_types


Expand Down Expand Up @@ -636,7 +596,10 @@ def _meta_from_pandas_series(
data: DataType, name: str, dtype: Optional[NumpyDType], handle: ctypes.c_void_p
) -> None:
"""Help transform pandas series for meta data like labels"""
data = data.values.astype("float")
if is_pd_sparse_dtype(data.dtype):
data = data.values.to_dense().astype(np.float32)
else:
data = data.values.astype("float")

if is_pd_sparse_dtype(getattr(data, "dtype", data)):
data = data.to_dense() # type: ignore
Expand Down
17 changes: 14 additions & 3 deletions tests/python/test_with_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def test_pyarrow_type(self, DMatrixT: Type[xgb.DMatrix]) -> None:

@pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix])
def test_mixed_type(self, DMatrixT: Type[xgb.DMatrix]) -> None:
f0 = np.arange(0, 10)
f0 = np.arange(0, 4)
f1 = pd.Series(f0, dtype="int64[pyarrow]")
f2l = list(f0)
f2l[0] = pd.NA
Expand All @@ -538,8 +538,19 @@ def test_mixed_type(self, DMatrixT: Type[xgb.DMatrix]) -> None:
assert m.num_col() == df.shape[1]

df["f1"] = f1
with pytest.raises(TypeError, match="arrow-based"):
DMatrixT(df)
m = DMatrixT(df)
assert m.num_col() == df.shape[1]
assert m.num_row() == df.shape[0]
assert m.num_nonmissing() == df.size - 1
assert m.feature_names == list(map(str, df.columns))
assert m.feature_types == ["int"] * df.shape[1]

y = f0
m.set_info(label=y)
booster = xgb.train({}, m)
p0 = booster.inplace_predict(df)
p1 = booster.predict(m)
np.testing.assert_allclose(p0, p1)

@pytest.mark.skipif(tm.is_windows(), reason="Rabit does not run on windows")
def test_pandas_column_split(self):
Expand Down

0 comments on commit fa48de4

Please sign in to comment.