Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] ENH: add TBB support #47

Closed
wants to merge 9 commits into from
Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 37 additions & 6 deletions threadpoolctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ class _dl_phdr_info(ctypes.Structure):
_RTLD_NOLOAD = ctypes.DEFAULT_MODE


# TBB's api to globally control the threadpool is a C++ class `global_control`
# which can't be accessed using ctypes. It can be accessed through tbb4py,
# pythons wrappers for TBB.
try:
import tbb
except ImportError:
tbb = None


# List of the supported implementations. The items hold the prefix of loaded
# shared objects, the name of the internal_api to call, matching the
# MAP_API_TO_FUNC keys and the name of the user_api, in {"blas", "openmp"}.
Expand All @@ -89,6 +98,11 @@ class _dl_phdr_info(ctypes.Structure):
"internal_api": "blis",
"filename_prefixes": ("libblis",),
},
{
"user_api": "tbb",
"internal_api": "tbb",
"filename_prefixes": ("libtbb",),
},
]

# map a internal_api (openmp, openblas, mkl) to set and get functions
Expand Down Expand Up @@ -262,6 +276,9 @@ def _get_version(dynlib, internal_api):
return _get_openblas_version(dynlib)
elif internal_api == "blis":
return _get_blis_version(dynlib)
elif internal_api == "tbb":
# tbb does not expose it's version
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm.. that's an interesting question. adding an issue oneapi-src/oneTBB#189
There is no way to get ready-to-use version in python package release format like 2019.2. TBB has TBB_runtime_interface_version() which return internal version which is not connected to release versions. Would that be enough or you want it to be the same as the package version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we need the version of libtbb DSO. So I guess TBB_runtime_interface_version() is what we want. Is that right ?
Thanks

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I believe so.

return None
else:
raise NotImplementedError("Unsupported API {}".format(internal_api))

Expand Down Expand Up @@ -337,12 +354,26 @@ def _make_module_info(filepath, module_info, prefix):
filepath = os.path.normpath(filepath)
dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD)
internal_api = module_info['internal_api']
set_func = getattr(dynlib,
_MAP_API_TO_FUNC[internal_api]['set_num_threads'],
lambda num_threads: None)
get_func = getattr(dynlib,
_MAP_API_TO_FUNC[internal_api]['get_num_threads'],
lambda: None)

if module_info['user_api'] == 'tbb':
# tbb's api can't be accessed with ctypes. We access it through
# tbb4py.
def set_func(max_threads):
if tbb is not None:
tbb.global_control(
tbb.global_control.max_allowed_parallelism, max_threads)
def get_func():
if tbb is not None:
return tbb.global_control.active_value(
tbb.global_control.max_allowed_parallelism)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move the definition outside of the function scope (something like tbb_*_num_threads). It would make the code more readable IMO.

else:
set_func = getattr(dynlib,
_MAP_API_TO_FUNC[internal_api]['set_num_threads'],
lambda num_threads: None)
get_func = getattr(dynlib,
_MAP_API_TO_FUNC[internal_api]['get_num_threads'],
lambda: None)

module_info = module_info.copy()
module_info.update(dynlib=dynlib, filepath=filepath, prefix=prefix,
set_num_threads=set_func, get_num_threads=get_func,
Expand Down