diff --git a/doc/install.rst b/doc/install.rst index a5ffad85acd3..7ce06aced7a4 100644 --- a/doc/install.rst +++ b/doc/install.rst @@ -101,7 +101,7 @@ R JVM --- -You can use XGBoost4J in your Java/Scala application by adding XGBoost4J as a dependency: +* XGBoost4j/XGBoost4j-Spark .. code-block:: xml :caption: Maven @@ -134,6 +134,39 @@ You can use XGBoost4J in your Java/Scala application by adding XGBoost4J as a de "ml.dmlc" %% "xgboost4j-spark" % "latest_version_num" ) +* XGBoost4j-GPU/XGBoost4j-Spark-GPU + +.. code-block:: xml + :caption: Maven + + + ... + + 2.12 + + + + ... + + ml.dmlc + xgboost4j-gpu_${scala.binary.version} + latest_version_num + + + ml.dmlc + xgboost4j-spark-gpu_${scala.binary.version} + latest_version_num + + + +.. code-block:: scala + :caption: sbt + + libraryDependencies ++= Seq( + "ml.dmlc" %% "xgboost4j-gpu" % "latest_version_num", + "ml.dmlc" %% "xgboost4j-spark-gpu" % "latest_version_num" + ) + This will check out the latest stable version from the Maven Central. For the latest release version number, please check `release page `_. @@ -185,7 +218,7 @@ and Windows.) Download it and run the following commands: JVM --- -First add the following Maven repository hosted by the XGBoost project: +* XGBoost4j/XGBoost4j-Spark .. code-block:: xml :caption: Maven @@ -234,6 +267,40 @@ Then add XGBoost4J as a dependency: "ml.dmlc" %% "xgboost4j-spark" % "latest_version_num-SNAPSHOT" ) +* XGBoost4j-GPU/XGBoost4j-Spark-GPU + +.. code-block:: xml + :caption: maven + + + ... + + 2.12 + + + + ... + + ml.dmlc + xgboost4j-gpu_${scala.binary.version} + latest_version_num-SNAPSHOT + + + ml.dmlc + xgboost4j-spark-gpu_${scala.binary.version} + latest_version_num-SNAPSHOT + + + +.. code-block:: scala + :caption: sbt + + libraryDependencies ++= Seq( + "ml.dmlc" %% "xgboost4j-gpu" % "latest_version_num-SNAPSHOT", + "ml.dmlc" %% "xgboost4j-spark-gpu" % "latest_version_num-SNAPSHOT" + ) + + Look up the ``version`` field in `pom.xml `_ to get the correct version number. The SNAPSHOT JARs are hosted by the XGBoost project. Every commit in the ``master`` branch will automatically trigger generation of a new SNAPSHOT JAR. You can control how often Maven should upgrade your SNAPSHOT installation by specifying ``updatePolicy``. See `here `_ for details. diff --git a/doc/jvm/index.rst b/doc/jvm/index.rst index cb90d82f3f53..895a325954c1 100644 --- a/doc/jvm/index.rst +++ b/doc/jvm/index.rst @@ -35,6 +35,7 @@ Contents java_intro XGBoost4J-Spark Tutorial + XGBoost4J-Spark-GPU Turorial Code Examples XGBoost4J Java API XGBoost4J Scala API diff --git a/doc/jvm/xgboost4j_spark_gpu_tutorial.rst b/doc/jvm/xgboost4j_spark_gpu_tutorial.rst new file mode 100644 index 000000000000..60fcbbc35aef --- /dev/null +++ b/doc/jvm/xgboost4j_spark_gpu_tutorial.rst @@ -0,0 +1,246 @@ +############################################# +XGBoost4J-Spark-GPU Tutorial (version 1.6.0+) +############################################# + +**XGBoost4J-Spark-GPU** is a project aiming to accelerate XGBoost distributed training on Spark from +end to end with GPUs by leveraging the `Spark-Rapids `_ project. + +This tutorial will show you how to use **XGBoost4J-Spark-GPU**. + +.. contents:: + :backlinks: none + :local: + +************************************************ +Build an ML Application with XGBoost4J-Spark-GPU +************************************************ + +Adding XGBoost to Your Project +============================== + +Before we go into the tour of how to use XGBoost4J-Spark-GPU, you should first consult +:ref:`Installation from Maven repository ` in order to add XGBoost4J-Spark-GPU as +a dependency for your project. We provide both stable releases and snapshots. + +Data Preparation +================ + +In this section, we use `Iris `_ dataset as an example to +showcase how we use Spark to transform raw dataset and make it fit to the data interface of XGBoost. + +Iris dataset is shipped in CSV format. Each instance contains 4 features, "sepal length", "sepal width", +"petal length" and "petal width". In addition, it contains the "class" column, which is essentially the +label with three possible values: "Iris Setosa", "Iris Versicolour" and "Iris Virginica". + +Read Dataset with Spark's Built-In Reader +----------------------------------------- + +.. code-block:: scala + + import org.apache.spark.sql.SparkSession + import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} + + val spark = SparkSession.builder().getOrCreate() + + val labelName = "class" + val schema = new StructType(Array( + StructField("sepal length", DoubleType, true), + StructField("sepal width", DoubleType, true), + StructField("petal length", DoubleType, true), + StructField("petal width", DoubleType, true), + StructField(labelName, StringType, true))) + + val xgbInput = spark.read.option("header", "false") + .schema(schema) + .csv(dataPath) + +At the first line, we create an instance of `SparkSession `_ +which is the entry of any Spark program working with DataFrame. The ``schema`` variable +defines the schema of DataFrame wrapping Iris data. With this explicitly set schema, we +can define the columns' name as well as their types; otherwise the column name would be +the default ones derived by Spark, such as ``_col0``, etc. Finally, we can use Spark's +built-in csv reader to load Iris csv file as a DataFrame named ``xgbInput``. + +Spark also contains many built-in readers for other format. eg ORC, Parquet, Avro, Json. + +Transform Raw Iris Dataset +-------------------------- + +To make Iris dataset be recognizable to XGBoost, we need to encode String-typed +label, i.e. "class", to Double-typed label. + +One way to convert the String-typed label to Double is to use Spark's built-in feature transformer +`StringIndexer `_. +but it has not been accelerated by Spark-Rapids yet, which means it will fall back +to CPU to run and cause performance issue. Instead, we use an alternative way to acheive +the same goal by the following code + +.. code-block:: scala + + import org.apache.spark.sql.expressions.Window + import org.apache.spark.sql.functions._ + + val spec = Window.orderBy(labelName) + val Array(train, test) = xgbInput + .withColumn("tmpClassName", dense_rank().over(spec) - 1) + .drop(labelName) + .withColumnRenamed("tmpClassName", labelName) + .randomSplit(Array(0.7, 0.3), seed = 1) + + train.show(5) + +.. code-block:: none + + +------------+-----------+------------+-----------+-----+ + |sepal length|sepal width|petal length|petal width|class| + +------------+-----------+------------+-----------+-----+ + | 4.3| 3.0| 1.1| 0.1| 0| + | 4.4| 2.9| 1.4| 0.2| 0| + | 4.4| 3.0| 1.3| 0.2| 0| + | 4.4| 3.2| 1.3| 0.2| 0| + | 4.6| 3.2| 1.4| 0.2| 0| + +------------+-----------+------------+-----------+-----+ + + +With window operations, we have mapped string column of labels to label indices. + +Training +======== + +The GPU version of XGBoost-Spark supports both regression and classification +models. Although we use the Iris dataset in this tutorial to show how we use +``XGBoost/XGBoost4J-Spark-GPU`` to resolve a multi-classes classification problem, the +usage in Regression is very similar to classification. + +To train a XGBoost model for classification, we need to claim a XGBoostClassifier first: + +.. code-block:: scala + + import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier + val xgbParam = Map( + "objective" -> "multi:softprob", + "num_class" -> 3, + "num_round" -> 100, + "tree_method" -> "gpu_hist", + "num_workers" -> 1) + + val featuresNames = schema.fieldNames.filter(name => name != labelName) + + val xgbClassifier = new XGBoostClassifier(xgbParam) + .setFeaturesCol(featuresNames) + .setLabelCol(labelName) + +The available parameters for training a XGBoost model can be found in :doc:`here `. +Similar to the XGBoost4J-Spark package, in addition to the default set of parameters, +XGBoost4J-Spark-GPU also supports the camel-case variant of these parameters to be +consistent with Spark's MLLIB naming convention. + +Specifically, each parameter in :doc:`this page ` has its equivalent form in +XGBoost4J-Spark-GPU with camel case. For example, to set ``max_depth`` for each tree, you can pass +parameter just like what we did in the above code snippet (as ``max_depth`` wrapped in a Map), or +you can do it through setters in XGBoostClassifer: + +.. code-block:: scala + + val xgbClassifier = new XGBoostClassifier(xgbParam) + .setFeaturesCol(featuresNames) + .setLabelCol(labelName) + xgbClassifier.setMaxDepth(2) + +.. note:: + + In contrast to the XGBoost4J-Spark package, which needs to first assemble the numeric + feature columns into one column with VectorUDF type by VectorAssembler, the + XGBoost4J-Spark-GPU does not require such transformation, it accepts an array of feature + column names by ``setFeaturesCol(value: Array[String])``. + +After we set XGBoostClassifier parameters and feature/label columns, we can build a +transformer, XGBoostClassificationModel by fitting XGBoostClassifier with the input +DataFrame. This ``fit`` operation is essentially the training process and the generated +model can then be used in other tasks like prediction. + +.. code-block:: scala + + val xgbClassificationModel = xgbClassifier.fit(train) + +Prediction +========== + +When we get a model, either XGBoostClassificationModel or XGBoostRegressionModel, it takes a DataFrame, +read the column containing feature vectors, predict for each feature vector, and output a new DataFrame +with the following columns by default: + +* XGBoostClassificationModel will output margins (``rawPredictionCol``), probabilities(``probabilityCol``) and the eventual prediction labels (``predictionCol``) for each possible label. +* XGBoostRegressionModel will output prediction label(``predictionCol``). + +.. code-block:: scala + + val xgbClassificationModel = xgbClassifier.fit(train) + val results = xgbClassificationModel.transform(test) + results.show() + +With the above code snippet, we get a DataFrame as result, which contains the margin, probability for each class, +and the prediction for each instance + +.. code-block:: none + + +------------+-----------+------------------+-------------------+-----+--------------------+--------------------+----------+ + |sepal length|sepal width| petal length| petal width|class| rawPrediction| probability|prediction| + +------------+-----------+------------------+-------------------+-----+--------------------+--------------------+----------+ + | 4.5| 2.3| 1.3|0.30000000000000004| 0|[3.16666603088378...|[0.98853939771652...| 0.0| + | 4.6| 3.1| 1.5| 0.2| 0|[3.25857257843017...|[0.98969423770904...| 0.0| + | 4.8| 3.1| 1.6| 0.2| 0|[3.25857257843017...|[0.98969423770904...| 0.0| + | 4.8| 3.4| 1.6| 0.2| 0|[3.25857257843017...|[0.98969423770904...| 0.0| + | 4.8| 3.4|1.9000000000000001| 0.2| 0|[3.25857257843017...|[0.98969423770904...| 0.0| + | 4.9| 2.4| 3.3| 1.0| 1|[-2.1498908996582...|[0.00596602633595...| 1.0| + | 4.9| 2.5| 4.5| 1.7| 2|[-2.1498908996582...|[0.00596602633595...| 1.0| + | 5.0| 3.5| 1.3|0.30000000000000004| 0|[3.25857257843017...|[0.98969423770904...| 0.0| + | 5.1| 2.5| 3.0| 1.1| 1|[3.16666603088378...|[0.98853939771652...| 0.0| + | 5.1| 3.3| 1.7| 0.5| 0|[3.25857257843017...|[0.98969423770904...| 0.0| + | 5.1| 3.5| 1.4| 0.2| 0|[3.25857257843017...|[0.98969423770904...| 0.0| + | 5.1| 3.8| 1.6| 0.2| 0|[3.25857257843017...|[0.98969423770904...| 0.0| + | 5.2| 3.4| 1.4| 0.2| 0|[3.25857257843017...|[0.98969423770904...| 0.0| + | 5.2| 3.5| 1.5| 0.2| 0|[3.25857257843017...|[0.98969423770904...| 0.0| + | 5.2| 4.1| 1.5| 0.1| 0|[3.25857257843017...|[0.98969423770904...| 0.0| + | 5.4| 3.9| 1.7| 0.4| 0|[3.25857257843017...|[0.98969423770904...| 0.0| + | 5.5| 2.4| 3.8| 1.1| 1|[-2.1498908996582...|[0.00596602633595...| 1.0| + | 5.5| 4.2| 1.4| 0.2| 0|[3.25857257843017...|[0.98969423770904...| 0.0| + | 5.7| 2.5| 5.0| 2.0| 2|[-2.1498908996582...|[0.00280966912396...| 2.0| + | 5.7| 3.0| 4.2| 1.2| 1|[-2.1498908996582...|[0.00643939292058...| 1.0| + +------------+-----------+------------------+-------------------+-----+--------------------+--------------------+----------+ + +********************** +Submit the application +********************** + +Take submitting the spark job to Spark Standalone cluster as an example, and assuming your application main class +is ``Iris`` and the application jar is ``iris-1.0.0.jar`` + +.. code-block:: bash + + cudf_version=22.02.0 + rapids_version=22.02.0 + xgboost_version=1.6.0 + main_class=Iris + app_jar=iris-1.0.0.jar + + spark-submit \ + --master $master \ + --packages ai.rapids:cudf:${cudf_version},com.nvidia:rapids-4-spark_2.12:${rapids_version},ml.dmlc:xgboost4j-gpu_2.12:${xgboost_version},ml.dmlc:xgboost4j-spark-gpu_2.12:${xgboost_version} \ + --conf spark.executor.cores=12 \ + --conf spark.task.cpus=1 \ + --conf spark.executor.resource.gpu.amount=1 \ + --conf spark.task.resource.gpu.amount=0.08 \ + --conf spark.rapids.sql.csv.read.double.enabled=true \ + --conf spark.rapids.sql.hasNans=false \ + --conf spark.plugins=com.nvidia.spark.SQLPlugin \ + --class ${main_class} \ + ${app_jar} + +* First, we need to specify the ``spark-rapids, cudf, xgboost4j-gpu, xgboost4j-spark-gpu`` packages by ``--packages`` +* Second, ``spark-rapids`` is a Spark plugin, so we need to configure it by specifying ``spark.plugins=com.nvidia.spark.SQLPlugin`` + +For details about ``spark-rapids`` other configurations, please refer to `configuration `_. + +For ``spark-rapids Frequently Asked Questions``, please refer to +`frequently-asked-questions `_. diff --git a/doc/tutorials/index.rst b/doc/tutorials/index.rst index d2cf979e39f3..2d0ec8a2fb8b 100644 --- a/doc/tutorials/index.rst +++ b/doc/tutorials/index.rst @@ -14,6 +14,7 @@ See `Awesome XGBoost `_ for mo Distributed XGBoost with AWS YARN kubernetes Distributed XGBoost with XGBoost4J-Spark + Distributed XGBoost with XGBoost4J-Spark-GPU dask ray dart diff --git a/jvm-packages/xgboost4j-gpu/pom.xml b/jvm-packages/xgboost4j-gpu/pom.xml index 19de82c1e756..b601ebefc076 100644 --- a/jvm-packages/xgboost4j-gpu/pom.xml +++ b/jvm-packages/xgboost4j-gpu/pom.xml @@ -20,11 +20,6 @@ ${cudf.classifier} provided - - com.fasterxml.jackson.core - jackson-databind - 2.10.5.1 - org.apache.hadoop hadoop-hdfs diff --git a/jvm-packages/xgboost4j-gpu/src/main/java/ml/dmlc/xgboost4j/gpu/java/CudfUtils.java b/jvm-packages/xgboost4j-gpu/src/main/java/ml/dmlc/xgboost4j/gpu/java/CudfUtils.java index b63ef7f30538..f7071dcd5fb2 100644 --- a/jvm-packages/xgboost4j-gpu/src/main/java/ml/dmlc/xgboost4j/gpu/java/CudfUtils.java +++ b/jvm-packages/xgboost4j-gpu/src/main/java/ml/dmlc/xgboost4j/gpu/java/CudfUtils.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2021 by Contributors + Copyright (c) 2021-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,15 +16,7 @@ package ml.dmlc.xgboost4j.gpu.java; -import java.io.ByteArrayOutputStream; -import java.io.IOException; - -import com.fasterxml.jackson.core.JsonFactory; -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ArrayNode; -import com.fasterxml.jackson.databind.node.JsonNodeFactory; -import com.fasterxml.jackson.databind.node.ObjectNode; +import java.util.ArrayList; /** * Cudf utilities to build cuda array interface against {@link CudfColumn} @@ -42,58 +34,64 @@ public static String buildArrayInterface(CudfColumn... cudfColumns) { // Helper class to build array interface string private static class Builder { - private JsonNodeFactory nodeFactory = new JsonNodeFactory(false); - private ArrayNode rootArrayNode = nodeFactory.arrayNode(); + private ArrayList colArrayInterfaces = new ArrayList(); private Builder add(CudfColumn... columns) { if (columns == null || columns.length <= 0) { throw new IllegalArgumentException("At least one ColumnData is required."); } for (CudfColumn cd : columns) { - rootArrayNode.add(buildColumnObject(cd)); + colArrayInterfaces.add(buildColumnObject(cd)); } return this; } private String build() { - try { - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - JsonGenerator jsonGen = new JsonFactory().createGenerator(bos); - new ObjectMapper().writeTree(jsonGen, rootArrayNode); - return bos.toString(); - } catch (IOException ie) { - ie.printStackTrace(); - throw new RuntimeException("Failed to build array interface. Error: " + ie); + StringBuilder builder = new StringBuilder(); + builder.append("["); + for (int i = 0; i < colArrayInterfaces.size(); i++) { + builder.append(colArrayInterfaces.get(i)); + if (i != colArrayInterfaces.size() - 1) { + builder.append(","); + } } + builder.append("]"); + return builder.toString(); } - private ObjectNode buildColumnObject(CudfColumn column) { + /** build the whole column information including data and valid info */ + private String buildColumnObject(CudfColumn column) { if (column.getDataPtr() == 0) { throw new IllegalArgumentException("Empty column data is NOT accepted!"); } if (column.getTypeStr() == null || column.getTypeStr().isEmpty()) { throw new IllegalArgumentException("Empty type string is NOT accepted!"); } - ObjectNode colDataObj = buildMetaObject(column.getDataPtr(), column.getShape(), - column.getTypeStr()); + StringBuilder builder = new StringBuilder(); + String colData = buildMetaObject(column.getDataPtr(), column.getShape(), + column.getTypeStr()); + builder.append("{"); + builder.append(colData); if (column.getValidPtr() != 0 && column.getNullCount() != 0) { - ObjectNode validObj = buildMetaObject(column.getValidPtr(), column.getShape(), " checkNumericType(schema, fn)) if (fitting) { require(labelName.nonEmpty, "label column is not set.") diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostClassifierSuite.scala index 9bed82072c2d..fc26b29858a6 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostClassifierSuite.scala @@ -126,7 +126,7 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite { val vectorAssembler = new VectorAssembler() .setHandleInvalid("keep") - .setInputCols(featureNames.toArray) + .setInputCols(featureNames) .setOutputCol("features") val trainingDf = vectorAssembler.transform(rawInput).select("features", labelName) @@ -147,12 +147,12 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite { .csv(dataPath).randomSplit(Array(0.7, 0.3), seed = 1) // Since CPU model does not know the information about the features cols that GPU transform - // pipeline requires. End user needs to setFeaturesCols in the model manually - val thrown = intercept[IllegalArgumentException](cpuModel + // pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model + // manually + val thrown = intercept[NoSuchElementException](cpuModel .transform(testDf) .collect()) - assert(thrown.getMessage.contains("Gpu transform requires features columns. " + - "please refer to setFeaturesCols")) + assert(thrown.getMessage.contains("Failed to find a default value for featuresCols")) val left = cpuModel .setFeaturesCol(featureNames) @@ -195,17 +195,16 @@ class GpuXGBoostClassifierSuite extends GpuTestSuite { val featureColName = "feature_col" val vectorAssembler = new VectorAssembler() .setHandleInvalid("keep") - .setInputCols(featureNames.toArray) + .setInputCols(featureNames) .setOutputCol(featureColName) val testDf = vectorAssembler.transform(rawInput).select(featureColName, labelName) // Since GPU model does not know the information about the features col name that CPU // transform pipeline requires. End user needs to setFeaturesCol in the model manually - val thrown = intercept[IllegalArgumentException]( + intercept[IllegalArgumentException]( gpuModel .transform(testDf) .collect()) - assert(thrown.getMessage.contains("features does not exist")) val left = gpuModel .setFeaturesCol(featureColName) diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostGeneralSuite.scala index 53cdcb923739..3d643761a3df 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostGeneralSuite.scala @@ -108,12 +108,15 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite { val trainingDf = trainingData.toDF(allColumnNames: _*) val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "multi:softprob", "num_class" -> 3, "num_round" -> 5, "num_workers" -> 1, "tree_method" -> "gpu_hist") - val thrown = intercept[IllegalArgumentException] { + + // GPU train requires featuresCols. If not specified, + // then NoSuchElementException will be thrown + val thrown = intercept[NoSuchElementException] { new XGBoostClassifier(xgbParam) .setLabelCol(labelName) .fit(trainingDf) } - assert(thrown.getMessage.contains("Gpu train requires features columns.")) + assert(thrown.getMessage.contains("Failed to find a default value for featuresCols")) val thrown1 = intercept[IllegalArgumentException] { new XGBoostClassifier(xgbParam) diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala index 18f35ee87dd4..5342aa563621 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2021 by Contributors + Copyright (c) 2021-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -86,7 +86,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite { .csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1) val classifier = new XGBoostRegressor(xgbParam) - .setFeaturesCols(featureNames) + .setFeaturesCol(featureNames) .setLabelCol(labelName) .setTreeMethod("gpu_hist") (classifier.fit(rawInput), testDf) @@ -122,7 +122,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite { val vectorAssembler = new VectorAssembler() .setHandleInvalid("keep") - .setInputCols(featureNames.toArray) + .setInputCols(featureNames) .setOutputCol("features") val trainingDf = vectorAssembler.transform(rawInput).select("features", labelName) @@ -143,20 +143,20 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite { .csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1) // Since CPU model does not know the information about the features cols that GPU transform - // pipeline requires. End user needs to setFeaturesCols in the model manually - val thrown = intercept[IllegalArgumentException](cpuModel + // pipeline requires. End user needs to setFeaturesCol(features: Array[String]) in the model + // manually + val thrown = intercept[NoSuchElementException](cpuModel .transform(testDf) .collect()) - assert(thrown.getMessage.contains("Gpu transform requires features columns. " + - "please refer to setFeaturesCols")) + assert(thrown.getMessage.contains("Failed to find a default value for featuresCols")) val left = cpuModel - .setFeaturesCols(featureNames) + .setFeaturesCol(featureNames) .transform(testDf) .collect() val right = cpuModelFromFile - .setFeaturesCols(featureNames) + .setFeaturesCol(featureNames) .transform(testDf) .collect() @@ -173,7 +173,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite { .csv(getResourcePath("/rank.train.csv")).randomSplit(Array(0.7, 0.3), seed = 1) val classifier = new XGBoostRegressor(xgbParam) - .setFeaturesCols(featureNames) + .setFeaturesCol(featureNames) .setLabelCol(labelName) .setTreeMethod("gpu_hist") classifier.fit(rawInput) @@ -191,17 +191,16 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite { val featureColName = "feature_col" val vectorAssembler = new VectorAssembler() .setHandleInvalid("keep") - .setInputCols(featureNames.toArray) + .setInputCols(featureNames) .setOutputCol(featureColName) val testDf = vectorAssembler.transform(rawInput).select(featureColName, labelName) // Since GPU model does not know the information about the features col name that CPU // transform pipeline requires. End user needs to setFeaturesCol in the model manually - val thrown = intercept[IllegalArgumentException]( + intercept[IllegalArgumentException]( gpuModel .transform(testDf) .collect()) - assert(thrown.getMessage.contains("features does not exist")) val left = gpuModel .setFeaturesCol(featureColName) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala index 8baaafba7ee7..67deb6979628 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoost.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2021 by Contributors + Copyright (c) 2021-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,8 +35,10 @@ import org.apache.commons.logging.LogFactory import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.{Estimator, Model, PipelineStage} import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType} import org.apache.spark.storage.StorageLevel @@ -112,7 +114,7 @@ object PreXGBoost extends PreXGBoostProvider { return optionProvider.get.buildDatasetToRDD(estimator, dataset, params) } - val (packedParams, evalSet) = estimator match { + val (packedParams, evalSet, xgbInput) = estimator match { case est: XGBoostEstimatorCommon => // get weight column, if weight is not defined, default to lit(1.0) val weight = if (!est.isDefined(est.weightCol) || est.getWeightCol.isEmpty) { @@ -136,15 +138,18 @@ object PreXGBoost extends PreXGBoostProvider { } - (PackedParams(col(est.getLabelCol), col(est.getFeaturesCol), weight, baseMargin, group, - est.getNumWorkers, est.needDeterministicRepartitioning), est.getEvalSets(params)) + val (xgbInput, featuresName) = est.vectorize(dataset) + + (PackedParams(col(est.getLabelCol), col(featuresName), weight, baseMargin, group, + est.getNumWorkers, est.needDeterministicRepartitioning), est.getEvalSets(params), + xgbInput) case _ => throw new RuntimeException("Unsupporting " + estimator) } // transform the training Dataset[_] to RDD[XGBLabeledPoint] val trainingSet: RDD[XGBLabeledPoint] = DataUtils.convertDataFrameToXGBLabeledPointRDDs( - packedParams, dataset.asInstanceOf[DataFrame]).head + packedParams, xgbInput.asInstanceOf[DataFrame]).head // transform the eval Dataset[_] to RDD[XGBLabeledPoint] val evalRDDMap = evalSet.map { @@ -184,11 +189,11 @@ object PreXGBoost extends PreXGBoostProvider { } /** get the necessary parameters */ - val (booster, inferBatchSize, featuresCol, useExternalMemory, missing, allowNonZeroForMissing, - predictFunc, schema) = + val (booster, inferBatchSize, xgbInput, featuresCol, useExternalMemory, missing, + allowNonZeroForMissing, predictFunc, schema) = model match { case m: XGBoostClassificationModel => - + val (xgbInput, featuresName) = m.vectorize(dataset) // predict and turn to Row val predictFunc = (broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => { @@ -199,7 +204,7 @@ object PreXGBoost extends PreXGBoostProvider { } // prepare the final Schema - var schema = StructType(dataset.schema.fields ++ + var schema = StructType(xgbInput.schema.fields ++ Seq(StructField(name = XGBoostClassificationModel._rawPredictionCol, dataType = ArrayType(FloatType, containsNull = false), nullable = false)) ++ Seq(StructField(name = XGBoostClassificationModel._probabilityCol, dataType = @@ -214,11 +219,12 @@ object PreXGBoost extends PreXGBoostProvider { ArrayType(FloatType, containsNull = false), nullable = false)) } - (m._booster, m.getInferBatchSize, m.getFeaturesCol, m.getUseExternalMemory, m.getMissing, - m.getAllowNonZeroForMissingValue, predictFunc, schema) + (m._booster, m.getInferBatchSize, xgbInput, featuresName, m.getUseExternalMemory, + m.getMissing, m.getAllowNonZeroForMissingValue, predictFunc, schema) case m: XGBoostRegressionModel => // predict and turn to Row + val (xgbInput, featuresName) = m.vectorize(dataset) val predictFunc = (broadcastBooster: Broadcast[Booster], dm: DMatrix, originalRowItr: Iterator[Row]) => { val Array(rawPredictionItr, predLeafItr, predContribItr) = @@ -227,7 +233,7 @@ object PreXGBoost extends PreXGBoostProvider { } // prepare the final Schema - var schema = StructType(dataset.schema.fields ++ + var schema = StructType(xgbInput.schema.fields ++ Seq(StructField(name = XGBoostRegressionModel._originalPredictionCol, dataType = ArrayType(FloatType, containsNull = false), nullable = false))) @@ -240,14 +246,14 @@ object PreXGBoost extends PreXGBoostProvider { ArrayType(FloatType, containsNull = false), nullable = false)) } - (m._booster, m.getInferBatchSize, m.getFeaturesCol, m.getUseExternalMemory, m.getMissing, - m.getAllowNonZeroForMissingValue, predictFunc, schema) + (m._booster, m.getInferBatchSize, xgbInput, featuresName, m.getUseExternalMemory, + m.getMissing, m.getAllowNonZeroForMissingValue, predictFunc, schema) } - val bBooster = dataset.sparkSession.sparkContext.broadcast(booster) - val appName = dataset.sparkSession.sparkContext.appName + val bBooster = xgbInput.sparkSession.sparkContext.broadcast(booster) + val appName = xgbInput.sparkSession.sparkContext.appName - val resultRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => + val resultRDD = xgbInput.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIterator => new AbstractIterator[Row] { private var batchCnt = 0 @@ -295,7 +301,7 @@ object PreXGBoost extends PreXGBoostProvider { } bBooster.unpersist(blocking = false) - dataset.sparkSession.createDataFrame(resultRDD, schema) + xgbInput.sparkSession.createDataFrame(resultRDD, schema) } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala index c8635d93cc4b..3e62e99465f4 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala @@ -144,13 +144,6 @@ class XGBoostClassifier ( def setSinglePrecisionHistogram(value: Boolean): this.type = set(singlePrecisionHistogram, value) - /** - * This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires - * all feature columns must be numeric types. - */ - def setFeaturesCol(value: Array[String]): this.type = - set(featuresCols, value) - // called at the start of fit/train when 'eval_metric' is not defined private def setupDefaultEvalMetric(): String = { require(isDefined(objective), "Users must set \'objective\' via xgboostParams.") @@ -165,7 +158,12 @@ class XGBoostClassifier ( // Callback from PreXGBoost private[spark] def transformSchemaInternal(schema: StructType): StructType = { - super.transformSchema(schema) + if (isFeaturesColSet(schema)) { + // User has vectorized the features into VectorUDT. + super.transformSchema(schema) + } else { + transformSchemaWithFeaturesCols(true, schema) + } } override def transformSchema(schema: StructType): StructType = { @@ -260,13 +258,6 @@ class XGBoostClassificationModel private[ml]( def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value) - /** - * This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires - * all feature columns must be numeric types. - */ - def setFeaturesCol(value: Array[String]): this.type = - set(featuresCols, value) - /** * Single instance prediction. * Note: The performance is not ideal, use it carefully! @@ -359,7 +350,12 @@ class XGBoostClassificationModel private[ml]( } private[spark] def transformSchemaInternal(schema: StructType): StructType = { - super.transformSchema(schema) + if (isFeaturesColSet(schema)) { + // User has vectorized the features into VectorUDT. + super.transformSchema(schema) + } else { + transformSchemaWithFeaturesCols(false, schema) + } } override def transformSchema(schema: StructType): StructType = { @@ -385,8 +381,6 @@ class XGBoostClassificationModel private[ml]( Vectors.dense(rawPredictions) } - - if ($(rawPredictionCol).nonEmpty) { outputData = outputData .withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol))) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala index 3ca1e7988a6d..9af52d165390 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala @@ -146,13 +146,6 @@ class XGBoostRegressor ( def setSinglePrecisionHistogram(value: Boolean): this.type = set(singlePrecisionHistogram, value) - /** - * This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires - * all feature columns must be numeric types. - */ - def setFeaturesCols(value: Array[String]): this.type = - set(featuresCols, value) - // called at the start of fit/train when 'eval_metric' is not defined private def setupDefaultEvalMetric(): String = { require(isDefined(objective), "Users must set \'objective\' via xgboostParams.") @@ -164,7 +157,12 @@ class XGBoostRegressor ( } private[spark] def transformSchemaInternal(schema: StructType): StructType = { - super.transformSchema(schema) + if (isFeaturesColSet(schema)) { + // User has vectorized the features into VectorUDT. + super.transformSchema(schema) + } else { + transformSchemaWithFeaturesCols(false, schema) + } } override def transformSchema(schema: StructType): StructType = { @@ -253,13 +251,6 @@ class XGBoostRegressionModel private[ml] ( def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value) - /** - * This API is only used in GPU train pipeline of xgboost4j-spark-gpu, which requires - * all feature columns must be numeric types. - */ - def setFeaturesCols(value: Array[String]): this.type = - set(featuresCols, value) - /** * Single instance prediction. * Note: The performance is not ideal, use it carefully! @@ -331,7 +322,12 @@ class XGBoostRegressionModel private[ml] ( } private[spark] def transformSchemaInternal(schema: StructType): StructType = { - super.transformSchema(schema) + if (isFeaturesColSet(schema)) { + // User has vectorized the features into VectorUDT. + super.transformSchema(schema) + } else { + transformSchemaWithFeaturesCols(false, schema) + } } override def transformSchema(schema: StructType): StructType = { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index a75f64dd8aba..2416df0b3b21 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014,2021 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -247,6 +247,27 @@ trait HasNumClass extends Params { final def getNumClass: Int = $(numClass) } +/** + * Trait for shared param featuresCols. + */ +trait HasFeaturesCols extends Params { + /** + * Param for the names of feature columns. + * @group param + */ + final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols", + "an array of feature column names.") + + /** @group getParam */ + final def getFeaturesCols: Array[String] = $(featuresCols) + + /** Check if featuresCols is valid */ + def isFeaturesColsValid: Boolean = { + isDefined(featuresCols) && $(featuresCols) != Array.empty + } + +} + private[spark] trait ParamMapFuncs extends Params { def XGBoost2MLlibParams(xgboostParams: Map[String, Any]): Unit = { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GpuParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GpuParams.scala deleted file mode 100644 index 9ab4c7357095..000000000000 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GpuParams.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - Copyright (c) 2021-2022 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark.params - -import org.apache.spark.ml.param.{Params, StringArrayParam} - -trait GpuParams extends Params { - /** - * Param for the names of feature columns for GPU pipeline. - * @group param - */ - final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols", - "an array of feature column names for GPU pipeline.") - - setDefault(featuresCols, Array.empty[String]) - - /** @group getParam */ - final def getFeaturesCols: Array[String] = $(featuresCols) - -} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala index 025757021fb4..5d2a1c04ea58 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014,2021 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,16 +16,101 @@ package ml.dmlc.xgboost4j.scala.spark.params -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol} +import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.ml.linalg.xgboost.XGBoostSchemaUtils +import org.apache.spark.ml.param.{Param, ParamValidators} +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol, HasWeightCol} +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.types.StructType private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables with HasWeightCol with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol with HasFeaturesCol - with HasLabelCol with GpuParams { + with HasLabelCol with HasFeaturesCols with HasHandleInvalid { def needDeterministicRepartitioning: Boolean = { getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0 } + + /** + * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with + * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the + * output). Column lengths are taken from the size of ML Attribute Group, which can be set using + * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred + * from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. + * Default: "error" + * @group param + */ + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + """Param for how to handle invalid data (NULL and NaN values). Options are 'skip' (filter out + |rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN + |in the output). Column lengths are taken from the size of ML Attribute Group, which can be + |set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also + |be inferred from first rows of the data since it is safe to do so but only in case of 'error' + |or 'skip'.""".stripMargin.replaceAll("\n", " "), + ParamValidators.inArray(Array("skip", "error", "keep"))) + + setDefault(handleInvalid, "error") + + /** + * Specify an array of feature column names which must be numeric types. + */ + def setFeaturesCol(value: Array[String]): this.type = set(featuresCols, value) + + /** Set the handleInvalid for VectorAssembler */ + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + /** + * Check if schema has a field named with the value of "featuresCol" param and it's data type + * must be VectorUDT + */ + def isFeaturesColSet(schema: StructType): Boolean = { + schema.fieldNames.contains(getFeaturesCol) && + XGBoostSchemaUtils.isVectorUDFType(schema(getFeaturesCol).dataType) + } + + /** check the features columns type */ + def transformSchemaWithFeaturesCols(fit: Boolean, schema: StructType): StructType = { + if (isFeaturesColsValid) { + if (fit) { + XGBoostSchemaUtils.checkNumericType(schema, $(labelCol)) + } + $(featuresCols).foreach(feature => + XGBoostSchemaUtils.checkFeatureColumnType(schema(feature).dataType)) + schema + } else { + throw new IllegalArgumentException("featuresCol or featuresCols must be specified") + } + } + + /** + * Vectorize the features columns if necessary. + * + * @param input the input dataset + * @return (output dataset and the feature column name) + */ + def vectorize(input: Dataset[_]): (Dataset[_], String) = { + val schema = input.schema + if (isFeaturesColSet(schema)) { + // Dataset already has vectorized. + (input, getFeaturesCol) + } else if (isFeaturesColsValid) { + val featuresName = if (!schema.fieldNames.contains(getFeaturesCol)) { + getFeaturesCol + } else { + "features_" + uid + } + val vectorAssembler = new VectorAssembler() + .setHandleInvalid($(handleInvalid)) + .setInputCols(getFeaturesCols) + .setOutputCol(featuresName) + (vectorAssembler.transform(input).select(featuresName, getLabelCol), featuresName) + } else { + // never reach here, since transformSchema will take care of the case + // that featuresCols is invalid + (input, getFeaturesCol) + } + } } private[scala] trait XGBoostClassifierParams extends XGBoostEstimatorCommon with HasNumClass diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/linalg/xgboost/XGBoostSchemaUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/linalg/xgboost/XGBoostSchemaUtils.scala new file mode 100644 index 000000000000..0976067ec38f --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/linalg/xgboost/XGBoostSchemaUtils.scala @@ -0,0 +1,51 @@ +/* + Copyright (c) 2022 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package org.apache.spark.ml.linalg.xgboost + +import org.apache.spark.sql.types.{BooleanType, DataType, NumericType, StructType} +import org.apache.spark.ml.linalg.VectorUDT +import org.apache.spark.ml.util.SchemaUtils + +object XGBoostSchemaUtils { + + /** check if the dataType is VectorUDT */ + def isVectorUDFType(dataType: DataType): Boolean = { + dataType match { + case _: VectorUDT => true + case _ => false + } + } + + /** The feature columns will be vectorized by VectorAssembler first, which only + * supports Numeric, Boolean and VectorUDT types */ + def checkFeatureColumnType(dataType: DataType): Unit = { + dataType match { + case _: NumericType | BooleanType => + case _: VectorUDT => + case d => throw new UnsupportedOperationException(s"featuresCols only supports Numeric, " + + s"boolean and VectorUDT types, found: ${d}") + } + } + + def checkNumericType( + schema: StructType, + colName: String, + msg: String = ""): Unit = { + SchemaUtils.checkNumericType(schema, colName, msg) + } + +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index 7940a51e5cad..91f4a4cfa20b 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql._ import org.scalatest.FunSuite import org.apache.spark.Partitioner +import org.apache.spark.ml.feature.VectorAssembler class XGBoostClassifierSuite extends FunSuite with PerTest { @@ -316,4 +317,77 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { xgb.fit(repartitioned) } + test("featuresCols with features column can work") { + val spark = ss + import spark.implicits._ + val xgbInput = Seq( + (Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0), + (Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1)) + .toDF("f1", "f2", "f3", "features", "label") + + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 1) + + val featuresName = Array("f1", "f2", "f3", "features") + val xgbClassifier = new XGBoostClassifier(paramMap) + .setFeaturesCol(featuresName) + .setLabelCol("label") + + val model = xgbClassifier.fit(xgbInput) + assert(model.getFeaturesCols.sameElements(featuresName)) + + val df = model.transform(xgbInput) + assert(df.schema.fieldNames.contains("features_" + model.uid)) + df.show() + + val newFeatureName = "features_new" + // transform also can work for vectorized dataset + val vectorizedInput = new VectorAssembler() + .setInputCols(featuresName) + .setOutputCol(newFeatureName) + .transform(xgbInput) + .select(newFeatureName, "label") + + val df1 = model + .setFeaturesCol(newFeatureName) + .transform(vectorizedInput) + assert(df1.schema.fieldNames.contains(newFeatureName)) + df1.show() + } + + test("featuresCols without features column can work") { + val spark = ss + import spark.implicits._ + val xgbInput = Seq( + (Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0), + (Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1)) + .toDF("f1", "f2", "f3", "f4", "label") + + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 1) + + val featuresName = Array("f1", "f2", "f3", "f4") + val xgbClassifier = new XGBoostClassifier(paramMap) + .setFeaturesCol(featuresName) + .setLabelCol("label") + + val model = xgbClassifier.fit(xgbInput) + assert(model.getFeaturesCols.sameElements(featuresName)) + + // transform should work for the dataset which includes the feature column names. + val df = model.transform(xgbInput) + assert(df.schema.fieldNames.contains("features")) + df.show() + + // transform also can work for vectorized dataset + val vectorizedInput = new VectorAssembler() + .setInputCols(featuresName) + .setOutputCol("features") + .transform(xgbInput) + .select("features", "label") + + val df1 = model.transform(vectorizedInput) + df1.show() + } + } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala index b06ffc9399a5..04e5106402b4 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2022 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,12 +17,15 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} -import org.apache.spark.ml.linalg.Vector + +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.sql.functions._ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types._ import org.scalatest.FunSuite +import org.apache.spark.ml.feature.VectorAssembler + class XGBoostRegressorSuite extends FunSuite with PerTest { protected val treeMethod: String = "auto" @@ -216,4 +219,77 @@ class XGBoostRegressorSuite extends FunSuite with PerTest { assert(resultDF.columns.contains("predictLeaf")) assert(resultDF.columns.contains("predictContrib")) } + + test("featuresCols with features column can work") { + val spark = ss + import spark.implicits._ + val xgbInput = Seq( + (Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0), + (Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1)) + .toDF("f1", "f2", "f3", "features", "label") + + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> 1) + + val featuresName = Array("f1", "f2", "f3", "features") + val xgbClassifier = new XGBoostRegressor(paramMap) + .setFeaturesCol(featuresName) + .setLabelCol("label") + + val model = xgbClassifier.fit(xgbInput) + assert(model.getFeaturesCols.sameElements(featuresName)) + + val df = model.transform(xgbInput) + assert(df.schema.fieldNames.contains("features_" + model.uid)) + df.show() + + val newFeatureName = "features_new" + // transform also can work for vectorized dataset + val vectorizedInput = new VectorAssembler() + .setInputCols(featuresName) + .setOutputCol(newFeatureName) + .transform(xgbInput) + .select(newFeatureName, "label") + + val df1 = model + .setFeaturesCol(newFeatureName) + .transform(vectorizedInput) + assert(df1.schema.fieldNames.contains(newFeatureName)) + df1.show() + } + + test("featuresCols without features column can work") { + val spark = ss + import spark.implicits._ + val xgbInput = Seq( + (Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0), + (Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1)) + .toDF("f1", "f2", "f3", "f4", "label") + + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> 1) + + val featuresName = Array("f1", "f2", "f3", "f4") + val xgbClassifier = new XGBoostRegressor(paramMap) + .setFeaturesCol(featuresName) + .setLabelCol("label") + + val model = xgbClassifier.fit(xgbInput) + assert(model.getFeaturesCols.sameElements(featuresName)) + + // transform should work for the dataset which includes the feature column names. + val df = model.transform(xgbInput) + assert(df.schema.fieldNames.contains("features")) + df.show() + + // transform also can work for vectorized dataset + val vectorizedInput = new VectorAssembler() + .setInputCols(featuresName) + .setOutputCol("features") + .transform(xgbInput) + .select("features", "label") + + val df1 = model.transform(vectorizedInput) + df1.show() + } }