Skip to content

Commit

Permalink
(torch/elastic) add fqdn hostname to error printout (#66182) (#66662)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #66182

closes #63174

Does a few things:

1. adds hostname to the error report
2. moves the "root cause" section to the end (presumably since the logs are being "tailed" we want the root cause to appear at the end)
3. moves redundant error info logging to debug
4. makes the border max 60 char in length and justifies left for the header

NOTE: YOU HAVE TO annotate your main function with torch.distributed.elastic.multiprocessing.errors.record, otherwise no traceback is printed (this is because python exception propagation does NOT work out of the both for IPC - hence the extra record annotation).

Test Plan:
Sample

```
============================================================
run_script_path FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2021-10-05_17:37:22
  host      : devvm4955.prn0.facebook.com
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 3296201)
  error_file: /home/kiuk/tmp/elastic/none_3_lsytqe/attempt_0/0/error.json
  traceback :
  Traceback (most recent call last):
    File "/tmp/jetter.xr3_x6qq/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 372, in wrapper
      return f(*args, **kwargs)
    File "main.py", line 28, in main
      raise RuntimeError(args.throws)
  RuntimeError: foobar

============================================================
```

Reviewed By: cbalioglu, aivanou

Differential Revision: D31416492

fbshipit-source-id: 0aeaf6e634e23ce0ea7f6a03b12c8a9ac57246e9
  • Loading branch information
kiukchung committed Oct 15, 2021
1 parent b544cbd commit 36449ea
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 66 deletions.
45 changes: 28 additions & 17 deletions test/distributed/elastic/multiprocessing/api_test.py
Expand Up @@ -15,7 +15,7 @@
import time
import unittest
from itertools import product
from typing import Dict, List, Union, Callable
from typing import Callable, Dict, List, Union
from unittest import mock
from unittest.mock import patch

Expand All @@ -24,25 +24,25 @@
from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes
from torch.distributed.elastic.multiprocessing.api import (
MultiprocessContext,
SignalException,
RunProcsResult,
SignalException,
Std,
_validate_full_rank,
to_map,
_wrap,
to_map,
)
from torch.distributed.elastic.multiprocessing.errors.error_handler import _write_error
from torch.testing._internal.common_utils import (
IS_IN_CI,
IS_MACOS,
IS_WINDOWS,
NO_MULTIPROCESSING_SPAWN,
TEST_WITH_ASAN,
TEST_WITH_TSAN,
TEST_WITH_DEV_DBG_ASAN,
IS_IN_CI,
IS_WINDOWS,
IS_MACOS,
TEST_WITH_TSAN,
run_tests,
sandcastle_skip_if,
)
from torch.testing._internal.common_utils import run_tests


class RunProcResultsTest(unittest.TestCase):
Expand Down Expand Up @@ -224,6 +224,7 @@ def start_processes_zombie_test(

# tests incompatible with tsan or asan
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):

class StartProcessesTest(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
Expand Down Expand Up @@ -251,12 +252,15 @@ def assert_pids_noexist(self, pids: Dict[int, int]):

def test_to_map(self):
local_world_size = 2
self.assertEqual({0: Std.OUT, 1: Std.OUT}, to_map(Std.OUT, local_world_size))
self.assertEqual(
{0: Std.OUT, 1: Std.OUT}, to_map(Std.OUT, local_world_size)
)
self.assertEqual(
{0: Std.NONE, 1: Std.OUT}, to_map({1: Std.OUT}, local_world_size)
)
self.assertEqual(
{0: Std.ERR, 1: Std.OUT}, to_map({0: Std.ERR, 1: Std.OUT}, local_world_size)
{0: Std.ERR, 1: Std.OUT},
to_map({0: Std.ERR, 1: Std.OUT}, local_world_size),
)

def test_invalid_log_dir(self):
Expand Down Expand Up @@ -382,9 +386,7 @@ def test_void_function(self):
results = pc.wait(period=0.1)
self.assertEqual({0: None, 1: None}, results.return_values)

@sandcastle_skip_if(
TEST_WITH_DEV_DBG_ASAN, "tests incompatible with asan"
)
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "tests incompatible with asan")
def test_function_large_ret_val(self):
# python multiprocessing.queue module uses pipes and actually PipedQueues
# This means that if a single object is greater than a pipe size
Expand Down Expand Up @@ -439,7 +441,9 @@ def test_function_raise(self):
self.assertEqual(1, failure.exitcode)
self.assertEqual("<N/A>", failure.signal_name())
self.assertEqual(pc.pids()[0], failure.pid)
self.assertEqual(os.path.join(log_dir, "0", "error.json"), error_file)
self.assertEqual(
os.path.join(log_dir, "0", "error.json"), error_file
)
self.assertEqual(
int(error_file_data["message"]["extraInfo"]["timestamp"]),
int(failure.timestamp),
Expand Down Expand Up @@ -541,17 +545,22 @@ def test_multiprocessing_context_poll_raises_exception(self):
run_result = mp_context._poll()
self.assertEqual(1, len(run_result.failures))
failure = run_result.failures[0]
self.assertEqual("Signal 1 (SIGHUP) received by PID 123", failure.message)
self.assertEqual(
"Signal 1 (SIGHUP) received by PID 123", failure.message
)


# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):

class StartProcessesListTest(StartProcessesTest):
########################################
# start_processes as binary tests
########################################
def test_function(self):
for start_method, redirs in product(self._start_methods, redirects_oss_test()):
for start_method, redirs in product(
self._start_methods, redirects_oss_test()
):
with self.subTest(start_method=start_method, redirs=redirs):
pc = start_processes(
name="echo",
Expand Down Expand Up @@ -644,6 +653,7 @@ def test_binary_redirect_and_tee(self):

# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_IN_CI):

class StartProcessesNotCITest(StartProcessesTest):
def test_wrap_bad(self):
none = ""
Expand Down Expand Up @@ -796,7 +806,8 @@ def test_function_exit(self):
self.assertEqual(pc.pids()[0], failure.pid)
self.assertEqual("<N/A>", error_file)
self.assertEqual(
f"Process failed with exitcode {FAIL}", failure.message
"To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html",
failure.message,
)
self.assertLessEqual(failure.timestamp, int(time.time()))

Expand Down
9 changes: 6 additions & 3 deletions test/distributed/elastic/multiprocessing/errors/api_test.py
Expand Up @@ -115,7 +115,10 @@ def test_process_failure_no_error_file(self):
pf = self.failure_without_error_file(exitcode=138)
self.assertEqual("<N/A>", pf.signal_name())
self.assertEqual("<N/A>", pf.error_file)
self.assertEqual("Process failed with exitcode 138", pf.message)
self.assertEqual(
"To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html",
pf.message,
)

def test_child_failed_error(self):
pf0 = self.failure_with_error_file(exception=SentinelError("rank 0"))
Expand All @@ -134,7 +137,7 @@ def test_child_failed_error(self):
rank: 0 (local_rank: 0)
exitcode: 1 (pid: 997)
error_file: /tmp/ApiTesttbb37ier/error.json
msg: "SentinelError: rank 0"
traceback: "SentinelError: rank 0"
=============================================
Other Failures:
[1]:
Expand All @@ -148,7 +151,7 @@ def test_child_failed_error(self):
rank: 2 (local_rank: 0)
exitcode: 138 (pid: 997)
error_file: <N/A>
msg: "Process failed with exitcode 138"
traceback: To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
*********************************************
"""
print(ex)
Expand Down
Expand Up @@ -21,10 +21,11 @@
WorkerState,
)
from torch.distributed.elastic.metrics.api import prof
from torch.distributed.elastic.multiprocessing import start_processes, PContext
from torch.distributed.elastic.multiprocessing import PContext, start_processes
from torch.distributed.elastic.utils import macros
from torch.distributed.elastic.utils.logging import get_logger


log = get_logger()


Expand Down
68 changes: 25 additions & 43 deletions torch/distributed/elastic/multiprocessing/errors/__init__.py
Expand Up @@ -51,6 +51,7 @@
import json
import os
import signal
import socket
import time
import warnings
from dataclasses import dataclass, field
Expand Down Expand Up @@ -109,7 +110,7 @@ def __post_init__(self):
try:
with open(self.error_file, "r") as fp:
self.error_file_data = json.load(fp)
log.info(
log.debug(
f"User process failed with error data: {json.dumps(self.error_file_data, indent=2)}"
)
self.message, self.timestamp = self._get_error_data(
Expand All @@ -130,7 +131,7 @@ def __post_init__(self):
f" received by PID {self.pid}"
)
else:
self.message = f"Process failed with exitcode {self.exitcode}"
self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html"

def _get_error_data(self, error_file_data: Dict[str, Any]) -> Tuple[str, int]:
message = error_file_data["message"]
Expand Down Expand Up @@ -162,24 +163,24 @@ def timestamp_isoformat(self):
GlobalRank = int

_FAILURE_FORMAT_TEMPLATE = """[${idx}]:
time: ${time}
rank: ${rank} (local_rank: ${local_rank})
exitcode: ${exitcode} (pid: ${pid})
time : ${time}
host : ${hostname}
rank : ${rank} (local_rank: ${local_rank})
exitcode : ${exitcode} (pid: ${pid})
error_file: ${error_file}
msg: ${message}"""
traceback : ${message}"""

# extra new lines before and after are intentional
_MSG_FORMAT_TEMPLATE = """
${boarder}
${title}
${section}
Root Cause:
${root_failure}
${section}
Other Failures:
Failures:
${other_failures}
${boarder}
"""
${section}
Root Cause (first observed failure):
${root_failure}
${boarder}"""


class ChildFailedError(Exception):
Expand Down Expand Up @@ -230,8 +231,8 @@ def get_first_failure(self) -> Tuple[GlobalRank, ProcessFailure]:
rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp)
return rank, self.failures[rank]

def format_msg(self, boarder_delim="*", section_delim="="):
title = f" {self.name} FAILED "
def format_msg(self, boarder_delim="=", section_delim="-"):
title = f"{self.name} FAILED"
root_rank, root_failure = self.get_first_failure()

root_failure_fmt: str = ""
Expand All @@ -246,11 +247,11 @@ def format_msg(self, boarder_delim="*", section_delim="="):
other_failures_fmt.append(fmt)

# upper boundary on width
width = min(width, 80)
width = min(width, 60)

return Template(_MSG_FORMAT_TEMPLATE).substitute(
boarder=boarder_delim * width,
title=title.center(width),
title=title,
section=section_delim * width,
root_failure=root_failure_fmt,
other_failures="\n".join(other_failures_fmt or [" <NO_OTHER_FAILURES>"]),
Expand Down Expand Up @@ -279,6 +280,7 @@ def _format_failure(
fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute(
idx=idx,
time=failure.timestamp_isoformat(),
hostname=socket.getfqdn(),
rank=rank,
local_rank=failure.local_rank,
exitcode=failure.exitcode,
Expand All @@ -292,32 +294,6 @@ def _format_failure(
return fmt, width


def _no_error_file_warning_msg(rank: int, failure: ProcessFailure) -> str:
msg = [
"CHILD PROCESS FAILED WITH NO ERROR_FILE",
f"Child process {failure.pid} (local_rank {rank}) FAILED (exitcode {failure.exitcode})",
f"Error msg: {failure.message}",
f"Without writing an error file to {failure.error_file}.",
"While this DOES NOT affect the correctness of your application,",
"no trace information about the error will be available for inspection.",
"Consider decorating your top level entrypoint function with",
"torch.distributed.elastic.multiprocessing.errors.record. Example:",
"",
r" from torch.distributed.elastic.multiprocessing.errors import record",
"",
r" @record",
r" def trainer_main(args):",
r" # do train",
]
width = 0
for line in msg:
width = max(width, len(line))

boarder = "*" * width
header = "CHILD PROCESS FAILED WITH NO ERROR_FILE".center(width)
return "\n".join(["\n", boarder, header, boarder, *msg, boarder])


def record(
fn: Callable[..., T], error_handler: Optional[ErrorHandler] = None
) -> Callable[..., T]:
Expand Down Expand Up @@ -372,7 +348,13 @@ def wrapper(*args, **kwargs):
if failure.error_file != _NOT_AVAILABLE:
error_handler.dump_error_file(failure.error_file, failure.exitcode)
else:
warnings.warn(_no_error_file_warning_msg(rank, failure))
log.info(
(
f"local_rank {rank} FAILED with no error file."
f" Decorate your entrypoint fn with @record for traceback info."
f" See: https://pytorch.org/docs/stable/elastic/errors.html"
)
)
raise
except Exception as e:
error_handler.record_exception(e)
Expand Down
Expand Up @@ -107,7 +107,7 @@ def dump_error_file(self, rootcause_error_file: str, error_code: int = 0):
else:
rootcause_error["message"]["errorCode"] = error_code

log.info(
log.debug(
f"child error file ({rootcause_error_file}) contents:\n"
f"{json.dumps(rootcause_error, indent=2)}"
)
Expand Down
23 changes: 22 additions & 1 deletion torch/distributed/run.py
Expand Up @@ -304,6 +304,27 @@ def train():
if should_checkpoint:
save_checkpoint(checkpoint_path)
9. (Recommended) On worker errors, this tool will summarize the details of the error
(e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp)
is heuristically reported as the "Root Cause" error. To get tracebacks as part of this
error summary print out, you must decorate your main entrypoint function in your
training script as shown in the example below. If not decorated, then the summary
will not include the traceback of the exception and will only contain the exitcode.
For details on torchelastic error handling see: https://pytorch.org/docs/stable/elastic/errors.html
::
from torch.distributed.elastic.multiprocessing.errors import record
@record
def main():
# do train
pass
if __name__ == "__main__":
main()
"""
import logging
import os
Expand Down Expand Up @@ -597,7 +618,7 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str
if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1:
omp_num_threads = 1
log.warning(
f"*****************************************\n"
f"\n*****************************************\n"
f"Setting OMP_NUM_THREADS environment variable for each process to be "
f"{omp_num_threads} in default, to avoid your system being overloaded, "
f"please further tune the variable for optimal performance in "
Expand Down

0 comments on commit 36449ea

Please sign in to comment.