diff --git a/examples/mpi4py-debug.py b/examples/mpi4py-debug.py new file mode 100644 index 00000000..f3c7b396 --- /dev/null +++ b/examples/mpi4py-debug.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +# This example demonstrates how to debug an mpi4py application. +# Run this with 'mpirun -n 2 python mpi4py-debug.py'. +# You can then attach to the debugger by running 'telnet 127.0.0.1 6899' +# (when using the default pudb configuration) in another terminal. + +from mpi4py import MPI +from pudb.remote import debug_remote_on_single_rank + + +def debugged_function(x): + y = x + fail # noqa: F821 + return y + + +# debug 'debugged_function' on rank 0 +debug_remote_on_single_rank(MPI.COMM_WORLD, 0, debugged_function, 42) diff --git a/pudb/remote.py b/pudb/remote.py index 64bf7b1a..87c59246 100644 --- a/pudb/remote.py +++ b/pudb/remote.py @@ -3,6 +3,7 @@ .. autofunction:: set_trace .. autofunction:: debugger +.. autofunction:: debug_remote_on_single_rank """ __copyright__ = """ @@ -43,10 +44,12 @@ import termios import struct import atexit +from typing import Callable, Any from pudb.debugger import Debugger -__all__ = ["PUDB_RDB_HOST", "PUDB_RDB_PORT", "default_port", "debugger", "set_trace"] +__all__ = ["PUDB_RDB_HOST", "PUDB_RDB_PORT", "default_port", "debugger", "set_trace", + "debug_remote_on_single_rank"] default_port = 6899 @@ -237,3 +240,26 @@ def set_trace( return debugger( term_size=term_size, host=host, port=port, reverse=reverse ).set_trace(frame) + + +def debug_remote_on_single_rank(comm: Any, rank: int, func: Callable, + *args: Any, **kwargs: Any) -> None: + """Run a remote debugger on a single rank of an ``mpi4py`` application. + *func* will be called on rank *rank* running in a :class:`RemoteDebugger`, + and will be called normally on all other ranks. + + :param comm: an ``mpi4py`` ``Comm`` object. + :param rank: the rank to debug. All other ranks will spin until this rank exits. + :param func: the callable to debug. + :param args: the arguments passed to ``func``. + :param kwargs: the kwargs passed to ``func``. + """ + if comm.rank == rank: + debugger().runcall(func, *args, **kwargs) + else: + try: + func(*args, **kwargs) + finally: + from time import sleep + while True: + sleep(1)