From 77d4a53c3201fff3e7f198b71249a5c8e2f22fa0 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 16 May 2022 21:15:41 -0700 Subject: [PATCH] use RabitContext intead of init/finalize (#7911) --- python-package/xgboost/rabit.py | 4 +- tests/distributed/distributed_gpu.py | 78 ++++++++------- tests/distributed/test_basic.py | 37 +++---- tests/distributed/test_federated.py | 45 ++++----- tests/distributed/test_issue3402.py | 139 +++++++++++++-------------- tests/python/test_tracker.py | 7 +- 6 files changed, 147 insertions(+), 163 deletions(-) diff --git a/python-package/xgboost/rabit.py b/python-package/xgboost/rabit.py index 465a5611a2d1..f5da7a353330 100644 --- a/python-package/xgboost/rabit.py +++ b/python-package/xgboost/rabit.py @@ -230,7 +230,9 @@ def version_number() -> int: class RabitContext: """A context controlling rabit initialization and finalization.""" - def __init__(self, args: List[bytes]) -> None: + def __init__(self, args: List[bytes] = None) -> None: + if args is None: + args = [] self.args = args def __enter__(self) -> None: diff --git a/tests/distributed/distributed_gpu.py b/tests/distributed/distributed_gpu.py index a2ab6d398018..d10d2aed4884 100644 --- a/tests/distributed/distributed_gpu.py +++ b/tests/distributed/distributed_gpu.py @@ -8,46 +8,44 @@ def run_test(name, params_fun): """Runs a distributed GPU test.""" # Always call this before using distributed module - xgb.rabit.init() - rank = xgb.rabit.get_rank() - world = xgb.rabit.get_world_size() - - # Load file, file will be automatically sharded in distributed mode. - dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') - dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') - - params, n_rounds = params_fun(rank) - - # Specify validations set to watch performance - watchlist = [(dtest, 'eval'), (dtrain, 'train')] - - # Run training, all the features in training API is available. - # Currently, this script only support calling train once for fault recovery purpose. - bst = xgb.train(params, dtrain, n_rounds, watchlist, early_stopping_rounds=2) - - # Have each worker save its model - model_name = "test.model.%s.%d" % (name, rank) - bst.dump_model(model_name, with_stats=True) - xgb.rabit.allreduce(np.ones((1, 1)), xgb.rabit.Op.MAX) # sync - xgb.rabit.tracker_print("Finished training\n") - - if (rank == 0): - for i in range(0, world): - model_name_root = "test.model.%s.%d" % (name, i) - for j in range(0, world): - if i == j: - continue - with open(model_name_root, 'r') as model_root: - contents_root = model_root.read() - model_name_rank = "test.model.%s.%d" % (name, j) - with open(model_name_rank, 'r') as model_rank: - contents_rank = model_rank.read() - if contents_root != contents_rank: - raise Exception( - ('Worker models diverged: test.model.%s.%d ' - 'differs from test.model.%s.%d') % (name, i, name, j)) - - xgb.rabit.finalize() + with xgb.rabit.RabitContext(): + rank = xgb.rabit.get_rank() + world = xgb.rabit.get_world_size() + + # Load file, file will be automatically sharded in distributed mode. + dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') + dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') + + params, n_rounds = params_fun(rank) + + # Specify validations set to watch performance + watchlist = [(dtest, 'eval'), (dtrain, 'train')] + + # Run training, all the features in training API is available. + # Currently, this script only support calling train once for fault recovery purpose. + bst = xgb.train(params, dtrain, n_rounds, watchlist, early_stopping_rounds=2) + + # Have each worker save its model + model_name = "test.model.%s.%d" % (name, rank) + bst.dump_model(model_name, with_stats=True) + xgb.rabit.allreduce(np.ones((1, 1)), xgb.rabit.Op.MAX) # sync + xgb.rabit.tracker_print("Finished training\n") + + if (rank == 0): + for i in range(0, world): + model_name_root = "test.model.%s.%d" % (name, i) + for j in range(0, world): + if i == j: + continue + with open(model_name_root, 'r') as model_root: + contents_root = model_root.read() + model_name_rank = "test.model.%s.%d" % (name, j) + with open(model_name_rank, 'r') as model_rank: + contents_rank = model_rank.read() + if contents_root != contents_rank: + raise Exception( + ('Worker models diverged: test.model.%s.%d ' + 'differs from test.model.%s.%d') % (name, i, name, j)) base_params = { diff --git a/tests/distributed/test_basic.py b/tests/distributed/test_basic.py index f7c1ffee3efc..db2916b39a3c 100644 --- a/tests/distributed/test_basic.py +++ b/tests/distributed/test_basic.py @@ -2,28 +2,23 @@ import xgboost as xgb # Always call this before using distributed module -xgb.rabit.init() +with xgb.rabit.RabitContext(): + # Load file, file will be automatically sharded in distributed mode. + dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') + dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') -# Load file, file will be automatically sharded in distributed mode. -dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train') -dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test') + # Specify parameters via map, definition are same as c++ version + param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} -# Specify parameters via map, definition are same as c++ version -param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} + # Specify validations set to watch performance + watchlist = [(dtest, 'eval'), (dtrain, 'train')] + num_round = 20 -# Specify validations set to watch performance -watchlist = [(dtest, 'eval'), (dtrain, 'train')] -num_round = 20 + # Run training, all the features in training API is available. + # Currently, this script only support calling train once for fault recovery purpose. + bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2) -# Run training, all the features in training API is available. -# Currently, this script only support calling train once for fault recovery purpose. -bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2) - -# Save the model, only ask process 0 to save the model. -if xgb.rabit.get_rank() == 0: - bst.save_model("test.model") - xgb.rabit.tracker_print("Finished training\n") - -# Notify the tracker all training has been successful -# This is only needed in distributed training. -xgb.rabit.finalize() + # Save the model, only ask process 0 to save the model. + if xgb.rabit.get_rank() == 0: + bst.save_model("test.model") + xgb.rabit.tracker_print("Finished training\n") diff --git a/tests/distributed/test_federated.py b/tests/distributed/test_federated.py index 5b5b167fcd32..a3cdbc1e2912 100644 --- a/tests/distributed/test_federated.py +++ b/tests/distributed/test_federated.py @@ -27,31 +27,26 @@ def run_worker(port: int, world_size: int, rank: int) -> None: f'federated_client_key={CLIENT_KEY}', f'federated_client_cert={CLIENT_CERT}' ] - xgb.rabit.init([e.encode() for e in rabit_env]) - - # Load file, file will not be sharded in federated mode. - dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank) - dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank) - - # Specify parameters via map, definition are same as c++ version - param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} - - # Specify validations set to watch performance - watchlist = [(dtest, 'eval'), (dtrain, 'train')] - num_round = 20 - - # Run training, all the features in training API is available. - # Currently, this script only support calling train once for fault recovery purpose. - bst = xgb.train(param, dtrain, num_round, evals=watchlist, early_stopping_rounds=2) - - # Save the model, only ask process 0 to save the model. - if xgb.rabit.get_rank() == 0: - bst.save_model("test.model.json") - xgb.rabit.tracker_print("Finished training\n") - - # Notify the tracker all training has been successful - # This is only needed in distributed training. - xgb.rabit.finalize() + with xgb.rabit.RabitContext([e.encode() for e in rabit_env]): + # Load file, file will not be sharded in federated mode. + dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank) + dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank) + + # Specify parameters via map, definition are same as c++ version + param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} + + # Specify validations set to watch performance + watchlist = [(dtest, 'eval'), (dtrain, 'train')] + num_round = 20 + + # Run training, all the features in training API is available. + bst = xgb.train(param, dtrain, num_round, evals=watchlist, + early_stopping_rounds=2) + + # Save the model, only ask process 0 to save the model. + if xgb.rabit.get_rank() == 0: + bst.save_model("test.model.json") + xgb.rabit.tracker_print("Finished training\n") def run_test() -> None: diff --git a/tests/distributed/test_issue3402.py b/tests/distributed/test_issue3402.py index e3b87931bf67..7a40d3420ebb 100644 --- a/tests/distributed/test_issue3402.py +++ b/tests/distributed/test_issue3402.py @@ -2,78 +2,73 @@ import xgboost as xgb import numpy as np -xgb.rabit.init() +with xgb.rabit.RabitContext(): + X = [ + [15.00,28.90,29.00,3143.70,0.00,0.10,69.90,90.00,13726.07,0.00,2299.70,0.00,0.05, + 4327.03,0.00,24.00,0.18,3.00,0.41,3.77,0.00,0.00,4.00,0.00,150.92,0.00,2.00,0.00, + 0.01,138.00,1.00,0.02,69.90,0.00,0.83,5.00,0.01,0.12,47.30,0.00,296.00,0.16,0.00, + 0.00,27.70,7.00,7.25,4406.16,1.00,0.54,245.28,3.00,0.06,306.50,5143.00,29.00,23.74, + 548.00,2.00,68.00,70.90,25.45,0.39,0.00,0.01,497.11,0.00,42.00,83.00,4.00,0.00,1.00, + 0.00,104.35,94.12,0.03,79.23,237.69,1.00,0.04,0.01,0.02,2.00,108.81,7.00,12.00,0.46, + 31.00,0.00,0.15,74.59,0.00,19.50,0.00,0.75,0.06,0.08,118.00,35.90,0.01,0.07,1.00, + 0.03,81.18,13.33,0.00,0.00,0.00,0.00,0.00,0.41,0.00,0.15,57.00,0.00,22.00,449.68, + 0.00,0.00,2.00,195.26,51.58,306.50,0.10,1.00,0.00,258.00,21.00,0.43,3.00,16.00,0.00, + 0.00,0.00,0.00,1.00,74.51,4.00,0.02,35.90,30.00,8.69,0.00,0.36,5.00,2.00,3.00,0.26, + 9.50,8.00,11.00,11918.15,0.00,258.00,13.00,9.04,0.14,604.65,0.92,74.59,0.00,0.00, + 72.76,1.00,0.22,64.00,2.00,0.00,0.00,0.02,0.00,305.50,27.70,0.02,0.00,177.00,14.00, + 0.00,0.05,90.00,0.03,0.00,1.00,0.43,4.00,0.05,0.09,431.00,0.00,2.00,0.00,0.00,1.00, + 0.25,0.17,0.00,0.00,21.00,94.12,0.17,0.00,0.00,0.00,548.00,0.00,68.00,0.00,0.00,9.50, + 25.45,1390.31,7.00,0.00,2.00,310.70,0.00,0.01,0.01,0.03,81.40,1.00,0.02,0.00,9.00, + 6.00,0.00,175.76,36.00,0.00,20.75,2.00,0.00,0.00,0.00,0.22,74.16,0.10,56.81,0.00, + 2197.03,0.00,197.66,0.00,55.00,20.00,367.18,22.00,0.00,0.01,1510.26,0.24,0.00,0.01, + 0.00,11.00,278.10,61.70,278.10,0.00,0.08,0.57,1.00,0.65,255.60,0.00,0.86,0.25,70.95, + 2299.70,0.23,0.05,92.70,1.00,38.00,0.00,0.00,56.81,21.85,0.00,23.74,0.00,2.00,0.03, + 2.00,0.00,347.58,30.00,243.55,109.00,0.00,296.00,6.00,6.00,0.00,0.00,109.00,2299.70, + 0.00,0.01,0.08,1.00,4745.09,4.00,0.18,0.00,0.17,0.02,0.00,1.00,147.13,71.07,2115.16, + 0.00,0.26,0.00,43.00,604.90,49.44,4327.03,0.68,0.75,0.10,86.36,52.98,0.20,0.00,22.50, + 305.50,0.00,1.00,0.00,7.00,0.78,0.00,296.00,22.50,0.00,5.00,2979.54,1.00,14.00,51.00, + 0.42,0.11,0.00,1.00,0.00,0.00,70.90,37.84,0.02,548.40,0.00,46.35,5.00,1.66,0.29,0.00, + 0.02,2255.69,160.53,790.64,6775.15,0.68,19.50,2299.70,79.87,6.00,0.00,60.00,0.27, + 233.77,10.00,0.00,0.00,23.00,82.27,1.00,0.00,1.00,0.42,1.00,0.01,0.40,0.41,9.50,2299.70, + 46.30,0.00,0.00,2299.70,3.00,0.00,0.00,83.00,1.00], + [48.00,80.89,69.90,11570.00,26.00,0.40,468.00,0.00,5739.46,0.00,1480.00,90.89,0.00, + 14042.09,3600.08,120.00,0.09,31.00,0.25,2.36,0.00,7.00,22.00,0.00,257.59,0.00,6.00, + 260.00,0.05,313.00,1.00,0.07,468.00,0.00,0.67,11.00,0.02,0.32,0.00,0.00,1387.61,0.34, + 0.00,0.00,158.04,6.00,13.98,12380.05,0.00,0.16,122.74,3.00,0.18,291.33,7517.79,124.00, + 45.08,900.00,1.00,0.00,577.25,79.75,0.39,0.00,0.00,244.62,0.00,57.00,178.00,19.00, + 0.00,1.00,386.10,103.51,480.00,0.06,129.41,334.31,1.00,0.06,0.00,0.06,3.00,125.55, + 0.00,76.00,0.14,30.00,0.00,0.03,411.29,791.33,55.00,0.12,3.80,0.07,0.01,188.00,221.11, + 0.01,0.15,1.00,0.18,144.32,15.00,0.00,0.05,0.00,3.00,0.00,0.20,0.00,0.14,62.00,0.06, + 55.00,239.35,0.00,0.00,2.00,534.20,747.50,400.57,0.40,0.00,0.00,219.98,30.00,0.25, + 1.00,70.00,0.02,0.04,0.00,0.00,7.00,747.50,8.67,0.06,271.01,28.00,5.63,75.39,0.46, + 11.00,3.00,19.00,0.38,131.74,23.00,39.00,30249.41,0.00,202.68,2.00,64.94,0.03,2787.68, + 0.54,35.00,0.02,106.03,25.00,1.00,0.10,45.00,2.00,0.00,0.00,0.00,0.00,449.27,172.38, + 0.05,0.00,550.00,130.00,2006.55,0.07,0.00,0.03,0.00,5.00,0.21,22.00,0.05,0.01,1011.40, + 0.00,4.00,3600.08,0.00,1.00,1.00,1.00,0.00,3.00,9.00,270.00,0.12,0.03,0.00,0.00,820.00, + 1827.50,0.00,100.33,0.00,131.74,53.16,9557.97,7.00,0.00,11.00,180.81,0.00,0.01,0.04, + 0.02,1480.00,0.92,0.05,0.00,15.00,6.00,0.00,161.42,28.00,169.00,35.60,4.00,0.12,0.00, + 0.00,0.27,230.56,0.42,171.90,0.00,28407.51,1.00,883.10,0.00,261.00,9.00,1031.67,38.00, + 0.00,0.04,1607.68,0.32,791.33,0.04,1403.00,2.00,2260.50,88.08,2260.50,0.00,0.12,0.75, + 3.00,0.00,1231.68,0.07,0.60,0.24,0.00,0.00,0.15,0.14,753.50,1.00,95.00,7.00,0.26, + 77.63,38.45,0.00,42.65,0.00,14.00,0.07,6.00,0.00,1911.59,43.00,386.77,1324.80,0.00, + 518.00,10.00,10.00,0.11,0.00,1324.80,0.00,0.00,0.02,0.16,1.00,10492.12,5.00,0.94, + 5.00,0.08,0.10,1.00,0.92,3731.49,105.81,6931.39,0.00,0.43,0.00,118.00,5323.71,81.66, + 14042.09,0.08,0.20,0.40,96.64,0.00,0.08,4.00,1028.82,353.00,0.00,2.00,32.00,43.00, + 5.16,75.39,900.00,232.10,3.00,5.00,6049.88,1.00,126.00,46.00,0.59,0.15,0.00,8.00, + 7.00,0.00,577.25,0.00,0.07,2415.10,0.00,83.72,9.00,1.76,0.20,0.00,0.17,3278.65,155.26, + 4415.50,22731.62,1.00,55.00,0.00,499.94,22.00,0.58,67.00,0.21,341.72,16.00,0.00,965.07, + 17.00,138.41,0.00,0.00,1.00,0.14,1.00,0.02,0.35,1.69,369.00,1300.00,25.00,0.00,0.01, + 0.00,0.00,0.00,0.00,52.00,8.00]] + X = np.array(X) + y = [1, 0] -X = [ - [15.00,28.90,29.00,3143.70,0.00,0.10,69.90,90.00,13726.07,0.00,2299.70,0.00,0.05, - 4327.03,0.00,24.00,0.18,3.00,0.41,3.77,0.00,0.00,4.00,0.00,150.92,0.00,2.00,0.00, - 0.01,138.00,1.00,0.02,69.90,0.00,0.83,5.00,0.01,0.12,47.30,0.00,296.00,0.16,0.00, - 0.00,27.70,7.00,7.25,4406.16,1.00,0.54,245.28,3.00,0.06,306.50,5143.00,29.00,23.74, - 548.00,2.00,68.00,70.90,25.45,0.39,0.00,0.01,497.11,0.00,42.00,83.00,4.00,0.00,1.00, - 0.00,104.35,94.12,0.03,79.23,237.69,1.00,0.04,0.01,0.02,2.00,108.81,7.00,12.00,0.46, - 31.00,0.00,0.15,74.59,0.00,19.50,0.00,0.75,0.06,0.08,118.00,35.90,0.01,0.07,1.00, - 0.03,81.18,13.33,0.00,0.00,0.00,0.00,0.00,0.41,0.00,0.15,57.00,0.00,22.00,449.68, - 0.00,0.00,2.00,195.26,51.58,306.50,0.10,1.00,0.00,258.00,21.00,0.43,3.00,16.00,0.00, - 0.00,0.00,0.00,1.00,74.51,4.00,0.02,35.90,30.00,8.69,0.00,0.36,5.00,2.00,3.00,0.26, - 9.50,8.00,11.00,11918.15,0.00,258.00,13.00,9.04,0.14,604.65,0.92,74.59,0.00,0.00, - 72.76,1.00,0.22,64.00,2.00,0.00,0.00,0.02,0.00,305.50,27.70,0.02,0.00,177.00,14.00, - 0.00,0.05,90.00,0.03,0.00,1.00,0.43,4.00,0.05,0.09,431.00,0.00,2.00,0.00,0.00,1.00, - 0.25,0.17,0.00,0.00,21.00,94.12,0.17,0.00,0.00,0.00,548.00,0.00,68.00,0.00,0.00,9.50, - 25.45,1390.31,7.00,0.00,2.00,310.70,0.00,0.01,0.01,0.03,81.40,1.00,0.02,0.00,9.00, - 6.00,0.00,175.76,36.00,0.00,20.75,2.00,0.00,0.00,0.00,0.22,74.16,0.10,56.81,0.00, - 2197.03,0.00,197.66,0.00,55.00,20.00,367.18,22.00,0.00,0.01,1510.26,0.24,0.00,0.01, - 0.00,11.00,278.10,61.70,278.10,0.00,0.08,0.57,1.00,0.65,255.60,0.00,0.86,0.25,70.95, - 2299.70,0.23,0.05,92.70,1.00,38.00,0.00,0.00,56.81,21.85,0.00,23.74,0.00,2.00,0.03, - 2.00,0.00,347.58,30.00,243.55,109.00,0.00,296.00,6.00,6.00,0.00,0.00,109.00,2299.70, - 0.00,0.01,0.08,1.00,4745.09,4.00,0.18,0.00,0.17,0.02,0.00,1.00,147.13,71.07,2115.16, - 0.00,0.26,0.00,43.00,604.90,49.44,4327.03,0.68,0.75,0.10,86.36,52.98,0.20,0.00,22.50, - 305.50,0.00,1.00,0.00,7.00,0.78,0.00,296.00,22.50,0.00,5.00,2979.54,1.00,14.00,51.00, - 0.42,0.11,0.00,1.00,0.00,0.00,70.90,37.84,0.02,548.40,0.00,46.35,5.00,1.66,0.29,0.00, - 0.02,2255.69,160.53,790.64,6775.15,0.68,19.50,2299.70,79.87,6.00,0.00,60.00,0.27, - 233.77,10.00,0.00,0.00,23.00,82.27,1.00,0.00,1.00,0.42,1.00,0.01,0.40,0.41,9.50,2299.70, - 46.30,0.00,0.00,2299.70,3.00,0.00,0.00,83.00,1.00], - [48.00,80.89,69.90,11570.00,26.00,0.40,468.00,0.00,5739.46,0.00,1480.00,90.89,0.00, - 14042.09,3600.08,120.00,0.09,31.00,0.25,2.36,0.00,7.00,22.00,0.00,257.59,0.00,6.00, - 260.00,0.05,313.00,1.00,0.07,468.00,0.00,0.67,11.00,0.02,0.32,0.00,0.00,1387.61,0.34, - 0.00,0.00,158.04,6.00,13.98,12380.05,0.00,0.16,122.74,3.00,0.18,291.33,7517.79,124.00, - 45.08,900.00,1.00,0.00,577.25,79.75,0.39,0.00,0.00,244.62,0.00,57.00,178.00,19.00, - 0.00,1.00,386.10,103.51,480.00,0.06,129.41,334.31,1.00,0.06,0.00,0.06,3.00,125.55, - 0.00,76.00,0.14,30.00,0.00,0.03,411.29,791.33,55.00,0.12,3.80,0.07,0.01,188.00,221.11, - 0.01,0.15,1.00,0.18,144.32,15.00,0.00,0.05,0.00,3.00,0.00,0.20,0.00,0.14,62.00,0.06, - 55.00,239.35,0.00,0.00,2.00,534.20,747.50,400.57,0.40,0.00,0.00,219.98,30.00,0.25, - 1.00,70.00,0.02,0.04,0.00,0.00,7.00,747.50,8.67,0.06,271.01,28.00,5.63,75.39,0.46, - 11.00,3.00,19.00,0.38,131.74,23.00,39.00,30249.41,0.00,202.68,2.00,64.94,0.03,2787.68, - 0.54,35.00,0.02,106.03,25.00,1.00,0.10,45.00,2.00,0.00,0.00,0.00,0.00,449.27,172.38, - 0.05,0.00,550.00,130.00,2006.55,0.07,0.00,0.03,0.00,5.00,0.21,22.00,0.05,0.01,1011.40, - 0.00,4.00,3600.08,0.00,1.00,1.00,1.00,0.00,3.00,9.00,270.00,0.12,0.03,0.00,0.00,820.00, - 1827.50,0.00,100.33,0.00,131.74,53.16,9557.97,7.00,0.00,11.00,180.81,0.00,0.01,0.04, - 0.02,1480.00,0.92,0.05,0.00,15.00,6.00,0.00,161.42,28.00,169.00,35.60,4.00,0.12,0.00, - 0.00,0.27,230.56,0.42,171.90,0.00,28407.51,1.00,883.10,0.00,261.00,9.00,1031.67,38.00, - 0.00,0.04,1607.68,0.32,791.33,0.04,1403.00,2.00,2260.50,88.08,2260.50,0.00,0.12,0.75, - 3.00,0.00,1231.68,0.07,0.60,0.24,0.00,0.00,0.15,0.14,753.50,1.00,95.00,7.00,0.26, - 77.63,38.45,0.00,42.65,0.00,14.00,0.07,6.00,0.00,1911.59,43.00,386.77,1324.80,0.00, - 518.00,10.00,10.00,0.11,0.00,1324.80,0.00,0.00,0.02,0.16,1.00,10492.12,5.00,0.94, - 5.00,0.08,0.10,1.00,0.92,3731.49,105.81,6931.39,0.00,0.43,0.00,118.00,5323.71,81.66, - 14042.09,0.08,0.20,0.40,96.64,0.00,0.08,4.00,1028.82,353.00,0.00,2.00,32.00,43.00, - 5.16,75.39,900.00,232.10,3.00,5.00,6049.88,1.00,126.00,46.00,0.59,0.15,0.00,8.00, - 7.00,0.00,577.25,0.00,0.07,2415.10,0.00,83.72,9.00,1.76,0.20,0.00,0.17,3278.65,155.26, - 4415.50,22731.62,1.00,55.00,0.00,499.94,22.00,0.58,67.00,0.21,341.72,16.00,0.00,965.07, - 17.00,138.41,0.00,0.00,1.00,0.14,1.00,0.02,0.35,1.69,369.00,1300.00,25.00,0.00,0.01, - 0.00,0.00,0.00,0.00,52.00,8.00]] -X = np.array(X) -y = [1, 0] + dtrain = xgb.DMatrix(X, label=y) -dtrain = xgb.DMatrix(X, label=y) + param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic' } + watchlist = [(dtrain,'train')] + num_round = 2 + bst = xgb.train(param, dtrain, num_round, watchlist) -param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic' } -watchlist = [(dtrain,'train')] -num_round = 2 -bst = xgb.train(param, dtrain, num_round, watchlist) - -if xgb.rabit.get_rank() == 0: - bst.save_model("test_issue3402.model") - xgb.rabit.tracker_print("Finished training\n") - -# Notify the tracker all training has been successful -# This is only needed in distributed training. -xgb.rabit.finalize() + if xgb.rabit.get_rank() == 0: + bst.save_model("test_issue3402.model") + xgb.rabit.tracker_print("Finished training\n") diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index 2e113898f4de..885221aae4ae 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -16,10 +16,9 @@ def test_rabit_tracker(): rabit_env = [] for k, v in worker_env.items(): rabit_env.append(f"{k}={v}".encode()) - xgb.rabit.init(rabit_env) - ret = xgb.rabit.broadcast('test1234', 0) - assert str(ret) == 'test1234' - xgb.rabit.finalize() + with xgb.rabit.RabitContext(rabit_env): + ret = xgb.rabit.broadcast('test1234', 0) + assert str(ret) == 'test1234' def run_rabit_ops(client, n_workers):