Skip to content

Commit

Permalink
Verify shared object version at load. (#7928)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 23, 2022
1 parent 474366c commit d314680
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 12 deletions.
18 changes: 11 additions & 7 deletions python-package/xgboost/__init__.py
@@ -1,12 +1,16 @@
# coding: utf-8
"""XGBoost: eXtreme Gradient Boosting library.
Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
"""

import os

from .core import DMatrix, DeviceQuantileDMatrix, Booster, DataIter, build_info
from .core import (
DMatrix,
DeviceQuantileDMatrix,
Booster,
DataIter,
build_info,
_py_version,
)
from .training import train, cv
from . import rabit # noqa
from . import tracker # noqa
Expand All @@ -21,9 +25,9 @@
except ImportError:
pass

VERSION_FILE = os.path.join(os.path.dirname(__file__), "VERSION")
with open(VERSION_FILE, encoding="ascii") as f:
__version__ = f.read().strip()

__version__ = _py_version()


__all__ = [
# core
Expand Down
52 changes: 47 additions & 5 deletions python-package/xgboost/core.py
Expand Up @@ -139,14 +139,30 @@ def _get_log_callback_func() -> Callable:
return c_callback(_log_callback)


def _lib_version(lib: ctypes.CDLL) -> Tuple[int, int, int]:
"""Get the XGBoost version from native shared object."""
major = ctypes.c_int()
minor = ctypes.c_int()
patch = ctypes.c_int()
lib.XGBoostVersion(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch))
return major.value, minor.value, patch.value


def _py_version() -> str:
"""Get the XGBoost version from Python version file."""
VERSION_FILE = os.path.join(os.path.dirname(__file__), "VERSION")
with open(VERSION_FILE, encoding="ascii") as f:
return f.read().strip()


def _load_lib() -> ctypes.CDLL:
"""Load xgboost Library."""
lib_paths = find_lib_path()
if not lib_paths:
# This happens only when building document.
return None # type: ignore
try:
pathBackup = os.environ['PATH'].split(os.pathsep)
pathBackup = os.environ["PATH"].split(os.pathsep)
except KeyError:
pathBackup = []
lib_success = False
Expand All @@ -155,16 +171,17 @@ def _load_lib() -> ctypes.CDLL:
try:
# needed when the lib is linked with non-system-available
# dependencies
os.environ['PATH'] = os.pathsep.join(
pathBackup + [os.path.dirname(lib_path)])
os.environ["PATH"] = os.pathsep.join(
pathBackup + [os.path.dirname(lib_path)]
)
lib = ctypes.cdll.LoadLibrary(lib_path)
setattr(lib, "path", os.path.normpath(lib_path))
lib_success = True
except OSError as e:
os_error_list.append(str(e))
continue
finally:
os.environ['PATH'] = os.pathsep.join(pathBackup)
os.environ["PATH"] = os.pathsep.join(pathBackup)
if not lib_success:
libname = os.path.basename(lib_paths[0])
raise XGBoostError(
Expand All @@ -180,11 +197,36 @@ def _load_lib() -> ctypes.CDLL:
* You are running 32-bit Python on a 64-bit OS
Error message(s): {os_error_list}
""")
"""
)
lib.XGBGetLastError.restype = ctypes.c_char_p
lib.callback = _get_log_callback_func() # type: ignore
if lib.XGBRegisterLogCallback(lib.callback) != 0:
raise XGBoostError(lib.XGBGetLastError())

def parse(ver: str) -> Tuple[int, int, int]:
"""Avoid dependency on packaging (PEP 440)."""
# 2.0.0-dev or 2.0.0
major, minor, patch = ver.split("-")[0].split(".")
return int(major), int(minor), int(patch)

libver = _lib_version(lib)
pyver = parse(_py_version())

# verify that we are loading the correct binary.
if pyver != libver:
pyver_str = ".".join((str(v) for v in pyver))
libver_str = ".".join((str(v) for v in libver))
msg = (
"Mismatched version between the Python package and the native shared "
f"""object. Python package version: {pyver_str}. Shared object """
f"""version: {libver_str}. Shared object is loaded from: {lib.path}.
Likely cause:
* XGBoost is first installed with anaconda then upgraded with pip. To fix it """
"please remove one of the installations."
)
raise ValueError(msg)

return lib


Expand Down

0 comments on commit d314680

Please sign in to comment.