Skip to content

Commit

Permalink
Add getData function for DMatrix similar for jvm-packages to python API
Browse files Browse the repository at this point in the history
  • Loading branch information
jinmfeng001 committed Jul 13, 2023
1 parent a1367ea commit b170d57
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 2 deletions.
Expand Up @@ -395,6 +395,36 @@ public float[] getBaseMargin() throws XGBoostError {
return getFloatInfo("base_margin");
}

/**
* Get feature data as BigDenseMatrix
* @return feature Matrix
* @throws XGBoostError
*/
public BigDenseMatrix getData() throws XGBoostError {
int rowNum = (int) rowNum();
int colNum = (int) colNum();
int nonMissingNum = (int) nonMissingNum();

long[] rowOffset = new long[rowNum + 1];
int[] featureIndex = new int[nonMissingNum];
float[] featureValue = new float[nonMissingNum];

XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetDataAsCSR(handle, "{}",
rowOffset, featureIndex, featureValue));

BigDenseMatrix denseMatrix = new BigDenseMatrix(rowNum, colNum);

for (int row = 0; row < rowNum; row++) {
int rowStart = (int)rowOffset[row];
int rowEnd = (int)rowOffset[row + 1];
for(int idx = rowStart; idx < rowEnd; idx ++) {
denseMatrix.set(row, featureIndex[idx], featureValue[idx]);
}
}

return denseMatrix;
}

/**
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
*
Expand Down Expand Up @@ -422,6 +452,17 @@ public long rowNum() throws XGBoostError {
return rowNum[0];
}

/**
* Get the col number of DMatrix
* @return number of columns
* @throws XGBoostError native error
*/
public long colNum() throws XGBoostError {
long[] colNum = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixNumCol(handle, colNum));
return colNum[0];
}

/**
* Get the number of non-missing values of DMatrix.
*
Expand Down
Expand Up @@ -100,6 +100,9 @@ public final static native int XGDMatrixGetStrFeatureInfo(long handle, String fi
long[] outLength, String[][] outValues);

public final static native int XGDMatrixNumRow(long handle, long[] row);

public final static native int XGDMatrixNumCol(long handle, long[] col);

public final static native int XGDMatrixNumNonMissing(long handle, long[] nonMissings);

public final static native int XGBoosterCreate(long[] handles, long[] out);
Expand Down Expand Up @@ -168,4 +171,6 @@ public final static native int XGDMatrixCreateFromArrayInterfaceColumns(

public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);

public final static native int XGDMatrixGetDataAsCSR(long handle, String config, long[] rowOffset,
int[] featureIndex, float[] featureValue);
}
Expand Up @@ -19,6 +19,7 @@ package ml.dmlc.xgboost4j.scala
import _root_.scala.collection.JavaConverters._

import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.java.util.BigDenseMatrix
import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, DataBatch, XGBoostError, DMatrix => JDMatrix}

class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
Expand Down Expand Up @@ -283,6 +284,16 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
jDMatrix.getFeatureTypes
}

/**
* get raw data from DMatrix as BigDenseMatrix
* @throws ml.dmlc.xgboost4j.java.XGBoostError
* @return
*/
@throws(classOf[XGBoostError])
def getData: BigDenseMatrix = {
jDMatrix.getData
}

/**
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
*
Expand All @@ -304,6 +315,17 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
jDMatrix.rowNum
}

/**
* get the col number of DMatrix
*
* @throws ml.dmlc.xgboost4j.java.XGBoostError
* @return
*/
@throws(classOf[XGBoostError])
def colNum: Long = {
jDMatrix.colNum
}

