From b170d57b036503e328ad41500b8fc2c7058bb590 Mon Sep 17 00:00:00 2001 From: jinmfeng Date: Wed, 12 Jul 2023 16:05:04 +0800 Subject: [PATCH] Add getData function for DMatrix similar for jvm-packages to python API --- .../java/ml/dmlc/xgboost4j/java/DMatrix.java | 41 ++++++++++++++++++ .../ml/dmlc/xgboost4j/java/XGBoostJNI.java | 5 +++ .../ml/dmlc/xgboost4j/scala/DMatrix.scala | 22 ++++++++++ .../xgboost4j/src/native/xgboost4j.cpp | 42 +++++++++++++++++++ jvm-packages/xgboost4j/src/native/xgboost4j.h | 16 +++++++ .../ml/dmlc/xgboost4j/java/DMatrixTest.java | 22 ++++++++++ .../dmlc/xgboost4j/scala/DMatrixSuite.scala | 28 ++++++++++++- 7 files changed, 174 insertions(+), 2 deletions(-) diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java index 2e7540bd2b30..af0a60bc5e1f 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java @@ -395,6 +395,36 @@ public float[] getBaseMargin() throws XGBoostError { return getFloatInfo("base_margin"); } + /** + * Get feature data as BigDenseMatrix + * @return feature Matrix + * @throws XGBoostError + */ + public BigDenseMatrix getData() throws XGBoostError { + int rowNum = (int) rowNum(); + int colNum = (int) colNum(); + int nonMissingNum = (int) nonMissingNum(); + + long[] rowOffset = new long[rowNum + 1]; + int[] featureIndex = new int[nonMissingNum]; + float[] featureValue = new float[nonMissingNum]; + + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetDataAsCSR(handle, "{}", + rowOffset, featureIndex, featureValue)); + + BigDenseMatrix denseMatrix = new BigDenseMatrix(rowNum, colNum); + + for (int row = 0; row < rowNum; row++) { + int rowStart = (int)rowOffset[row]; + int rowEnd = (int)rowOffset[row + 1]; + for(int idx = rowStart; idx < rowEnd; idx ++) { + denseMatrix.set(row, featureIndex[idx], featureValue[idx]); + } + } + + return denseMatrix; + } + /** * Slice the DMatrix and return a new DMatrix that only contains `rowIndex`. * @@ -422,6 +452,17 @@ public long rowNum() throws XGBoostError { return rowNum[0]; } + /** + * Get the col number of DMatrix + * @return number of columns + * @throws XGBoostError native error + */ + public long colNum() throws XGBoostError { + long[] colNum = new long[1]; + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixNumCol(handle, colNum)); + return colNum[0]; + } + /** * Get the number of non-missing values of DMatrix. * diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index abe584f05fe4..ba04196d8a84 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -100,6 +100,9 @@ public final static native int XGDMatrixGetStrFeatureInfo(long handle, String fi long[] outLength, String[][] outValues); public final static native int XGDMatrixNumRow(long handle, long[] row); + + public final static native int XGDMatrixNumCol(long handle, long[] col); + public final static native int XGDMatrixNumNonMissing(long handle, long[] nonMissings); public final static native int XGBoosterCreate(long[] handles, long[] out); @@ -168,4 +171,6 @@ public final static native int XGDMatrixCreateFromArrayInterfaceColumns( public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out); + public final static native int XGDMatrixGetDataAsCSR(long handle, String config, long[] rowOffset, + int[] featureIndex, float[] featureValue); } diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala index 714adf726292..f79cb6bcdd27 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala @@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala import _root_.scala.collection.JavaConverters._ import ml.dmlc.xgboost4j.LabeledPoint +import ml.dmlc.xgboost4j.java.util.BigDenseMatrix import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, DataBatch, XGBoostError, DMatrix => JDMatrix} class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { @@ -283,6 +284,16 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { jDMatrix.getFeatureTypes } + /** + * get raw data from DMatrix as BigDenseMatrix + * @throws ml.dmlc.xgboost4j.java.XGBoostError + * @return + */ + @throws(classOf[XGBoostError]) + def getData: BigDenseMatrix = { + jDMatrix.getData + } + /** * Slice the DMatrix and return a new DMatrix that only contains `rowIndex`. * @@ -304,6 +315,17 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { jDMatrix.rowNum } + /** + * get the col number of DMatrix + * + * @throws ml.dmlc.xgboost4j.java.XGBoostError + * @return + */ + @throws(classOf[XGBoostError]) + def colNum: Long = { + jDMatrix.colNum + } + /** * Get the number of non-missing values of DMatrix. * diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index a61a68dbcb88..23e7667d44bd 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -495,6 +495,21 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumRow return ret; } +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGDMatrixNumCol + * Signature: (J[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumCol( + JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) { + DMatrixHandle handle = (DMatrixHandle)jhandle; + bst_ulong result[1]; + int ret = (jint)XGDMatrixNumCol(handle, result); + JVM_CHECK_CALL(ret); + jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *)result); + return ret; +} + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGDMatrixNumNonMissing @@ -1213,3 +1228,30 @@ Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo( return ret; } + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGDMatrixGetDataAsCSR + * Signature: (Ljava/lang/String;[J[I[F)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetDataAsCSR + (JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jconfig, jlongArray jrowOffset, + jintArray jfeatureIndex, jfloatArray jfeatureValue) { + DMatrixHandle handle = (DMatrixHandle) jhandle; + const char *config = jenv->GetStringUTFChars(jconfig, 0); + + const int offsetNum = jenv->GetArrayLength(jrowOffset); + const int nonMissingNum = jenv->GetArrayLength(jfeatureIndex); + bst_ulong rowOffset[offsetNum]; + unsigned int featureIndex[nonMissingNum]; + float featureValue[nonMissingNum]; + + int ret = XGDMatrixGetDataAsCSR(handle, config, &rowOffset[0], &featureIndex[0], &featureValue[0]); + + jenv->SetLongArrayRegion(jrowOffset, 0, offsetNum, (const long *)rowOffset); + jenv->SetIntArrayRegion(jfeatureIndex, 0, nonMissingNum, (const int *)featureIndex); + jenv->SetFloatArrayRegion(jfeatureValue, 0, nonMissingNum, (const float *)featureValue); + + JVM_CHECK_CALL(ret); + return ret; +} diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index 11a2f86ffb82..7b8635f18c5c 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -143,6 +143,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFea JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumRow (JNIEnv *, jclass, jlong, jlongArray); +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGDMatrixNumCol + * Signature: (J[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumCol + (JNIEnv *, jclass, jlong, jlongArray); + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGDMatrixNumNonMissing @@ -401,6 +409,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo (JNIEnv *, jclass, jlong, jstring, jobjectArray); +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGDMatrixGetDataAsCSR + * Signature: (Ljava/lang/String;[J[I[F)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetDataAsCSR + (JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray, jfloatArray); + #ifdef __cplusplus } #endif diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java index d658c55292c4..c4aafed6757c 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java @@ -428,4 +428,26 @@ public void testSetAndGetFeatureInfo() throws XGBoostError { String[] retFeatureTypes = dmat.getFeatureTypes(); assertArrayEquals(featureTypes, retFeatureTypes); } + + @Test + public void testGetDataMatrixForCSR() throws XGBoostError { + //create Matrix from csr format sparse Matrix and labels + /** + * sparse matrix + * 1 0 2 3 0 + * 4 0 2 3 5 + * 3 1 2 5 0 + */ + float[] data = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5}; + int[] colIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3}; + long[] rowHeaders = new long[]{0, 3, 7, 11}; + DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR, 5); + BigDenseMatrix denseMatrix = dmat1.getData(); + + TestCase.assertTrue(denseMatrix.get(0, 0) == 1.0f); + TestCase.assertTrue(denseMatrix.get(0, 3) == 3.0f); + TestCase.assertTrue(denseMatrix.get(1, 2) == 2.0f); + TestCase.assertTrue(denseMatrix.get(2, 3) == 5.0f); + TestCase.assertTrue(denseMatrix.get(2, 4) == 0.0f); + } } diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala index 53325effa6ab..46ea9a592cc9 100644 --- a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala +++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala @@ -16,10 +16,10 @@ package ml.dmlc.xgboost4j.scala -import java.util.Arrays +import ml.dmlc.xgboost4j.java.DMatrix.SparseType +import java.util.Arrays import scala.util.Random - import org.scalatest.funsuite.AnyFunSuite import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} @@ -173,4 +173,28 @@ class DMatrixSuite extends AnyFunSuite { assert(dmat0.rowNum === 10) assert(dmat0.getLabel.length === 10) } + + test("create get data from DMatrix as BigDenseMatrix") { + // create Matrix from csr format sparse Matrix and labels + /** + * sparse matrix + * 1 0 2 3 0 + * 4 0 2 3 5 + * 3 1 2 5 0 + */ + val data = Array[Float](1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5) + val colIndex = Array[Int](0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3) + val rowHeaders = Array[Long](0, 3, 7, 11) + val dmatrix = new DMatrix(rowHeaders, colIndex, data, SparseType.CSR, 5) + + val denseMatrix = dmatrix.getData + + assert(denseMatrix.get(0, 0) == 1.0f) + assert(denseMatrix.get(0, 3) == 3.0f) + assert(denseMatrix.get(1, 2) == 2.0f) + assert(denseMatrix.get(2, 3) == 5.0f) + assert(denseMatrix.get(2, 4) == 0.0f) + + } + }