diff --git a/test/dynamo/test_debug_dir.py b/test/dynamo/test_debug_dir.py deleted file mode 100644 index 5827ff40ea781f7..000000000000000 --- a/test/dynamo/test_debug_dir.py +++ /dev/null @@ -1,96 +0,0 @@ -# Owner(s): ["module: dynamo"] -import shutil -import unittest - -import torch -import torch._dynamo.test_case -import torch._dynamo.testing -from torch._dynamo.utils import DebugDir, get_debug_dir - - -class DebugDirTests(torch._dynamo.test_case.TestCase): - @classmethod - def setUpClass(cls): - super().setUpClass() - cls._exit_stack.enter_context( - unittest.mock.patch.object( - torch._dynamo.config, - "debug_dir_root", - "/tmp/torch._dynamo_debug_dirs/", - ) - ) - - @classmethod - def tearDownClass(cls): - shutil.rmtree(torch._dynamo.config.debug_dir_root, ignore_errors=True) - cls._exit_stack.close() - - def setUp(self): - super().setUp() - torch._dynamo.utils.debug_dir = DebugDir() - - def tearDown(self): - torch._dynamo.utils.debug_dir = DebugDir() - super().tearDown() - - def _setup(self): - debug_dir = torch._dynamo.utils.debug_dir - debug_dir.setup() - self.assertIsNotNone(debug_dir.debug_path) - self.assertEqual(debug_dir.num_setup_calls, 1) - return debug_dir - - def test_setup(self): - self._setup() - - def test_clear(self): - debug_dir = self._setup() - debug_dir.clear() - self.assertIsNone(debug_dir.debug_path) - self.assertEqual(debug_dir.num_setup_calls, 0) - - def test_multi_setup_single_clear(self): - debug_dir = self._setup() - prev = get_debug_dir() - - debug_dir.setup() - self.assertEqual(prev, get_debug_dir()) - self.assertEqual(debug_dir.num_setup_calls, 2) - - debug_dir.clear() - self.assertEqual(prev, get_debug_dir()) - self.assertEqual(debug_dir.num_setup_calls, 1) - - def test_multi_setup_multi_clear(self): - debug_dir = self._setup() - prev = get_debug_dir() - - debug_dir.setup() - self.assertEqual(prev, get_debug_dir()) - self.assertEqual(debug_dir.num_setup_calls, 2) - - debug_dir.clear() - self.assertEqual(prev, get_debug_dir()) - self.assertEqual(debug_dir.num_setup_calls, 1) - - debug_dir.clear() - self.assertIsNone(debug_dir.debug_path) - self.assertEqual(debug_dir.num_setup_calls, 0) - - def test_single_setup_single_clear(self): - debug_dir = self._setup() - debug_dir.clear() - self.assertIsNone(debug_dir.debug_path) - self.assertEqual(debug_dir.num_setup_calls, 0) - - def test_multi_get(self): - self._setup() - prev = get_debug_dir() - next = get_debug_dir() - self.assertEqual(prev, next) - - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py index a282485285797be..0cec7d202a9d446 100644 --- a/test/dynamo/test_minifier.py +++ b/test/dynamo/test_minifier.py @@ -43,10 +43,8 @@ def tearDownClass(cls): def setUp(self): super().setUp() - torch._dynamo.utils.debug_dir.setup() def tearDown(self): - torch._dynamo.utils.debug_dir.clear() super().tearDown() def test_after_dynamo(self): diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index ea5671a81d02fe2..0ece930d1d13b25 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -240,7 +240,7 @@ def save_graph_repro(fd, gm, args, compiler_name): def isolate_fails(fx_g, args, compiler_name: str, env=None): if env is None: env = {} - subdir = f"{minifier_dir()}/isolate" + subdir = os.path.join(os.getcwd(), "isolate") if not os.path.exists(subdir): os.makedirs(subdir, exist_ok=True) file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py") @@ -600,10 +600,11 @@ def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False): """ Saves the repro to a repro.py file """ - subdir = os.path.join(minifier_dir()) + curdir = os.getcwd() + subdir = os.path.join(os.getcwd(), "checkpoints") if not os.path.exists(subdir): os.makedirs(subdir, exist_ok=True) - file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py") + file_name = os.path.join(subdir, f"minified_{len(gm.graph.nodes)}_nodes.py") log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}") model_str = NNModuleToString.convert(gm) @@ -613,19 +614,10 @@ def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False): model_str, args, compiler_name, check_accuracy ) ) - latest_repro = os.path.join(subdir, "repro.py") + latest_repro = os.path.join(curdir, "repro.py") log.warning(f"Copying {file_name} to {latest_repro} for convenience") shutil.copyfile(file_name, latest_repro) - local_path = os.path.join(config.base_dir, "repro.py") - try: - shutil.copyfile(file_name, local_path) - log.warning( - f"Copying minified repro from {file_name} to {local_path} for convenience" - ) - except OSError: - log.warning("No write permissions for {local_path}") - # TODO - Commented because we are assuming that nn.Modules can be safely repr'd # If that does not work, we might have to bring this code back. So, keeping it @@ -748,8 +740,6 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): from {config.dynamo_import}.optimizations.backends import BACKENDS from {config.dynamo_import}.testing import rand_strided -{config.dynamo_import}.config.repro_dir = \"{minifier_dir()}\" - args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 9895da4ad9bbab5..d86653f9973cc93 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -103,14 +103,12 @@ def __enter__(self): "Please refer to https://github.com/pytorch/torchdynamo#usage-example " "to use torchdynamo.optimize(...) as an annotation/decorator. " ) - utils.debug_dir.setup() self.on_enter() self.prior = set_eval_frame(self.callback) self.backend_ctx = self.extra_ctx_ctor() self.backend_ctx.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): - utils.debug_dir.clear() set_eval_frame(self.prior) self.prior = unset self.backend_ctx.__exit__(exc_type, exc_val, exc_tb) @@ -152,14 +150,12 @@ def __call__(self, *args, **kwargs): @functools.wraps(fn) def _fn(*args, **kwargs): on_enter() - utils.debug_dir.setup() prior = set_eval_frame(callback) backend_ctx = backend_ctx_ctor() backend_ctx.__enter__() try: return fn(*args, **kwargs) finally: - utils.debug_dir.clear() set_eval_frame(prior) backend_ctx.__exit__(None, None, None) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 1bc646be454356a..ef2c1c38ea8ba5b 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -975,35 +975,13 @@ def recompile_reasons(code): return rpt -class DebugDir: - def __init__(self): - self.num_setup_calls = 0 - self.debug_path = None - - def setup(self): - assert self.num_setup_calls >= 0 - if self.num_setup_calls == 0: - debug_root = config.debug_dir_root - dir_name = "run_" + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") - self.debug_path = os.path.join(debug_root, dir_name) - - self.num_setup_calls += 1 - - def clear(self): - assert self.num_setup_calls >= 0 - if self.num_setup_calls == 1: - self.debug_path = None - - self.num_setup_calls -= 1 - assert self.num_setup_calls >= 0 - - def get(self): - assert self.debug_path is not None - return self.debug_path - - -debug_dir = DebugDir() +# return same dir unless user changes config between calls +@functools.lru_cache(None) +def _get_debug_dir(root_dir): + dir_name = "run_" + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + return os.path.join(root_dir, dir_name) def get_debug_dir(): - return debug_dir.get() + debug_root = config.debug_dir_root + return _get_debug_dir(debug_root)