From 006eef129010898e151532a5bc54153710885415 Mon Sep 17 00:00:00 2001 From: Praateek Mahajan Date: Mon, 1 Aug 2022 14:56:14 -0700 Subject: [PATCH 1/2] use the validation_indicator model --- demo/guide-python/spark_estimator_examples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demo/guide-python/spark_estimator_examples.py b/demo/guide-python/spark_estimator_examples.py index de1bda560251..8462371cf061 100644 --- a/demo/guide-python/spark_estimator_examples.py +++ b/demo/guide-python/spark_estimator_examples.py @@ -75,7 +75,7 @@ def create_spark_df(X, y): # train xgboost classifier model with validation dataset xgb_classifier2 = SparkXGBClassifier(max_depth=5, validation_indicator_col="validationIndicatorCol") -xgb_classifier_model2 = xgb_classifier.fit(iris_train_spark_df2) +xgb_classifier_model2 = xgb_classifier2.fit(iris_train_spark_df2) transformed_iris_test_spark_df2 = xgb_classifier_model2.transform(iris_test_spark_df) print(f"classifier2 f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df2)}") From a805a0324145f9d6a10ffe219c367bbbcffb6186 Mon Sep 17 00:00:00 2001 From: Praateek Mahajan Date: Mon, 1 Aug 2022 14:57:15 -0700 Subject: [PATCH 2/2] use the validation_indicator model for regression --- demo/guide-python/spark_estimator_examples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demo/guide-python/spark_estimator_examples.py b/demo/guide-python/spark_estimator_examples.py index 8462371cf061..e4f481a192d9 100644 --- a/demo/guide-python/spark_estimator_examples.py +++ b/demo/guide-python/spark_estimator_examples.py @@ -48,7 +48,7 @@ def create_spark_df(X, y): # train xgboost regressor model with validation dataset xgb_regressor2 = SparkXGBRegressor(max_depth=5, validation_indicator_col="validationIndicatorCol") -xgb_regressor_model2 = xgb_regressor.fit(diabetes_train_spark_df2) +xgb_regressor_model2 = xgb_regressor2.fit(diabetes_train_spark_df2) transformed_diabetes_test_spark_df2 = xgb_regressor_model2.transform(diabetes_test_spark_df) print(f"regressor2 rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df2)}")