diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala new file mode 100644 index 000000000000..579e3dd37447 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala @@ -0,0 +1,276 @@ +/* + Copyright (c) 2014-2022 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import java.util.concurrent.LinkedBlockingDeque + +import scala.util.Random + +import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker} +import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus +import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker} +import ml.dmlc.xgboost4j.scala.DMatrix +import org.scalatest.FunSuite + +class CommunicatorRobustnessSuite extends FunSuite with PerTest { + + private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = { + val classifier = new XGBoostClassifier(paramMap) + val xgbParamsFactory = new XGBoostExecutionParamsFactory(classifier.MLlib2XGBoostParams, sc) + xgbParamsFactory.buildXGBRuntimeParams + } + + test("Customize host ip and python exec for Rabit tracker") { + val hostIp = "192.168.22.111" + val pythonExec = "/usr/bin/python3" + + val paramMap = Map( + "num_workers" -> numWorkers, + "tracker_conf" -> TrackerConf(0L, "python", hostIp)) + val xgbExecParams = getXGBoostExecutionParams(paramMap) + val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf) + tracker match { + case pyTracker: PyRabitTracker => + val cmd = pyTracker.getRabitTrackerCommand + assert(cmd.contains(hostIp)) + assert(cmd.startsWith("python")) + case _ => assert(false, "expected python tracker implementation") + } + + val paramMap1 = Map( + "num_workers" -> numWorkers, + "tracker_conf" -> TrackerConf(0L, "python", "", pythonExec)) + val xgbExecParams1 = getXGBoostExecutionParams(paramMap1) + val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf) + tracker1 match { + case pyTracker: PyRabitTracker => + val cmd = pyTracker.getRabitTrackerCommand + assert(cmd.startsWith(pythonExec)) + assert(!cmd.contains(hostIp)) + case _ => assert(false, "expected python tracker implementation") + } + + val paramMap2 = Map( + "num_workers" -> numWorkers, + "tracker_conf" -> TrackerConf(0L, "python", hostIp, pythonExec)) + val xgbExecParams2 = getXGBoostExecutionParams(paramMap2) + val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf) + tracker2 match { + case pyTracker: PyRabitTracker => + val cmd = pyTracker.getRabitTrackerCommand + assert(cmd.startsWith(pythonExec)) + assert(cmd.contains(s" --host-ip=${hostIp}")) + case _ => assert(false, "expected python tracker implementation") + } + } + + test("training with Scala-implemented Rabit tracker") { + val eval = new EvalError() + val training = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) + val paramMap = Map("eta" -> "1", "max_depth" -> "6", + "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers, + "tracker_conf" -> TrackerConf(60 * 60 * 1000, "scala")) + val model = new XGBoostClassifier(paramMap).fit(training) + assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1) + } + + test("test Communicator allreduce to validate Scala-implemented Rabit tracker") { + val vectorLength = 100 + val rdd = sc.parallelize( + (1 to numWorkers * vectorLength).toArray.map { _ => Random.nextFloat() }, numWorkers).cache() + + val tracker = new ScalaRabitTracker(numWorkers) + tracker.start(0) + val trackerEnvs = tracker.getWorkerEnvs + val collectedAllReduceResults = new LinkedBlockingDeque[Array[Float]]() + + val rawData = rdd.mapPartitions { iter => + Iterator(iter.toArray) + }.collect() + + val maxVec = (0 until vectorLength).toArray.map { j => + (0 until numWorkers).toArray.map { i => rawData(i)(j) }.max + } + + val allReduceResults = rdd.mapPartitions { iter => + Communicator.init(trackerEnvs) + val arr = iter.toArray + val results = Communicator.allReduce(arr, Communicator.OpType.MAX) + Communicator.shutdown() + Iterator(results) + }.cache() + + val sparkThread = new Thread() { + override def run(): Unit = { + allReduceResults.foreachPartition(() => _) + val byPartitionResults = allReduceResults.collect() + assert(byPartitionResults(0).length == vectorLength) + collectedAllReduceResults.put(byPartitionResults(0)) + } + } + sparkThread.start() + assert(tracker.waitFor(0L) == 0) + sparkThread.join() + + assert(collectedAllReduceResults.poll().sameElements(maxVec)) + } + + test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") { + /* + Deliberately create new instances of SparkContext in each unit test to avoid reusing the + same thread pool spawned by the local mode of Spark. As these tests simulate worker crashes + by throwing exceptions, the crashed worker thread never calls Rabit.shutdown, and therefore + corrupts the internal state of the native Rabit C++ code. Calling Rabit.init() in subsequent + tests on a reentrant thread will crash the entire Spark application, an undesired side-effect + that should be avoided. + */ + val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache() + + val tracker = new PyRabitTracker(numWorkers) + tracker.start(0) + val trackerEnvs = tracker.getWorkerEnvs + + val workerCount: Int = numWorkers + /* + Simulate worker crash events by creating dummy Rabit workers, and throw exceptions in the + last created worker. A cascading event chain will be triggered once the RuntimeException is + thrown: the thread running the dummy spark job (sparkThread) catches the exception and + delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself. + + The Java RabitTracker class reacts to exceptions by killing the spawned process running + the Python tracker. If at least one Rabit worker has yet connected to the tracker before + it is killed, the resulted connection failure will trigger the Rabit worker to call + "exit(-1);" in the native C++ code, effectively ending the dummy Spark task. + + In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are + isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks + running in separate containers. However, as unit tests are run in Spark local mode, in which + tasks are executed by threads belonging to the same process, one thread calling "exit(-1);" + ultimately kills the entire process, which also happens to host the Spark driver, causing + the entire Spark application to crash. + + To prevent unit tests from crashing, deterministic delays were introduced to make sure that + the exception is thrown at last, ideally after all worker connections have been established. + For the same reason, the Java RabitTracker class delays the killing of the Python tracker + process to ensure that pending worker connections are handled. + */ + val dummyTasks = rdd.mapPartitions { iter => + Communicator.init(trackerEnvs) + val index = iter.next() + Thread.sleep(100 + index * 10) + if (index == workerCount) { + // kill the worker by throwing an exception + throw new RuntimeException("Worker exception.") + } + Communicator.shutdown() + Iterator(index) + }.cache() + + val sparkThread = new Thread() { + override def run(): Unit = { + // forces a Spark job. + dummyTasks.foreachPartition(() => _) + } + } + + sparkThread.setUncaughtExceptionHandler(tracker) + sparkThread.start() + assert(tracker.waitFor(0) != 0) + } + + test("test Scala RabitTracker's exception handling: it should not hang forever.") { + val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache() + + val tracker = new ScalaRabitTracker(numWorkers) + tracker.start(0) + val trackerEnvs = tracker.getWorkerEnvs + + val workerCount: Int = numWorkers + val dummyTasks = rdd.mapPartitions { iter => + Communicator.init(trackerEnvs) + val index = iter.next() + Thread.sleep(100 + index * 10) + if (index == workerCount) { + // kill the worker by throwing an exception + throw new RuntimeException("Worker exception.") + } + Communicator.shutdown() + Iterator(index) + }.cache() + + val sparkThread = new Thread() { + override def run(): Unit = { + // forces a Spark job. + dummyTasks.foreachPartition(() => _) + } + } + sparkThread.setUncaughtExceptionHandler(tracker) + sparkThread.start() + assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode) + } + + test("test Scala RabitTracker's workerConnectionTimeout") { + val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache() + + val tracker = new ScalaRabitTracker(numWorkers) + tracker.start(500) + val trackerEnvs = tracker.getWorkerEnvs + + val dummyTasks = rdd.mapPartitions { iter => + val index = iter.next() + // simulate that the first worker cannot connect to tracker due to network issues. + if (index != 1) { + Communicator.init(trackerEnvs) + Thread.sleep(1000) + Communicator.shutdown() + } + + Iterator(index) + }.cache() + + val sparkThread = new Thread() { + override def run(): Unit = { + // forces a Spark job. + dummyTasks.foreachPartition(() => _) + } + } + sparkThread.setUncaughtExceptionHandler(tracker) + sparkThread.start() + // should fail due to connection timeout + assert(tracker.waitFor(0L) == TrackerStatus.FAILURE.getStatusCode) + } + + test("should allow the dataframe containing communicator calls to be partially evaluated for" + + " multiple times (ISSUE-4406)") { + val paramMap = Map( + "eta" -> "1", + "max_depth" -> "6", + "silent" -> "1", + "objective" -> "binary:logistic") + val trainingDF = buildDataFrame(Classification.train) + val model = new XGBoostClassifier(paramMap ++ Array("num_round" -> 10, + "num_workers" -> numWorkers)).fit(trainingDF) + val prediction = model.transform(trainingDF) + // a partial evaluation of dataframe will cause rabit initialized but not shutdown in some + // threads + prediction.show() + // a full evaluation here will re-run init and shutdown all rabit proxy + // expecting no error + prediction.collect() + } +} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Communicator.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Communicator.java new file mode 100644 index 000000000000..795e7d99e8fe --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Communicator.java @@ -0,0 +1,152 @@ +package ml.dmlc.xgboost4j.java; + +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +/** + * Collective communicator global class for synchronization. + * + * Currently the communicator API is experimental, function signatures may change in the future + * without notice. + */ +public class Communicator { + + public enum OpType implements Serializable { + MAX(0), MIN(1), SUM(2); + + private int op; + + public int getOperand() { + return this.op; + } + + OpType(int op) { + this.op = op; + } + } + + public enum DataType implements Serializable { + INT8(0, 1), UINT8(1, 1), INT32(2, 4), UINT32(3, 4), + INT64(4, 8), UINT64(5, 8), FLOAT32(6, 4), FLOAT64(7, 8); + + private final int enumOp; + private final int size; + + public int getEnumOp() { + return this.enumOp; + } + + public int getSize() { + return this.size; + } + + DataType(int enumOp, int size) { + this.enumOp = enumOp; + this.size = size; + } + } + + private static void checkCall(int ret) throws XGBoostError { + if (ret != 0) { + throw new XGBoostError(XGBoostJNI.XGBGetLastError()); + } + } + + // used as way to test/debug passed communicator init parameters + public static Map communicatorEnvs; + public static List mockList = new LinkedList<>(); + + /** + * Initialize the collective communicator on current working thread. + * + * @param envs The additional environment variables to pass to the communicator. + * @throws XGBoostError + */ + public static void init(Map envs) throws XGBoostError { + communicatorEnvs = envs; + String[] args = new String[envs.size() * 2 + mockList.size() * 2]; + int idx = 0; + for (java.util.Map.Entry e : envs.entrySet()) { + args[idx++] = e.getKey(); + args[idx++] = e.getValue(); + } + // pass list of rabit mock strings eg mock=0,1,0,0 + for (String mock : mockList) { + args[idx++] = "mock"; + args[idx++] = mock; + } + checkCall(XGBoostJNI.CommunicatorInit(args)); + } + + /** + * Shutdown the communicator in current working thread, equals to finalize. + * + * @throws XGBoostError + */ + public static void shutdown() throws XGBoostError { + checkCall(XGBoostJNI.CommunicatorFinalize()); + } + + /** + * Print the message via the communicator. + * + * @param msg + * @throws XGBoostError + */ + public static void communicatorPrint(String msg) throws XGBoostError { + checkCall(XGBoostJNI.CommunicatorPrint(msg)); + } + + /** + * get rank of current thread. + * + * @return the rank. + * @throws XGBoostError + */ + public static int getRank() throws XGBoostError { + int[] out = new int[1]; + checkCall(XGBoostJNI.CommunicatorGetRank(out)); + return out[0]; + } + + /** + * get world size of current job. + * + * @return the worldsize + * @throws XGBoostError + */ + public static int getWorldSize() throws XGBoostError { + int[] out = new int[1]; + checkCall(XGBoostJNI.CommunicatorGetWorldSize(out)); + return out[0]; + } + + /** + * perform Allreduce on distributed float vectors using operator op. + * + * @param elements local elements on distributed workers. + * @param op operator used for Allreduce. + * @return All-reduced float elements according to the given operator. + */ + public static float[] allReduce(float[] elements, OpType op) { + DataType dataType = DataType.FLOAT32; + ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length) + .order(ByteOrder.nativeOrder()); + + for (float el : elements) { + buffer.putFloat(el); + } + buffer.flip(); + + XGBoostJNI.CommunicatorAllreduce(buffer, elements.length, dataType.getEnumOp(), + op.getOperand()); + float[] results = new float[elements.length]; + buffer.asFloatBuffer().get(results); + + return results; + } +} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index d2285af90e08..72234f526b08 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -148,6 +148,17 @@ public final static native int XGBoosterDumpModelExWithFeatures( final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count, int enum_dtype, int enum_op); + // communicator functions + public final static native int CommunicatorInit(String[] args); + public final static native int CommunicatorFinalize(); + public final static native int CommunicatorPrint(String msg); + public final static native int CommunicatorGetRank(int[] out); + public final static native int CommunicatorGetWorldSize(int[] out); + + // Perform Allreduce operation on data in sendrecvbuf. + final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count, + int enum_dtype, int enum_op); + public final static native int XGDMatrixSetInfoFromInterface( long handle, String field, String json); diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 630040731e67..a89e0f07a341 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -977,6 +977,89 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce return 0; } +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: CommunicatorInit + * Signature: ([Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit + (JNIEnv *jenv, jclass jcls, jobjectArray jargs) { + xgboost::Json config{xgboost::Object{}}; + bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs); + assert(len % 2 == 0); + for (bst_ulong i = 0; i < len / 2; ++i) { + jstring key = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i); + std::string key_str(jenv->GetStringUTFChars(key, 0), jenv->GetStringLength(key)); + jstring value = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i + 1); + std::string value_str(jenv->GetStringUTFChars(value, 0), jenv->GetStringLength(value)); + config[key_str] = xgboost::String(value_str); + } + std::string json_str; + xgboost::Json::Dump(config, &json_str); + JVM_CHECK_CALL(XGCommunicatorInit(json_str.c_str())); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: CommunicatorFinalize + * Signature: ()I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize + (JNIEnv *jenv, jclass jcls) { + JVM_CHECK_CALL(XGCommunicatorFinalize()); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: CommunicatorPrint + * Signature: (Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorPrint + (JNIEnv *jenv, jclass jcls, jstring jmsg) { + std::string str(jenv->GetStringUTFChars(jmsg, 0), + jenv->GetStringLength(jmsg)); + JVM_CHECK_CALL(XGCommunicatorPrint(str.c_str())); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: CommunicatorGetRank + * Signature: ([I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRank + (JNIEnv *jenv, jclass jcls, jintArray jout) { + jint rank = XGCommunicatorGetRank(); + jenv->SetIntArrayRegion(jout, 0, 1, &rank); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: CommunicatorGetWorldSize + * Signature: ([I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize + (JNIEnv *jenv, jclass jcls, jintArray jout) { + jint out = XGCommunicatorGetWorldSize(); + jenv->SetIntArrayRegion(jout, 0, 1, &out); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: CommunicatorAllreduce + * Signature: (Ljava/nio/ByteBuffer;III)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorAllreduce + (JNIEnv *jenv, jclass jcls, jobject jsendrecvbuf, jint jcount, jint jenum_dtype, jint jenum_op) { + void *ptr_sendrecvbuf = jenv->GetDirectBufferAddress(jsendrecvbuf); + JVM_CHECK_CALL(XGCommunicatorAllreduce(ptr_sendrecvbuf, (size_t) jcount, jenum_dtype, jenum_op)); + return 0; +} + namespace xgboost { namespace jni { XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls, diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index 2db64a16992c..7baae983cf51 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -335,6 +335,54 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce (JNIEnv *, jclass, jobject, jint, jint, jint); +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: CommunicatorInit + * Signature: ([Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit + (JNIEnv *, jclass, jobjectArray); + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: CommunicatorFinalize + * Signature: ()I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize + (JNIEnv *, jclass); + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: CommunicatorPrint + * Signature: (Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorPrint + (JNIEnv *, jclass, jstring); + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: CommunicatorGetRank + * Signature: ([I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRank + (JNIEnv *, jclass, jintArray); + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: CommunicatorGetWorldSize + * Signature: ([I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize + (JNIEnv *, jclass, jintArray); + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: CommunicatorAllreduce + * Signature: (Ljava/nio/ByteBuffer;III)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorAllreduce + (JNIEnv *, jclass, jobject, jint, jint, jint); + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGDMatrixSetInfoFromInterface