Skip to content

Commit

Permalink
feat: add dump_stats to advanced profiler (#19698)
Browse files Browse the repository at this point in the history
  • Loading branch information
azzhipa committed Mar 26, 2024
1 parent 8b378f0 commit cf1947c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `on_exception` hook to `LightningDataModule` ([#19601](https://github.com/Lightning-AI/pytorch-lightning/pull/19601))

- Added `dump_stats` flag to `AdvancedProfiler` ([#19698](https://github.com/Lightning-AI/pytorch-lightning/issues/19698))

-

### Changed
Expand Down
25 changes: 25 additions & 0 deletions src/lightning/pytorch/profilers/advanced.py
Expand Up @@ -16,12 +16,15 @@
import cProfile
import io
import logging
import os
import pstats
import tempfile
from pathlib import Path
from typing import Dict, Optional, Tuple, Union

from typing_extensions import override

from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.pytorch.profilers.profiler import Profiler

log = logging.getLogger(__name__)
Expand All @@ -40,6 +43,7 @@ def __init__(
dirpath: Optional[Union[str, Path]] = None,
filename: Optional[str] = None,
line_count_restriction: float = 1.0,
dump_stats: bool = False,
) -> None:
"""
Args:
Expand All @@ -54,13 +58,17 @@ def __init__(
reported for each action. either an integer (to select a count of lines),
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
dump_stats: Whether to save raw profiler results. When ``True`` then ``dirpath`` must be provided.
Raises:
ValueError:
If you attempt to stop recording an action which was never started.
"""
super().__init__(dirpath=dirpath, filename=filename)
self.profiled_actions: Dict[str, cProfile.Profile] = {}
self.line_count_restriction = line_count_restriction
self.dump_stats = dump_stats
assert not self.dump_stats or self.dirpath is not None, "dirname must be provided for dump_states to work"

@override
def start(self, action_name: str) -> None:
Expand All @@ -75,10 +83,27 @@ def stop(self, action_name: str) -> None:
raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.")
pr.disable()

def _maybe_dump_stats(self, action_name: str, pr: cProfile.Profile) -> None:
if not self.dump_stats:
return
dst_filepath = os.path.join(self.dirpath, self._prepare_filename(action_name=action_name, extension=".prof"))
dst_fs = get_filesystem(dst_filepath)
dst_fs.mkdirs(self.dirpath, exist_ok=True)
# temporarily save to local since pstats can only dump into a local file
with tempfile.TemporaryDirectory(prefix="test", suffix="test", dir=os.getcwd()) as tmp_dir, dst_fs.open(
dst_filepath, "wb"
) as dst_file:
src_filepath = os.path.join(tmp_dir, "tmp.prof")
pr.dump_stats(src_filepath)
src_fs = get_filesystem(src_filepath)
with src_fs.open(src_filepath, "rb") as src_file:
dst_file.write(src_file.read())

@override
def summary(self) -> str:
recorded_stats = {}
for action_name, pr in self.profiled_actions.items():
self._maybe_dump_stats(action_name, pr)
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats("cumulative")
ps.print_stats(self.line_count_restriction)
Expand Down
19 changes: 19 additions & 0 deletions tests/tests_pytorch/profilers/test_profiler.py
Expand Up @@ -299,6 +299,25 @@ def test_advanced_profiler_describe(tmp_path, advanced_profiler):
assert len(data) > 0


def test_advanced_profiler_dump_states(tmp_path):
advanced_profiler = AdvancedProfiler(dirpath=tmp_path, dump_stats=True)
"""Ensure the profiler dump stats during summary."""
# record at least one event
with advanced_profiler.profile(action_name := "test"):
pass
# dump_stats to file
advanced_profiler.describe()
path = advanced_profiler.dirpath / f"{action_name}.prof"
data = path.read_bytes()
assert len(data) > 0


def test_advanced_profiler_dump_states_needs_dirpath():
"""Ensure the profiler requires dirpath to dump stats."""
with pytest.raises(AssertionError):
AdvancedProfiler(dump_stats=True)


def test_advanced_profiler_value_errors(advanced_profiler):
"""Ensure errors are raised where expected."""
action = "test"
Expand Down

0 comments on commit cf1947c

Please sign in to comment.