diff --git a/src/metric/survival_metric.cu b/src/metric/survival_metric.cu index ac25e4548919..ed82505f5be4 100644 --- a/src/metric/survival_metric.cu +++ b/src/metric/survival_metric.cu @@ -206,10 +206,6 @@ struct EvalEWiseSurvivalBase : public Metric { bst_float Eval(const HostDeviceVector& preds, const MetaInfo& info, bool distributed) override { - CHECK_NE(info.labels_lower_bound_.Size(), 0U) - << "labels_lower_bound cannot be empty"; - CHECK_NE(info.labels_upper_bound_.Size(), 0U) - << "labels_upper_bound cannot be empty"; CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size()); CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size()); diff --git a/tests/python/test_survival.py b/tests/python/test_survival.py index a4a5d6ac65be..41a618de5bd5 100644 --- a/tests/python/test_survival.py +++ b/tests/python/test_survival.py @@ -52,6 +52,16 @@ def gather_split_thresholds(tree): for tree in model_json: assert gather_split_thresholds(tree).issubset({2.5, 3.5, 4.5}) + +def test_aft_empty_dmatrix(): + X = np.array([]).reshape((0, 2)) + y_lower, y_upper = np.array([]), np.array([]) + dtrain = xgb.DMatrix(X) + dtrain.set_info(label_lower_bound=y_lower, label_upper_bound=y_upper) + bst = xgb.train({'objective': 'survival:aft', 'tree_method': 'hist'}, + dtrain, num_boost_round=2, evals=[(dtrain, 'train')]) + + @pytest.mark.skipif(**tm.no_pandas()) def test_aft_survival_demo_data(): import pandas as pd diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index d8643251275b..ca1da70425bb 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -594,7 +594,6 @@ def test_predict_with_meta(client): def run_aft_survival(client, dmatrix_t): - # survival doesn't handle empty dataset well. df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data', 'veterans_lung_cancer.csv')) y_lower_bound = df['Survival_label_lower_bound'] @@ -632,7 +631,7 @@ def run_aft_survival(client, dmatrix_t): def test_aft_survival(): - with LocalCluster(n_workers=1) as cluster: + with LocalCluster(n_workers=kWorkers) as cluster: with Client(cluster) as client: run_aft_survival(client, DaskDMatrix)