diff --git a/demo/guide-python/spark_estimator_examples.py b/demo/guide-python/spark_estimator_examples.py index de1bda560251..e4f481a192d9 100644 --- a/demo/guide-python/spark_estimator_examples.py +++ b/demo/guide-python/spark_estimator_examples.py @@ -48,7 +48,7 @@ def create_spark_df(X, y): # train xgboost regressor model with validation dataset xgb_regressor2 = SparkXGBRegressor(max_depth=5, validation_indicator_col="validationIndicatorCol") -xgb_regressor_model2 = xgb_regressor.fit(diabetes_train_spark_df2) +xgb_regressor_model2 = xgb_regressor2.fit(diabetes_train_spark_df2) transformed_diabetes_test_spark_df2 = xgb_regressor_model2.transform(diabetes_test_spark_df) print(f"regressor2 rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df2)}") @@ -75,7 +75,7 @@ def create_spark_df(X, y): # train xgboost classifier model with validation dataset xgb_classifier2 = SparkXGBClassifier(max_depth=5, validation_indicator_col="validationIndicatorCol") -xgb_classifier_model2 = xgb_classifier.fit(iris_train_spark_df2) +xgb_classifier_model2 = xgb_classifier2.fit(iris_train_spark_df2) transformed_iris_test_spark_df2 = xgb_classifier_model2.transform(iris_test_spark_df) print(f"classifier2 f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df2)}")