Skip to content

Commit

Permalink
use RabitContext intead of init/finalize (#7911)
Browse files Browse the repository at this point in the history
  • Loading branch information
rongou committed May 17, 2022
1 parent 4fcfd9c commit 77d4a53
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 163 deletions.
4 changes: 3 additions & 1 deletion python-package/xgboost/rabit.py
Expand Up @@ -230,7 +230,9 @@ def version_number() -> int:
class RabitContext:
"""A context controlling rabit initialization and finalization."""

def __init__(self, args: List[bytes]) -> None:
def __init__(self, args: List[bytes] = None) -> None:
if args is None:
args = []
self.args = args

def __enter__(self) -> None:
Expand Down
78 changes: 38 additions & 40 deletions tests/distributed/distributed_gpu.py
Expand Up @@ -8,46 +8,44 @@
def run_test(name, params_fun):
"""Runs a distributed GPU test."""
# Always call this before using distributed module
xgb.rabit.init()
rank = xgb.rabit.get_rank()
world = xgb.rabit.get_world_size()

# Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')

params, n_rounds = params_fun(rank)

# Specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]

# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(params, dtrain, n_rounds, watchlist, early_stopping_rounds=2)

# Have each worker save its model
model_name = "test.model.%s.%d" % (name, rank)
bst.dump_model(model_name, with_stats=True)
xgb.rabit.allreduce(np.ones((1, 1)), xgb.rabit.Op.MAX) # sync
xgb.rabit.tracker_print("Finished training\n")

if (rank == 0):
for i in range(0, world):
model_name_root = "test.model.%s.%d" % (name, i)
for j in range(0, world):
if i == j:
continue
with open(model_name_root, 'r') as model_root:
contents_root = model_root.read()
model_name_rank = "test.model.%s.%d" % (name, j)
with open(model_name_rank, 'r') as model_rank:
contents_rank = model_rank.read()
if contents_root != contents_rank:
raise Exception(
('Worker models diverged: test.model.%s.%d '
'differs from test.model.%s.%d') % (name, i, name, j))

xgb.rabit.finalize()
with xgb.rabit.RabitContext():
rank = xgb.rabit.get_rank()
world = xgb.rabit.get_world_size()

# Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')

params, n_rounds = params_fun(rank)

# Specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]

# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(params, dtrain, n_rounds, watchlist, early_stopping_rounds=2)

# Have each worker save its model
model_name = "test.model.%s.%d" % (name, rank)
bst.dump_model(model_name, with_stats=True)
xgb.rabit.allreduce(np.ones((1, 1)), xgb.rabit.Op.MAX) # sync
xgb.rabit.tracker_print("Finished training\n")

if (rank == 0):
for i in range(0, world):
model_name_root = "test.model.%s.%d" % (name, i)
for j in range(0, world):
if i == j:
continue
with open(model_name_root, 'r') as model_root:
contents_root = model_root.read()
model_name_rank = "test.model.%s.%d" % (name, j)
with open(model_name_rank, 'r') as model_rank:
contents_rank = model_rank.read()
if contents_root != contents_rank:
raise Exception(
('Worker models diverged: test.model.%s.%d '
'differs from test.model.%s.%d') % (name, i, name, j))


base_params = {
Expand Down
37 changes: 16 additions & 21 deletions tests/distributed/test_basic.py
Expand Up @@ -2,28 +2,23 @@
import xgboost as xgb

# Always call this before using distributed module
xgb.rabit.init()
with xgb.rabit.RabitContext():
# Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')

# Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
# Specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}

# Specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
# Specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 20

# Specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 20
# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)

# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)

# Save the model, only ask process 0 to save the model.
if xgb.rabit.get_rank() == 0:
bst.save_model("test.model")
xgb.rabit.tracker_print("Finished training\n")

# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()
# Save the model, only ask process 0 to save the model.
if xgb.rabit.get_rank() == 0:
bst.save_model("test.model")
xgb.rabit.tracker_print("Finished training\n")
45 changes: 20 additions & 25 deletions tests/distributed/test_federated.py
Expand Up @@ -27,31 +27,26 @@ def run_worker(port: int, world_size: int, rank: int) -> None:
f'federated_client_key={CLIENT_KEY}',
f'federated_client_cert={CLIENT_CERT}'
]
xgb.rabit.init([e.encode() for e in rabit_env])

# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)

# Specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}

# Specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 20

# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, evals=watchlist, early_stopping_rounds=2)

# Save the model, only ask process 0 to save the model.
if xgb.rabit.get_rank() == 0:
bst.save_model("test.model.json")
xgb.rabit.tracker_print("Finished training\n")

# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
dtest = xgb.DMatrix('agaricus.txt.test-%02d' % rank)

# Specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}

# Specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 20

# Run training, all the features in training API is available.
bst = xgb.train(param, dtrain, num_round, evals=watchlist,
early_stopping_rounds=2)

# Save the model, only ask process 0 to save the model.
if xgb.rabit.get_rank() == 0:
bst.save_model("test.model.json")
xgb.rabit.tracker_print("Finished training\n")


def run_test() -> None:
Expand Down
139 changes: 67 additions & 72 deletions tests/distributed/test_issue3402.py
Expand Up @@ -2,78 +2,73 @@
import xgboost as xgb
import numpy as np

xgb.rabit.init()
with xgb.rabit.RabitContext():
X = [
[15.00,28.90,29.00,3143.70,0.00,0.10,69.90,90.00,13726.07,0.00,2299.70,0.00,0.05,
4327.03,0.00,24.00,0.18,3.00,0.41,3.77,0.00,0.00,4.00,0.00,150.92,0.00,2.00,0.00,
0.01,138.00,1.00,0.02,69.90,0.00,0.83,5.00,0.01,0.12,47.30,0.00,296.00,0.16,0.00,
0.00,27.70,7.00,7.25,4406.16,1.00,0.54,245.28,3.00,0.06,306.50,5143.00,29.00,23.74,
548.00,2.00,68.00,70.90,25.45,0.39,0.00,0.01,497.11,0.00,42.00,83.00,4.00,0.00,1.00,
0.00,104.35,94.12,0.03,79.23,237.69,1.00,0.04,0.01,0.02,2.00,108.81,7.00,12.00,0.46,
31.00,0.00,0.15,74.59,0.00,19.50,0.00,0.75,0.06,0.08,118.00,35.90,0.01,0.07,1.00,
0.03,81.18,13.33,0.00,0.00,0.00,0.00,0.00,0.41,0.00,0.15,57.00,0.00,22.00,449.68,
0.00,0.00,2.00,195.26,51.58,306.50,0.10,1.00,0.00,258.00,21.00,0.43,3.00,16.00,0.00,
0.00,0.00,0.00,1.00,74.51,4.00,0.02,35.90,30.00,8.69,0.00,0.36,5.00,2.00,3.00,0.26,
9.50,8.00,11.00,11918.15,0.00,258.00,13.00,9.04,0.14,604.65,0.92,74.59,0.00,0.00,
72.76,1.00,0.22,64.00,2.00,0.00,0.00,0.02,0.00,305.50,27.70,0.02,0.00,177.00,14.00,
0.00,0.05,90.00,0.03,0.00,1.00,0.43,4.00,0.05,0.09,431.00,0.00,2.00,0.00,0.00,1.00,
0.25,0.17,0.00,0.00,21.00,94.12,0.17,0.00,0.00,0.00,548.00,0.00,68.00,0.00,0.00,9.50,
25.45,1390.31,7.00,0.00,2.00,310.70,0.00,0.01,0.01,0.03,81.40,1.00,0.02,0.00,9.00,
6.00,0.00,175.76,36.00,0.00,20.75,2.00,0.00,0.00,0.00,0.22,74.16,0.10,56.81,0.00,
2197.03,0.00,197.66,0.00,55.00,20.00,367.18,22.00,0.00,0.01,1510.26,0.24,0.00,0.01,
0.00,11.00,278.10,61.70,278.10,0.00,0.08,0.57,1.00,0.65,255.60,0.00,0.86,0.25,70.95,
2299.70,0.23,0.05,92.70,1.00,38.00,0.00,0.00,56.81,21.85,0.00,23.74,0.00,2.00,0.03,
2.00,0.00,347.58,30.00,243.55,109.00,0.00,296.00,6.00,6.00,0.00,0.00,109.00,2299.70,
0.00,0.01,0.08,1.00,4745.09,4.00,0.18,0.00,0.17,0.02,0.00,1.00,147.13,71.07,2115.16,
0.00,0.26,0.00,43.00,604.90,49.44,4327.03,0.68,0.75,0.10,86.36,52.98,0.20,0.00,22.50,
305.50,0.00,1.00,0.00,7.00,0.78,0.00,296.00,22.50,0.00,5.00,2979.54,1.00,14.00,51.00,
0.42,0.11,0.00,1.00,0.00,0.00,70.90,37.84,0.02,548.40,0.00,46.35,5.00,1.66,0.29,0.00,
0.02,2255.69,160.53,790.64,6775.15,0.68,19.50,2299.70,79.87,6.00,0.00,60.00,0.27,
233.77,10.00,0.00,0.00,23.00,82.27,1.00,0.00,1.00,0.42,1.00,0.01,0.40,0.41,9.50,2299.70,
46.30,0.00,0.00,2299.70,3.00,0.00,0.00,83.00,1.00],
[48.00,80.89,69.90,11570.00,26.00,0.40,468.00,0.00,5739.46,0.00,1480.00,90.89,0.00,
14042.09,3600.08,120.00,0.09,31.00,0.25,2.36,0.00,7.00,22.00,0.00,257.59,0.00,6.00,
260.00,0.05,313.00,1.00,0.07,468.00,0.00,0.67,11.00,0.02,0.32,0.00,0.00,1387.61,0.34,
0.00,0.00,158.04,6.00,13.98,12380.05,0.00,0.16,122.74,3.00,0.18,291.33,7517.79,124.00,
45.08,900.00,1.00,0.00,577.25,79.75,0.39,0.00,0.00,244.62,0.00,57.00,178.00,19.00,
0.00,1.00,386.10,103.51,480.00,0.06,129.41,334.31,1.00,0.06,0.00,0.06,3.00,125.55,
0.00,76.00,0.14,30.00,0.00,0.03,411.29,791.33,55.00,0.12,3.80,0.07,0.01,188.00,221.11,
0.01,0.15,1.00,0.18,144.32,15.00,0.00,0.05,0.00,3.00,0.00,0.20,0.00,0.14,62.00,0.06,
55.00,239.35,0.00,0.00,2.00,534.20,747.50,400.57,0.40,0.00,0.00,219.98,30.00,0.25,
1.00,70.00,0.02,0.04,0.00,0.00,7.00,747.50,8.67,0.06,271.01,28.00,5.63,75.39,0.46,
11.00,3.00,19.00,0.38,131.74,23.00,39.00,30249.41,0.00,202.68,2.00,64.94,0.03,2787.68,
0.54,35.00,0.02,106.03,25.00,1.00,0.10,45.00,2.00,0.00,0.00,0.00,0.00,449.27,172.38,
0.05,0.00,550.00,130.00,2006.55,0.07,0.00,0.03,0.00,5.00,0.21,22.00,0.05,0.01,1011.40,
0.00,4.00,3600.08,0.00,1.00,1.00,1.00,0.00,3.00,9.00,270.00,0.12,0.03,0.00,0.00,820.00,
1827.50,0.00,100.33,0.00,131.74,53.16,9557.97,7.00,0.00,11.00,180.81,0.00,0.01,0.04,
0.02,1480.00,0.92,0.05,0.00,15.00,6.00,0.00,161.42,28.00,169.00,35.60,4.00,0.12,0.00,
0.00,0.27,230.56,0.42,171.90,0.00,28407.51,1.00,883.10,0.00,261.00,9.00,1031.67,38.00,
0.00,0.04,1607.68,0.32,791.33,0.04,1403.00,2.00,2260.50,88.08,2260.50,0.00,0.12,0.75,
3.00,0.00,1231.68,0.07,0.60,0.24,0.00,0.00,0.15,0.14,753.50,1.00,95.00,7.00,0.26,
77.63,38.45,0.00,42.65,0.00,14.00,0.07,6.00,0.00,1911.59,43.00,386.77,1324.80,0.00,
518.00,10.00,10.00,0.11,0.00,1324.80,0.00,0.00,0.02,0.16,1.00,10492.12,5.00,0.94,
5.00,0.08,0.10,1.00,0.92,3731.49,105.81,6931.39,0.00,0.43,0.00,118.00,5323.71,81.66,
14042.09,0.08,0.20,0.40,96.64,0.00,0.08,4.00,1028.82,353.00,0.00,2.00,32.00,43.00,
5.16,75.39,900.00,232.10,3.00,5.00,6049.88,1.00,126.00,46.00,0.59,0.15,0.00,8.00,
7.00,0.00,577.25,0.00,0.07,2415.10,0.00,83.72,9.00,1.76,0.20,0.00,0.17,3278.65,155.26,
4415.50,22731.62,1.00,55.00,0.00,499.94,22.00,0.58,67.00,0.21,341.72,16.00,0.00,965.07,
17.00,138.41,0.00,0.00,1.00,0.14,1.00,0.02,0.35,1.69,369.00,1300.00,25.00,0.00,0.01,
0.00,0.00,0.00,0.00,52.00,8.00]]
X = np.array(X)
y = [1, 0]

X = [
[15.00,28.90,29.00,3143.70,0.00,0.10,69.90,90.00,13726.07,0.00,2299.70,0.00,0.05,
4327.03,0.00,24.00,0.18,3.00,0.41,3.77,0.00,0.00,4.00,0.00,150.92,0.00,2.00,0.00,
0.01,138.00,1.00,0.02,69.90,0.00,0.83,5.00,0.01,0.12,47.30,0.00,296.00,0.16,0.00,
0.00,27.70,7.00,7.25,4406.16,1.00,0.54,245.28,3.00,0.06,306.50,5143.00,29.00,23.74,
548.00,2.00,68.00,70.90,25.45,0.39,0.00,0.01,497.11,0.00,42.00,83.00,4.00,0.00,1.00,
0.00,104.35,94.12,0.03,79.23,237.69,1.00,0.04,0.01,0.02,2.00,108.81,7.00,12.00,0.46,
31.00,0.00,0.15,74.59,0.00,19.50,0.00,0.75,0.06,0.08,118.00,35.90,0.01,0.07,1.00,
0.03,81.18,13.33,0.00,0.00,0.00,0.00,0.00,0.41,0.00,0.15,57.00,0.00,22.00,449.68,
0.00,0.00,2.00,195.26,51.58,306.50,0.10,1.00,0.00,258.00,21.00,0.43,3.00,16.00,0.00,
0.00,0.00,0.00,1.00,74.51,4.00,0.02,35.90,30.00,8.69,0.00,0.36,5.00,2.00,3.00,0.26,
9.50,8.00,11.00,11918.15,0.00,258.00,13.00,9.04,0.14,604.65,0.92,74.59,0.00,0.00,
72.76,1.00,0.22,64.00,2.00,0.00,0.00,0.02,0.00,305.50,27.70,0.02,0.00,177.00,14.00,
0.00,0.05,90.00,0.03,0.00,1.00,0.43,4.00,0.05,0.09,431.00,0.00,2.00,0.00,0.00,1.00,
0.25,0.17,0.00,0.00,21.00,94.12,0.17,0.00,0.00,0.00,548.00,0.00,68.00,0.00,0.00,9.50,
25.45,1390.31,7.00,0.00,2.00,310.70,0.00,0.01,0.01,0.03,81.40,1.00,0.02,0.00,9.00,
6.00,0.00,175.76,36.00,0.00,20.75,2.00,0.00,0.00,0.00,0.22,74.16,0.10,56.81,0.00,
2197.03,0.00,197.66,0.00,55.00,20.00,367.18,22.00,0.00,0.01,1510.26,0.24,0.00,0.01,
0.00,11.00,278.10,61.70,278.10,0.00,0.08,0.57,1.00,0.65,255.60,0.00,0.86,0.25,70.95,
2299.70,0.23,0.05,92.70,1.00,38.00,0.00,0.00,56.81,21.85,0.00,23.74,0.00,2.00,0.03,
2.00,0.00,347.58,30.00,243.55,109.00,0.00,296.00,6.00,6.00,0.00,0.00,109.00,2299.70,
0.00,0.01,0.08,1.00,4745.09,4.00,0.18,0.00,0.17,0.02,0.00,1.00,147.13,71.07,2115.16,
0.00,0.26,0.00,43.00,604.90,49.44,4327.03,0.68,0.75,0.10,86.36,52.98,0.20,0.00,22.50,
305.50,0.00,1.00,0.00,7.00,0.78,0.00,296.00,22.50,0.00,5.00,2979.54,1.00,14.00,51.00,
0.42,0.11,0.00,1.00,0.00,0.00,70.90,37.84,0.02,548.40,0.00,46.35,5.00,1.66,0.29,0.00,
0.02,2255.69,160.53,790.64,6775.15,0.68,19.50,2299.70,79.87,6.00,0.00,60.00,0.27,
233.77,10.00,0.00,0.00,23.00,82.27,1.00,0.00,1.00,0.42,1.00,0.01,0.40,0.41,9.50,2299.70,
46.30,0.00,0.00,2299.70,3.00,0.00,0.00,83.00,1.00],
[48.00,80.89,69.90,11570.00,26.00,0.40,468.00,0.00,5739.46,0.00,1480.00,90.89,0.00,
14042.09,3600.08,120.00,0.09,31.00,0.25,2.36,0.00,7.00,22.00,0.00,257.59,0.00,6.00,
260.00,0.05,313.00,1.00,0.07,468.00,0.00,0.67,11.00,0.02,0.32,0.00,0.00,1387.61,0.34,
0.00,0.00,158.04,6.00,13.98,12380.05,0.00,0.16,122.74,3.00,0.18,291.33,7517.79,124.00,
45.08,900.00,1.00,0.00,577.25,79.75,0.39,0.00,0.00,244.62,0.00,57.00,178.00,19.00,
0.00,1.00,386.10,103.51,480.00,0.06,129.41,334.31,1.00,0.06,0.00,0.06,3.00,125.55,
0.00,76.00,0.14,30.00,0.00,0.03,411.29,791.33,55.00,0.12,3.80,0.07,0.01,188.00,221.11,
0.01,0.15,1.00,0.18,144.32,15.00,0.00,0.05,0.00,3.00,0.00,0.20,0.00,0.14,62.00,0.06,
55.00,239.35,0.00,0.00,2.00,534.20,747.50,400.57,0.40,0.00,0.00,219.98,30.00,0.25,
1.00,70.00,0.02,0.04,0.00,0.00,7.00,747.50,8.67,0.06,271.01,28.00,5.63,75.39,0.46,
11.00,3.00,19.00,0.38,131.74,23.00,39.00,30249.41,0.00,202.68,2.00,64.94,0.03,2787.68,
0.54,35.00,0.02,106.03,25.00,1.00,0.10,45.00,2.00,0.00,0.00,0.00,0.00,449.27,172.38,
0.05,0.00,550.00,130.00,2006.55,0.07,0.00,0.03,0.00,5.00,0.21,22.00,0.05,0.01,1011.40,
0.00,4.00,3600.08,0.00,1.00,1.00,1.00,0.00,3.00,9.00,270.00,0.12,0.03,0.00,0.00,820.00,
1827.50,0.00,100.33,0.00,131.74,53.16,9557.97,7.00,0.00,11.00,180.81,0.00,0.01,0.04,
0.02,1480.00,0.92,0.05,0.00,15.00,6.00,0.00,161.42,28.00,169.00,35.60,4.00,0.12,0.00,
0.00,0.27,230.56,0.42,171.90,0.00,28407.51,1.00,883.10,0.00,261.00,9.00,1031.67,38.00,
0.00,0.04,1607.68,0.32,791.33,0.04,1403.00,2.00,2260.50,88.08,2260.50,0.00,0.12,0.75,
3.00,0.00,1231.68,0.07,0.60,0.24,0.00,0.00,0.15,0.14,753.50,1.00,95.00,7.00,0.26,
77.63,38.45,0.00,42.65,0.00,14.00,0.07,6.00,0.00,1911.59,43.00,386.77,1324.80,0.00,
518.00,10.00,10.00,0.11,0.00,1324.80,0.00,0.00,0.02,0.16,1.00,10492.12,5.00,0.94,
5.00,0.08,0.10,1.00,0.92,3731.49,105.81,6931.39,0.00,0.43,0.00,118.00,5323.71,81.66,
14042.09,0.08,0.20,0.40,96.64,0.00,0.08,4.00,1028.82,353.00,0.00,2.00,32.00,43.00,
5.16,75.39,900.00,232.10,3.00,5.00,6049.88,1.00,126.00,46.00,0.59,0.15,0.00,8.00,
7.00,0.00,577.25,0.00,0.07,2415.10,0.00,83.72,9.00,1.76,0.20,0.00,0.17,3278.65,155.26,
4415.50,22731.62,1.00,55.00,0.00,499.94,22.00,0.58,67.00,0.21,341.72,16.00,0.00,965.07,
17.00,138.41,0.00,0.00,1.00,0.14,1.00,0.02,0.35,1.69,369.00,1300.00,25.00,0.00,0.01,
0.00,0.00,0.00,0.00,52.00,8.00]]
X = np.array(X)
y = [1, 0]
dtrain = xgb.DMatrix(X, label=y)

dtrain = xgb.DMatrix(X, label=y)
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic' }
watchlist = [(dtrain,'train')]
num_round = 2
bst = xgb.train(param, dtrain, num_round, watchlist)

param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic' }
watchlist = [(dtrain,'train')]
num_round = 2
bst = xgb.train(param, dtrain, num_round, watchlist)

if xgb.rabit.get_rank() == 0:
bst.save_model("test_issue3402.model")
xgb.rabit.tracker_print("Finished training\n")

# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()
if xgb.rabit.get_rank() == 0:
bst.save_model("test_issue3402.model")
xgb.rabit.tracker_print("Finished training\n")
7 changes: 3 additions & 4 deletions tests/python/test_tracker.py
Expand Up @@ -16,10 +16,9 @@ def test_rabit_tracker():
rabit_env = []
for k, v in worker_env.items():
rabit_env.append(f"{k}={v}".encode())
xgb.rabit.init(rabit_env)
ret = xgb.rabit.broadcast('test1234', 0)
assert str(ret) == 'test1234'
xgb.rabit.finalize()
with xgb.rabit.RabitContext(rabit_env):
ret = xgb.rabit.broadcast('test1234', 0)
assert str(ret) == 'test1234'


def run_rabit_ops(client, n_workers):
Expand Down

0 comments on commit 77d4a53

Please sign in to comment.