Skip to content

Commit

Permalink
Reviewer's comment.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 22, 2022
1 parent ea18e0a commit 089260d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 11 deletions.
2 changes: 1 addition & 1 deletion python-package/xgboost/spark/core.py
Expand Up @@ -585,7 +585,7 @@ def _train_booster(pandas_df_iter):
dtrain, dvalid = create_dmatrix_from_partitions(
pandas_df_iter,
None,
**dmatrix_kwargs,
dmatrix_kwargs,
)
if dvalid is not None:
dval = [(dtrain, "training"), (dvalid, "validation")]
Expand Down
8 changes: 4 additions & 4 deletions python-package/xgboost/spark/data.py
Expand Up @@ -100,7 +100,7 @@ def reset(self) -> None:
def create_dmatrix_from_partitions(
iterator: Iterator[pd.DataFrame],
feature_cols: Optional[Sequence[str]],
**kwargs: Any,
kwargs: Dict[str, Any], # use dict to make sure this parameter is passed.
) -> Tuple[DMatrix, Optional[DMatrix]]:
"""Create DMatrix from spark data partitions. This is not particularly efficient as
we need to convert the pandas series format to numpy then concatenate all the data.
Expand Down Expand Up @@ -154,7 +154,7 @@ def append_dqm(part: pd.DataFrame, name: str, is_valid: bool) -> None:
else:
train_data[name].append(array)

def make(values: Dict[str, List[np.ndarray]]) -> DMatrix:
def make(values: Dict[str, List[np.ndarray]], kwargs: Dict[str, Any]) -> DMatrix:
data = concat_or_none(values[alias.data])
label = concat_or_none(values.get(alias.label, None))
weight = concat_or_none(values.get(alias.weight, None))
Expand All @@ -166,13 +166,13 @@ def make(values: Dict[str, List[np.ndarray]]) -> DMatrix:
is_dmatrix = feature_cols is None
if is_dmatrix:
cache_partitions(iterator, append_m)
dtrain = make(train_data)
dtrain = make(train_data, kwargs)
else:
cache_partitions(iterator, append_dqm)
it = PartIter(train_data, True)
dtrain = DeviceQuantileDMatrix(it, **kwargs)

dvalid = make(valid_data) if len(valid_data) != 0 else None
dvalid = make(valid_data, kwargs) if len(valid_data) != 0 else None

assert dtrain.num_col() == n_features
if dvalid is not None:
Expand Down
9 changes: 3 additions & 6 deletions tests/python/test_spark/test_data.py
Expand Up @@ -59,15 +59,12 @@ def run_dmatrix_ctor(is_dqm: bool) -> None:
df[alias.data] = pd.Series(list(X))
dfs.append(df)

kwargs = {"feature_types": feature_types}
if is_dqm:
cols = [f"feat-{i}" for i in range(n_features)]
train_Xy, valid_Xy = create_dmatrix_from_partitions(
iter(dfs), cols, feature_types=feature_types
)
train_Xy, valid_Xy = create_dmatrix_from_partitions(iter(dfs), cols, kwargs)
else:
train_Xy, valid_Xy = create_dmatrix_from_partitions(
iter(dfs), None, feature_types=feature_types
)
train_Xy, valid_Xy = create_dmatrix_from_partitions(iter(dfs), None, kwargs)

assert valid_Xy is not None
assert valid_Xy.num_row() + train_Xy.num_row() == n_samples_per_batch * n_batches
Expand Down

0 comments on commit 089260d

Please sign in to comment.