Skip to content

Commit

Permalink
transform.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 20, 2022
1 parent 1da9e6b commit a9d9550
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
30 changes: 17 additions & 13 deletions python-package/xgboost/spark/core.py
Expand Up @@ -719,7 +719,7 @@ def _transform(self, dataset):
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]:
model = xgb_sklearn_model
for data in iterator:
X = np.array(data[alias.data].tolist())
X = stack_series(data[alias.data])
if has_base_margin:
base_margin = data[alias.margin].to_numpy()
else:
Expand Down Expand Up @@ -773,6 +773,21 @@ def _transform(self, dataset):
alias.margin
)

def transform_margin(margins: np.ndarray):
if margins.ndim == 1:
# binomial case
classone_probs = expit(margins)
classzero_probs = 1.0 - classone_probs
raw_preds = np.vstack((-margins, margins)).transpose()
class_probs = np.vstack(
(classzero_probs, classone_probs)
).transpose()
else:
# multinomial case
raw_preds = margins
class_probs = softmax(raw_preds, axis=1)
return raw_preds, class_probs

@pandas_udf(
"rawPrediction array<float>, prediction float, probability array<float>"
)
Expand All @@ -794,18 +809,7 @@ def predict_udf(
validate_features=False,
**predict_params,
)
if margins.ndim == 1:
# binomial case
classone_probs = expit(margins)
classzero_probs = 1.0 - classone_probs
raw_preds = np.vstack((-margins, margins)).transpose()
class_probs = np.vstack(
(classzero_probs, classone_probs)
).transpose()
else:
# multinomial case
raw_preds = margins
class_probs = softmax(raw_preds, axis=1)
raw_preds, class_probs = transform_margin(margins)

# It seems that they use argmax of class probs,
# not of margin to get the prediction (Note: scala implementation)
Expand Down
4 changes: 4 additions & 0 deletions python-package/xgboost/spark/data.py
Expand Up @@ -31,6 +31,10 @@ def concat_or_none(seq: Optional[Sequence[np.ndarray]]) -> Optional[np.ndarray]:
def create_dmatrix_from_partitions(
iterator: Iterator[pd.DataFrame], has_validation: bool, **kwargs: Any
) -> Tuple[DMatrix, Optional[DMatrix]]:
"""Create DMatrix from spark data partitions. This is not particularly efficient as
we need to convert the pandas series format to numpy then concatenate all the data.
"""
train_data: Dict[str, List[np.ndarray]] = defaultdict(list)
valid_data: Dict[str, List[np.ndarray]] = defaultdict(list)
n_features: int = 0
Expand Down

0 comments on commit a9d9550

Please sign in to comment.