diff --git a/demo/nvflare/custom/trainer.py b/demo/nvflare/custom/trainer.py index fd56da0e57fe..fd93ae3a4031 100644 --- a/demo/nvflare/custom/trainer.py +++ b/demo/nvflare/custom/trainer.py @@ -53,6 +53,7 @@ def _do_training(self, fl_ctx: FLContext): client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME) rank = int(client_name.split('-')[1]) - 1 communicator_env = { + 'xgboost_communicator': 'federated', 'federated_server_address': self._server_address, 'federated_world_size': self._world_size, 'federated_rank': rank,