Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dask ip resolution. #6475

Merged
merged 3 commits into from Dec 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 2 additions & 3 deletions python-package/xgboost/dask.py
Expand Up @@ -33,7 +33,7 @@
from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
from .core import _deprecate_positional_args
from .training import train as worker_train
from .tracker import RabitTracker
from .tracker import RabitTracker, get_host_ip
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
from .sklearn import xgboost_model_doc

Expand Down Expand Up @@ -70,8 +70,7 @@
def _start_tracker(n_workers):
"""Start Rabit tracker """
env = {'DMLC_NUM_WORKER': n_workers}
import socket
host = socket.gethostbyname(socket.gethostname())
host = get_host_ip('auto')
rabit_context = RabitTracker(hostIP=host, nslave=n_workers)
env.update(rabit_context.slave_envs())

Expand Down
22 changes: 22 additions & 0 deletions python-package/xgboost/tracker.py
Expand Up @@ -52,6 +52,28 @@ def get_some_ip(host):
return socket.getaddrinfo(host, None)[0][4][0]


def get_host_ip(hostIP=None):
if hostIP is None or hostIP == 'auto':
hostIP = 'ip'

if hostIP == 'dns':
hostIP = socket.getfqdn()
elif hostIP == 'ip':
from socket import gaierror
try:
hostIP = socket.gethostbyname(socket.getfqdn())
except gaierror:
logging.warning(
'gethostbyname(socket.getfqdn()) failed... trying on hostname()')
hostIP = socket.gethostbyname(socket.gethostname())
if hostIP.startswith("127."):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# doesn't have to be reachable

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May just leave a more informative comment here, for posterity. 1) Why local host isn't appropriate in all contexts, 2) Why '10.255.255.255' on port 1. Did you choose this because its a broadcast addr and won't be forwarded, or just because it's unlikely to be hosting anything?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @drobison00 , let me add some comments there later. But if you have suggestions on how to do this better feel free to share! I'm happy to revise it. I think you are much more familiar with this than me.

s.connect(('10.255.255.255', 1))
hostIP = s.getsockname()[0]
return hostIP


def get_family(addr):
return socket.getaddrinfo(addr, None)[0][0]

Expand Down