From 1bf389998383f333490155dba4608bff9ca63b42 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 8 Dec 2020 08:36:23 +0800 Subject: [PATCH] Fix dask ip resolution. (#6475) This adopts the solution used in dask/dask-xgboost#40 which employs the get_host_ip from dmlc-core tracker. --- python-package/xgboost/dask.py | 5 ++--- python-package/xgboost/tracker.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index bbd2fe1d2c64..4000c280ad76 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -33,7 +33,7 @@ from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter from .core import _deprecate_positional_args from .training import train as worker_train -from .tracker import RabitTracker +from .tracker import RabitTracker, get_host_ip from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase from .sklearn import xgboost_model_doc @@ -70,8 +70,7 @@ def _start_tracker(n_workers): """Start Rabit tracker """ env = {'DMLC_NUM_WORKER': n_workers} - import socket - host = socket.gethostbyname(socket.gethostname()) + host = get_host_ip('auto') rabit_context = RabitTracker(hostIP=host, nslave=n_workers) env.update(rabit_context.slave_envs()) diff --git a/python-package/xgboost/tracker.py b/python-package/xgboost/tracker.py index 5b217b5c86f0..700b6898fa44 100644 --- a/python-package/xgboost/tracker.py +++ b/python-package/xgboost/tracker.py @@ -52,6 +52,28 @@ def get_some_ip(host): return socket.getaddrinfo(host, None)[0][4][0] +def get_host_ip(hostIP=None): + if hostIP is None or hostIP == 'auto': + hostIP = 'ip' + + if hostIP == 'dns': + hostIP = socket.getfqdn() + elif hostIP == 'ip': + from socket import gaierror + try: + hostIP = socket.gethostbyname(socket.getfqdn()) + except gaierror: + logging.warning( + 'gethostbyname(socket.getfqdn()) failed... trying on hostname()') + hostIP = socket.gethostbyname(socket.gethostname()) + if hostIP.startswith("127."): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + # doesn't have to be reachable + s.connect(('10.255.255.255', 1)) + hostIP = s.getsockname()[0] + return hostIP + + def get_family(addr): return socket.getaddrinfo(addr, None)[0][0]