Skip to content

Commit

Permalink
Merge pull request #8542 from gmarkall/cuda-callback-default-arg
Browse files Browse the repository at this point in the history
CUDA: Make arg optional for Stream.add_callback()
  • Loading branch information
sklam committed Oct 26, 2022
2 parents 1591fc0 + 7950828 commit 598aadf
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
4 changes: 2 additions & 2 deletions numba/cuda/cudadrv/driver.py
Expand Up @@ -2238,7 +2238,7 @@ def auto_synchronize(self):
yield self
self.synchronize()

def add_callback(self, callback, arg):
def add_callback(self, callback, arg=None):
"""
Add a callback to a compute stream.
The user provided function is called from a driver thread once all
Expand All @@ -2256,7 +2256,7 @@ def add_callback(self, callback, arg):
eventual deprecation and may be replaced in a future CUDA release.
:param callback: Callback function with arguments (stream, status, arg).
:param arg: User data to be passed to the callback function.
:param arg: Optional user data to be passed to the callback function.
"""
data = (self, callback, arg)
_py_incref(data)
Expand Down
11 changes: 11 additions & 0 deletions numba/cuda/tests/cudadrv/test_streams.py
Expand Up @@ -29,6 +29,17 @@ def callback(stream, status, event):
stream.add_callback(callback, callback_event)
self.assertTrue(callback_event.wait(1.0))

def test_add_callback_with_default_arg(self):
callback_event = threading.Event()

def callback(stream, status, arg):
self.assertIsNone(arg)
callback_event.set()

stream = cuda.stream()
stream.add_callback(callback)
self.assertTrue(callback_event.wait(1.0))

@with_asyncio_loop
async def test_async_done(self):
stream = cuda.stream()
Expand Down

0 comments on commit 598aadf

Please sign in to comment.