Skip to content

Commit

Permalink
[pyspark] Cleanup the comments (#8217)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Sep 5, 2022
1 parent ada4a86 commit 7ee10e3
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python-package/setup.py
Expand Up @@ -322,7 +322,7 @@ def run(self) -> None:
# - python setup.py bdist_wheel && pip install <wheel-name>

# When XGBoost is compiled directly with CMake:
# - pip install . -e
# - pip install -e .
# - python setup.py develop # same as above
logging.basicConfig(level=logging.INFO)

Expand Down
17 changes: 16 additions & 1 deletion python-package/xgboost/spark/core.py
Expand Up @@ -713,6 +713,13 @@ def _fit(self, dataset):

is_local = _is_local(_get_spark_session().sparkContext)

# Remove the parameters whose value is None
booster_params = {k: v for k, v in booster_params.items() if v is not None}
train_call_kwargs_params = {
k: v for k, v in train_call_kwargs_params.items() if v is not None
}
dmatrix_kwargs = {k: v for k, v in dmatrix_kwargs.items() if v is not None}

def _train_booster(pandas_df_iter):
"""Takes in an RDD partition and outputs a booster for that partition after
going through the Rabit Ring protocol
Expand All @@ -737,6 +744,15 @@ def _train_booster(pandas_df_iter):

_rabit_args = ""
if context.partitionId() == 0:
get_logger("XGBoostPySpark").info(
"booster params: %s\n"
"train_call_kwargs_params: %s\n"
"dmatrix_kwargs: %s",
booster_params,
train_call_kwargs_params,
dmatrix_kwargs,
)

_rabit_args = str(_get_rabit_args(context, num_workers))

messages = context.allGather(message=str(_rabit_args))
Expand All @@ -754,7 +770,6 @@ def _train_booster(pandas_df_iter):
dval = [(dtrain, "training"), (dvalid, "validation")]
else:
dval = None

booster = worker_train(
params=booster_params,
dtrain=dtrain,
Expand Down
5 changes: 2 additions & 3 deletions python-package/xgboost/spark/params.py
Expand Up @@ -36,7 +36,7 @@ class HasBaseMarginCol(Params):

class HasFeaturesCols(Params):
"""
Mixin for param featuresCols: a list of feature column names.
Mixin for param features_cols: a list of feature column names.
This parameter is taken effect only when use_gpu is enabled.
"""

Expand Down Expand Up @@ -76,8 +76,7 @@ def __init__(self):

class HasQueryIdCol(Params):
"""
Mixin for param featuresCols: a list of feature column names.
This parameter is taken effect only when use_gpu is enabled.
Mixin for param qid_col: query id column name.
"""

qid_col = Param(
Expand Down

0 comments on commit 7ee10e3

Please sign in to comment.