/**
* Get the number of non-missing values of DMatrix.
*
Expand Down
42 changes: 42 additions & 0 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Expand Up @@ -495,6 +495,21 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumRow
return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixNumCol
* Signature: (J[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumCol(
JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) {
DMatrixHandle handle = (DMatrixHandle)jhandle;
bst_ulong result[1];
int ret = (jint)XGDMatrixNumCol(handle, result);
JVM_CHECK_CALL(ret);
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *)result);
return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixNumNonMissing
Expand Down Expand Up @@ -1213,3 +1228,30 @@ Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo(

return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixGetDataAsCSR
* Signature: (Ljava/lang/String;[J[I[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetDataAsCSR
(JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jconfig, jlongArray jrowOffset,
jintArray jfeatureIndex, jfloatArray jfeatureValue) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char *config = jenv->GetStringUTFChars(jconfig, 0);

const int offsetNum = jenv->GetArrayLength(jrowOffset);
const int nonMissingNum = jenv->GetArrayLength(jfeatureIndex);
bst_ulong rowOffset[offsetNum];
unsigned int featureIndex[nonMissingNum];
float featureValue[nonMissingNum];

int ret = XGDMatrixGetDataAsCSR(handle, config, &rowOffset[0], &featureIndex[0], &featureValue[0]);

jenv->SetLongArrayRegion(jrowOffset, 0, offsetNum, (const long *)rowOffset);
jenv->SetIntArrayRegion(jfeatureIndex, 0, nonMissingNum, (const int *)featureIndex);
jenv->SetFloatArrayRegion(jfeatureValue, 0, nonMissingNum, (const float *)featureValue);

JVM_CHECK_CALL(ret);
return ret;
}
16 changes: 16 additions & 0 deletions jvm-packages/xgboost4j/src/native/xgboost4j.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Expand Up @@ -428,4 +428,26 @@ public void testSetAndGetFeatureInfo() throws XGBoostError {
String[] retFeatureTypes = dmat.getFeatureTypes();
assertArrayEquals(featureTypes, retFeatureTypes);
}

@Test
public void testGetDataMatrixForCSR() throws XGBoostError {
//create Matrix from csr format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2 3 0
* 4 0 2 3 5
* 3 1 2 5 0
*/
float[] data = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
int[] colIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
long[] rowHeaders = new long[]{0, 3, 7, 11};
DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR, 5);
BigDenseMatrix denseMatrix = dmat1.getData();

TestCase.assertTrue(denseMatrix.get(0, 0) == 1.0f);
TestCase.assertTrue(denseMatrix.get(0, 3) == 3.0f);
TestCase.assertTrue(denseMatrix.get(1, 2) == 2.0f);
TestCase.assertTrue(denseMatrix.get(2, 3) == 5.0f);
TestCase.assertTrue(denseMatrix.get(2, 4) == 0.0f);
}
}
Expand Up @@ -16,10 +16,10 @@

package ml.dmlc.xgboost4j.scala

import java.util.Arrays
import ml.dmlc.xgboost4j.java.DMatrix.SparseType

import java.util.Arrays
import scala.util.Random

import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}

Expand Down Expand Up @@ -173,4 +173,28 @@ class DMatrixSuite extends AnyFunSuite {
assert(dmat0.rowNum === 10)
assert(dmat0.getLabel.length === 10)
}

test("create get data from DMatrix as BigDenseMatrix") {
// create Matrix from csr format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2 3 0
* 4 0 2 3 5
* 3 1 2 5 0
*/
val data = Array[Float](1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5)
val colIndex = Array[Int](0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3)
val rowHeaders = Array[Long](0, 3, 7, 11)
val dmatrix = new DMatrix(rowHeaders, colIndex, data, SparseType.CSR, 5)

val denseMatrix = dmatrix.getData

assert(denseMatrix.get(0, 0) == 1.0f)
assert(denseMatrix.get(0, 3) == 3.0f)
assert(denseMatrix.get(1, 2) == 2.0f)
assert(denseMatrix.get(2, 3) == 5.0f)
assert(denseMatrix.get(2, 4) == 0.0f)

}

}

0 comments on commit b170d57

Please sign in to comment.