Skip to content

Commit

Permalink
[dnl] add NCCL/PT debug log for S413673
Browse files Browse the repository at this point in the history
Test Plan:
Smoke test w/ NCCL ut cannot repro segfault w/ dynamic register + len 213942272
```
IFNAME=eth2 HOSTS="rtptest908.pci1,rtptest693.pci1" ENVS="NCCL_DEBUG=INFO;NCCL_DEBUG_SUBSYS=INIT,COLL,ALLOC" buck2 run fbcode//mode/opt fbsource//third-party/nccl-exp/v2.18.3-1/src/ctran/tests:ctran_dist_allgather
```
P1223852264

Differential Revision: D56659330
  • Loading branch information
minsii authored and facebook-github-bot committed Apr 27, 2024
1 parent fd24d8c commit ad32a1d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
8 changes: 8 additions & 0 deletions c10/cuda/CUDACachingAllocator.cpp
Expand Up @@ -34,6 +34,8 @@
#include <utility>
#include <vector>

#include <c10/util/Logging.h>

TORCH_SDT_DEFINE_SEMAPHORE(malloc)
TORCH_SDT_DEFINE_SEMAPHORE(free)

Expand Down Expand Up @@ -2796,6 +2798,12 @@ class DeviceCachingAllocator {
cudaStream_t stream,
c10::DeviceIndex device,
std::shared_ptr<GatheredContext> context) {
if (action == TraceEntry::SEGMENT_ALLOC) {
LOG(INFO) << "CacheAllocator: SEGMENT_ALLOC addr=0x" << std::hex << addr
<< std::dec << ", size=" << size
<< ", trace_trackers_.size=" << trace_trackers_.size();
}

if (!record_history && trace_trackers_.empty())
return;

Expand Down
1 change: 1 addition & 0 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Expand Up @@ -875,6 +875,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
// lazyInitCUDA is called (and is a no-op if CUDA is already initialized).
if (useTensorRegisterAllocatorHook_ && !allocatorHooksAttached) {
at::globalContext().lazyInitCUDA();
LOG(INFO) << logPrefix() << "Registered cacheAllocatorRegisterHook";
c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker(
&cacheAllocatorRegisterHook);
c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker(
Expand Down

0 comments on commit ad32a1d

Please sign in to comment.