Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[jvm-packages][xgboost4j-gpu] Support GPU dataframe and `DeviceQuanti…
…leDMatrix` (#7195) Following classes are added to support dataframe in java binding: - `Column` is an abstract type for a single column in tabular data. - `ColumnBatch` is an abstract type for dataframe. - `CuDFColumn` is an implementaiton of `Column` that consume cuDF column - `CudfColumnBatch` is an implementation of `ColumnBatch` that consumes cuDF dataframe. - `DeviceQuantileDMatrix` is the interface for quantized data. The Java implementation mimics the Python interface and uses `__cuda_array_interface__` protocol for memory indexing. One difference is on JVM package, the data batch is staged on the host as java iterators cannot be reset. Co-authored-by: jiamingy <jm.yuan@outlook.com>
- Loading branch information
1 parent
d27a427
commit 0ee11da
Showing
23 changed files
with
1,388 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
110 changes: 110 additions & 0 deletions
110
jvm-packages/xgboost4j-gpu/src/main/java/ml/dmlc/xgboost4j/gpu/java/CudfColumn.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
/* | ||
Copyright (c) 2021 by Contributors | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package ml.dmlc.xgboost4j.gpu.java; | ||
|
||
import ai.rapids.cudf.BaseDeviceMemoryBuffer; | ||
import ai.rapids.cudf.BufferType; | ||
import ai.rapids.cudf.ColumnVector; | ||
import ai.rapids.cudf.DType; | ||
|
||
import ml.dmlc.xgboost4j.java.Column; | ||
|
||
/** | ||
* This class is composing of base data with Apache Arrow format from Cudf ColumnVector. | ||
* It will be used to generate the cuda array interface. | ||
*/ | ||
class CudfColumn extends Column { | ||
|
||
private final long dataPtr; // gpu data buffer address | ||
private final long shape; // row count | ||
private final long validPtr; // gpu valid buffer address | ||
private final int typeSize; // type size in bytes | ||
private final String typeStr; // follow array interface spec | ||
private final long nullCount; // null count | ||
|
||
private String arrayInterface = null; // the cuda array interface | ||
|
||
public static CudfColumn from(ColumnVector cv) { | ||
BaseDeviceMemoryBuffer dataBuffer = cv.getDeviceBufferFor(BufferType.DATA); | ||
BaseDeviceMemoryBuffer validBuffer = cv.getDeviceBufferFor(BufferType.VALIDITY); | ||
long validPtr = 0; | ||
if (validBuffer != null) { | ||
validPtr = validBuffer.getAddress(); | ||
} | ||
DType dType = cv.getType(); | ||
String typeStr = ""; | ||
if (dType == DType.FLOAT32 || dType == DType.FLOAT64 || | ||
dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS || | ||
dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS || | ||
dType == DType.TIMESTAMP_SECONDS) { | ||
typeStr = "<f" + dType.getSizeInBytes(); | ||
} else if (dType == DType.BOOL8 || dType == DType.INT8 || dType == DType.INT16 || | ||
dType == DType.INT32 || dType == DType.INT64) { | ||
typeStr = "<i" + dType.getSizeInBytes(); | ||
} else { | ||
// Unsupported type. | ||
throw new IllegalArgumentException("Unsupported data type: " + dType); | ||
} | ||
|
||
return new CudfColumn(dataBuffer.getAddress(), cv.getRowCount(), validPtr, | ||
dType.getSizeInBytes(), typeStr, cv.getNullCount()); | ||
} | ||
|
||
private CudfColumn(long dataPtr, long shape, long validPtr, int typeSize, String typeStr, | ||
long nullCount) { | ||
this.dataPtr = dataPtr; | ||
this.shape = shape; | ||
this.validPtr = validPtr; | ||
this.typeSize = typeSize; | ||
this.typeStr = typeStr; | ||
this.nullCount = nullCount; | ||
} | ||
|
||
@Override | ||
public String getArrayInterfaceJson() { | ||
// There is no race-condition | ||
if (arrayInterface == null) { | ||
arrayInterface = CudfUtils.buildArrayInterface(this); | ||
} | ||
return arrayInterface; | ||
} | ||
|
||
public long getDataPtr() { | ||
return dataPtr; | ||
} | ||
|
||
public long getShape() { | ||
return shape; | ||
} | ||
|
||
public long getValidPtr() { | ||
return validPtr; | ||
} | ||
|
||
public int getTypeSize() { | ||
return typeSize; | ||
} | ||
|
||
public String getTypeStr() { | ||
return typeStr; | ||
} | ||
|
||
public long getNullCount() { | ||
return nullCount; | ||
} | ||
|
||
} |
88 changes: 88 additions & 0 deletions
88
jvm-packages/xgboost4j-gpu/src/main/java/ml/dmlc/xgboost4j/gpu/java/CudfColumnBatch.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
/* | ||
Copyright (c) 2021 by Contributors | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package ml.dmlc.xgboost4j.gpu.java; | ||
|
||
import java.util.stream.IntStream; | ||
|
||
import ai.rapids.cudf.Table; | ||
|
||
import ml.dmlc.xgboost4j.java.ColumnBatch; | ||
|
||
/** | ||
* Class to wrap CUDF Table to generate the cuda array interface. | ||
*/ | ||
public class CudfColumnBatch extends ColumnBatch { | ||
private final Table feature; | ||
private final Table label; | ||
private final Table weight; | ||
private final Table baseMargin; | ||
|
||
public CudfColumnBatch(Table feature, Table labels, Table weights, Table baseMargins) { | ||
this.feature = feature; | ||
this.label = labels; | ||
this.weight = weights; | ||
this.baseMargin = baseMargins; | ||
} | ||
|
||
@Override | ||
public String getFeatureArrayInterface() { | ||
return getArrayInterface(this.feature); | ||
} | ||
|
||
@Override | ||
public String getLabelsArrayInterface() { | ||
return getArrayInterface(this.label); | ||
} | ||
|
||
@Override | ||
public String getWeightsArrayInterface() { | ||
return getArrayInterface(this.weight); | ||
} | ||
|
||
@Override | ||
public String getBaseMarginsArrayInterface() { | ||
return getArrayInterface(this.baseMargin); | ||
} | ||
|
||
@Override | ||
public void close() { | ||
if (feature != null) feature.close(); | ||
if (label != null) label.close(); | ||
if (weight != null) weight.close(); | ||
if (baseMargin != null) baseMargin.close(); | ||
} | ||
|
||
private String getArrayInterface(Table table) { | ||
if (table == null || table.getNumberOfColumns() == 0) { | ||
return ""; | ||
} | ||
return CudfUtils.buildArrayInterface(getAsCudfColumn(table)); | ||
} | ||
|
||
private CudfColumn[] getAsCudfColumn(Table table) { | ||
if (table == null || table.getNumberOfColumns() == 0) { | ||
// This will never happen. | ||
return new CudfColumn[]{}; | ||
} | ||
|
||
return IntStream.range(0, table.getNumberOfColumns()) | ||
.mapToObj((i) -> table.getColumn(i)) | ||
.map(CudfColumn::from) | ||
.toArray(CudfColumn[]::new); | ||
} | ||
|
||
} |
100 changes: 100 additions & 0 deletions
100
jvm-packages/xgboost4j-gpu/src/main/java/ml/dmlc/xgboost4j/gpu/java/CudfUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
/* | ||
Copyright (c) 2021 by Contributors | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package ml.dmlc.xgboost4j.gpu.java; | ||
|
||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
|
||
import com.fasterxml.jackson.core.JsonFactory; | ||
import com.fasterxml.jackson.core.JsonGenerator; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import com.fasterxml.jackson.databind.node.ArrayNode; | ||
import com.fasterxml.jackson.databind.node.JsonNodeFactory; | ||
import com.fasterxml.jackson.databind.node.ObjectNode; | ||
|
||
/** | ||
* Cudf utilities to build cuda array interface against {@link CudfColumn} | ||
*/ | ||
class CudfUtils { | ||
|
||
/** | ||
* Build the cuda array interface based on CudfColumn(s) | ||
* @param cudfColumns the CudfColumn(s) to be built | ||
* @return the json format of cuda array interface | ||
*/ | ||
public static String buildArrayInterface(CudfColumn... cudfColumns) { | ||
return new Builder().add(cudfColumns).build(); | ||
} | ||
|
||
// Helper class to build array interface string | ||
private static class Builder { | ||
private JsonNodeFactory nodeFactory = new JsonNodeFactory(false); | ||
private ArrayNode rootArrayNode = nodeFactory.arrayNode(); | ||
|
||
private Builder add(CudfColumn... columns) { | ||
if (columns == null || columns.length <= 0) { | ||
throw new IllegalArgumentException("At least one ColumnData is required."); | ||
} | ||
for (CudfColumn cd : columns) { | ||
rootArrayNode.add(buildColumnObject(cd)); | ||
} | ||
return this; | ||
} | ||
|
||
private String build() { | ||
try { | ||
ByteArrayOutputStream bos = new ByteArrayOutputStream(); | ||
JsonGenerator jsonGen = new JsonFactory().createGenerator(bos); | ||
new ObjectMapper().writeTree(jsonGen, rootArrayNode); | ||
return bos.toString(); | ||
} catch (IOException ie) { | ||
ie.printStackTrace(); | ||
throw new RuntimeException("Failed to build array interface. Error: " + ie); | ||
} | ||
} | ||
|
||
private ObjectNode buildColumnObject(CudfColumn column) { | ||
if (column.getDataPtr() == 0) { | ||
throw new IllegalArgumentException("Empty column data is NOT accepted!"); | ||
} | ||
if (column.getTypeStr() == null || column.getTypeStr().isEmpty()) { | ||
throw new IllegalArgumentException("Empty type string is NOT accepted!"); | ||
} | ||
ObjectNode colDataObj = buildMetaObject(column.getDataPtr(), column.getShape(), | ||
column.getTypeStr()); | ||
|
||
if (column.getValidPtr() != 0 && column.getNullCount() != 0) { | ||
ObjectNode validObj = buildMetaObject(column.getValidPtr(), column.getShape(), "<t1"); | ||
colDataObj.set("mask", validObj); | ||
} | ||
return colDataObj; | ||
} | ||
|
||
private ObjectNode buildMetaObject(long ptr, long shape, final String typeStr) { | ||
ObjectNode objNode = nodeFactory.objectNode(); | ||
ArrayNode shapeNode = objNode.putArray("shape"); | ||
shapeNode.add(shape); | ||
ArrayNode dataNode = objNode.putArray("data"); | ||
dataNode.add(ptr) | ||
.add(false); | ||
objNode.put("typestr", typeStr) | ||
.put("version", 1); | ||
return objNode; | ||
} | ||
} | ||
|
||
} |
1 change: 1 addition & 0 deletions
1
jvm-packages/xgboost4j-gpu/src/main/java/ml/dmlc/xgboost4j/java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../../../../../xgboost4j/src/main/java/ml/dmlc/xgboost4j/java |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#ifndef JVM_UTILS_H_ | ||
#define JVM_UTILS_H_ | ||
|
||
#define JVM_CHECK_CALL(__expr) \ | ||
{ \ | ||
int __errcode = (__expr); \ | ||
if (__errcode != 0) { \ | ||
return __errcode; \ | ||
} \ | ||
} | ||
|
||
JavaVM*& GlobalJvm(); | ||
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle); | ||
|
||
#endif // JVM_UTILS_H_ |
Oops, something went wrong.