Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] ENH: add TBB support #47

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 15 additions & 0 deletions .azure_pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,21 @@ jobs:
VERSION_PYTHON: '*'
CC_OUTER_LOOP: 'clang-8'
CC_INNER_LOOP: 'gcc'
# Same but with TBB as threading layer for MKL and tbb4py installed
pylatest_conda_mkl_tbb4py_clang_gcc:
PACKAGER: 'conda'
VERSION_PYTHON: '*'
CC_OUTER_LOOP: 'clang-8'
CC_INNER_LOOP: 'gcc'
MKL_THREADING_LAYER: 'TBB'
TBB4PY: 'true'
# Same but with TBB as threading layer for MKL and tbb4py not installed
pylatest_conda_mkl_tbb_clang_gcc:
PACKAGER: 'conda'
VERSION_PYTHON: '*'
CC_OUTER_LOOP: 'clang-8'
CC_INNER_LOOP: 'gcc'
MKL_THREADING_LAYER: 'TBB'
# Linux + Python 3.7 with numpy / scipy installed with pip from PyPI and
# heterogeneous openmp runtimes.
py37_pip_openblas_gcc_clang:
Expand Down
3 changes: 3 additions & 0 deletions continuous_integration/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ if [[ "$PACKAGER" == "conda" ]]; then
if [[ "$NO_MKL" == "true" ]]; then
TO_INSTALL="$TO_INSTALL nomkl"
fi
fi
if [[ ! -z "$TBB4PY" ]]; then
TO_INSTALL="$TO_INSTALL tbb4py"
fi
make_conda $TO_INSTALL

Expand Down
95 changes: 89 additions & 6 deletions threadpoolctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ class _dl_phdr_info(ctypes.Structure):
_RTLD_NOLOAD = ctypes.DEFAULT_MODE


# TBB's api to globally control the threadpool is a C++ class `global_control`
# which can't be accessed using ctypes. It can be accessed through tbb4py,
# pythons wrappers for TBB.
try:
import tbb
except ImportError:
tbb = None


# List of the supported implementations. The items hold the prefix of loaded
# shared objects, the name of the internal_api to call, matching the
# MAP_API_TO_FUNC keys and the name of the user_api, in {"blas", "openmp"}.
Expand All @@ -89,6 +98,11 @@ class _dl_phdr_info(ctypes.Structure):
"internal_api": "blis",
"filename_prefixes": ("libblis",),
},
{
"user_api": "tbb",
"internal_api": "tbb",
"filename_prefixes": ("libtbb",),
},
]

# map a internal_api (openmp, openblas, mkl) to set and get functions
Expand Down Expand Up @@ -182,6 +196,9 @@ def _set_threadpool_limits(limits, user_api=None):
- 'dynlib': the instance of ctypes.CDLL use to access the dynamic
library.
"""
original_limits = limits
original_user_api = user_api

if isinstance(limits, int):
if user_api is None:
user_api = _ALL_USER_APIS
Expand Down Expand Up @@ -209,6 +226,8 @@ def _set_threadpool_limits(limits, user_api=None):
user_api = [module for module in limits if module in _ALL_USER_APIS]

modules = _load_modules(prefixes=prefixes, user_api=user_api)
_check_tbb_set_limits(original_limits, original_user_api, modules)

for module in modules:
# Workaround clang bug (TODO: report it)
module['get_num_threads']()
Expand All @@ -223,6 +242,36 @@ def _set_threadpool_limits(limits, user_api=None):
return modules


def _check_tbb_set_limits(limits, user_api, modules):
"""Check whether calling threadpool_limits for TBB will have an effect

Raise an error or warn the user depending on the user's request.
"""
if tbb is not None:
# tbb is installed no need to warn or raise anything
return

if not any(module['user_api'] == 'tbb' for module in modules):
# tbb is not in the modules, not need to warn or raise anything
return

warn_msg = ("'libtbb' is loaded but tbb4py is not installed. This function "
"has no effect on the number of threads for TBB.")
raise_msg = "tbb4py is required limit the number of threads for TBB."
if isinstance(limits, int):
if user_api is None:
warnings.warn(warn_msg, RuntimeWarning)
elif 'tbb' in user_api:
raise RuntimeError(raise_msg)
elif isinstance(limits, list):
if 'libtbb' in [module['prefix'] for module in modules]:
raise RuntimeError(raise_msg)
else:
# only remaining possiblity is limits is a dict
if 'libtbb' in limits or 'tbb' in limits:
raise RuntimeError(raise_msg)


@_format_docstring(INTERNAL_APIS=_ALL_INTERNAL_APIS)
def threadpool_info():
"""Return the maximal number of threads for each detected library.
Expand All @@ -243,6 +292,8 @@ def threadpool_info():
# we map it to 1 for consistency with other libraries.
if module['num_threads'] == -1 and module['internal_api'] == 'blis':
module['num_threads'] = 1
# add a 'status' field to display additional informations
_add_status(module)
# Remove the wrapper for the module and its function
del module['set_num_threads'], module['get_num_threads']
del module['dynlib']
Expand All @@ -251,6 +302,11 @@ def threadpool_info():
return infos


def _add_status(module):
if module['internal_api'] == 'tbb' and tbb is None:
module['status'] = "threadpool introspection failed (tbb4py required)"


def _get_version(dynlib, internal_api):
if internal_api == "mkl":
return _get_mkl_version(dynlib)
Expand All @@ -262,6 +318,8 @@ def _get_version(dynlib, internal_api):
return _get_openblas_version(dynlib)
elif internal_api == "blis":
return _get_blis_version(dynlib)
elif internal_api == "tbb":
return _get_tbb_version(dynlib)
else:
raise NotImplementedError("Unsupported API {}".format(internal_api))

Expand Down Expand Up @@ -298,6 +356,24 @@ def _get_blis_version(blis_dynlib):
get_version.restype = ctypes.c_char_p
return get_version().decode('utf-8')

def _get_tbb_version(tbb_dynlib):
"""Return the TBB version"""
get_version = getattr(tbb_dynlib, "TBB_runtime_interface_version")
return str(get_version())


def _tbb_get_num_threads():
if tbb is not None:
return tbb.global_control.active_value(
tbb.global_control.max_allowed_parallelism)
return None


def _tbb_set_num_threads(max_threads):
if tbb is not None:
tbb.global_control(
tbb.global_control.max_allowed_parallelism, max_threads)


# Loading utilities for dynamically linked shared objects

Expand Down Expand Up @@ -337,12 +413,19 @@ def _make_module_info(filepath, module_info, prefix):
filepath = os.path.normpath(filepath)
dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD)
internal_api = module_info['internal_api']
set_func = getattr(dynlib,
_MAP_API_TO_FUNC[internal_api]['set_num_threads'],
lambda num_threads: None)
get_func = getattr(dynlib,
_MAP_API_TO_FUNC[internal_api]['get_num_threads'],
lambda: None)

if module_info['user_api'] == 'tbb':
# tbb's api can't be accessed with ctypes. We access it through tbb4py
set_func = _tbb_set_num_threads
get_func = _tbb_get_num_threads
else:
set_func = getattr(dynlib,
_MAP_API_TO_FUNC[internal_api]['set_num_threads'],
lambda num_threads: None)
get_func = getattr(dynlib,
_MAP_API_TO_FUNC[internal_api]['get_num_threads'],
lambda: None)

module_info = module_info.copy()
module_info.update(dynlib=dynlib, filepath=filepath, prefix=prefix,
set_num_threads=set_func, get_num_threads=get_func,
Expand Down