Skip to content

Commit

Permalink
add new field for event node (#43223)
Browse files Browse the repository at this point in the history
* add new field for event node

* fix

* fix bug

* fix bug

* fix clang

* fix clang format

* fix code format
  • Loading branch information
rainyfly committed Jun 10, 2022
1 parent 6d3a68c commit 06de489
Show file tree
Hide file tree
Showing 12 changed files with 612 additions and 38 deletions.
102 changes: 86 additions & 16 deletions paddle/fluid/platform/profiler/chrometracing_logger.cc
Expand Up @@ -27,7 +27,7 @@ limitations under the License. */
namespace paddle {
namespace platform {

static const char* kSchemaVersion = "1.0.0";
static const char* kSchemaVersion = "1.0.1";
static const char* kDefaultFilename = "pid_%s_time_%s.paddle_trace.json";
static uint32_t span_indx = 0;

Expand All @@ -37,14 +37,6 @@ static std::string DefaultFileName() {
GetStringFormatLocalTime().c_str());
}

const char* ChromeTracingLogger::categary_name_[] = {
"Operator", "Dataloader", "ProfileStep",
"CudaRuntime", "Kernel", "Memcpy",
"Memset", "UserDefined", "OperatorInner",
"Forward", "Backward", "Optimization",
"Communication", "PythonOp", "PythonUserDefined",
"MluRuntime"};

void ChromeTracingLogger::OpenFile() {
output_file_stream_.open(filename_,
std::ofstream::out | std::ofstream::trunc);
Expand Down Expand Up @@ -116,10 +108,41 @@ void ChromeTracingLogger::LogNodeTrees(const NodeTrees& node_trees) {
(*devicenode)->LogMe(this);
}
}
for (auto memnode = (*hostnode)->GetMemTraceEventNodes().begin();
memnode != (*hostnode)->GetMemTraceEventNodes().end(); ++memnode) {
(*memnode)->LogMe(this);
}
}
}
}

void ChromeTracingLogger::LogMemTraceEventNode(
const MemTraceEventNode& mem_node) {
if (!output_file_stream_) {
return;
}
output_file_stream_ << string_format(
std::string(
R"JSON(
{
"name": "[memory]", "pid": %lld, "tid": "%lld",
"ts": %lld,
"ph": "i", "cat": "%s",
"args": {
"place": "%s",
"addr": "%llu",
"current_allocated": %llu,
"current_reserved": %llu,
"increase_bytes": %lld
}
},
)JSON"),
mem_node.ProcessId(), mem_node.ThreadId(), mem_node.TimeStampNs(),
StringTracerMemEventType(mem_node.Type()), mem_node.Place().c_str(),
mem_node.Addr(), mem_node.CurrentAllocated(), mem_node.CurrentReserved(),
mem_node.IncreaseBytes());
}

void ChromeTracingLogger::LogHostTraceEventNode(
const HostTraceEventNode& host_node) {
if (!output_file_stream_) {
Expand All @@ -132,6 +155,16 @@ void ChromeTracingLogger::LogHostTraceEventNode(
} else {
dur_display = string_format(std::string("%.3f us"), dur * 1000);
}
std::map<std::string, std::vector<std::vector<int64_t>>> input_shapes;
std::map<std::string, std::vector<std::string>> input_dtypes;
std::string callstack;
OperatorSupplementEventNode* op_supplement_node =
host_node.GetOperatorSupplementEventNode();
if (op_supplement_node != nullptr) {
input_shapes = op_supplement_node->InputShapes();
input_dtypes = op_supplement_node->Dtypes();
callstack = op_supplement_node->CallStack();
}
switch (host_node.Type()) {
case TracerEventType::ProfileStep:
case TracerEventType::Forward:
Expand Down Expand Up @@ -159,10 +192,48 @@ void ChromeTracingLogger::LogHostTraceEventNode(
host_node.Name().c_str(), dur_display.c_str(), host_node.ProcessId(),
host_node.ThreadId(), nsToUs(host_node.StartNs()),
nsToUsFloat(host_node.Duration()),
categary_name_[static_cast<int>(host_node.Type())],
StringTracerEventType(host_node.Type()),
nsToUsFloat(host_node.StartNs(), start_time_),
nsToUsFloat(host_node.EndNs(), start_time_));
break;

case TracerEventType::Operator:

output_file_stream_ << string_format(
std::string(
R"JSON(
{
"name": "%s[%s]", "pid": %lld, "tid": "%lld(C++)",
"ts": %lld, "dur": %.3f,
"ph": "X", "cat": "%s",
"cname": "thread_state_runnable",
"args": {
"start_time": "%.3f us",
"end_time": "%.3f us",
"input_shapes": %s,
"input_dtypes": %s,
"callstack": "%s"
}
},
)JSON"),
host_node.Name().c_str(), dur_display.c_str(), host_node.ProcessId(),
host_node.ThreadId(), nsToUs(host_node.StartNs()),
nsToUsFloat(host_node.Duration()),
StringTracerEventType(host_node.Type()),
nsToUsFloat(host_node.StartNs(), start_time_),
nsToUsFloat(host_node.EndNs(), start_time_),
json_dict(input_shapes).c_str(), json_dict(input_dtypes).c_str(),
callstack.c_str());
break;
case TracerEventType::CudaRuntime:
case TracerEventType::Kernel:
case TracerEventType::Memcpy:
case TracerEventType::Memset:
case TracerEventType::UserDefined:
case TracerEventType::OperatorInner:
case TracerEventType::Communication:
case TracerEventType::MluRuntime:
case TracerEventType::NumTypes:
default:
output_file_stream_ << string_format(
std::string(
Expand All @@ -181,7 +252,7 @@ void ChromeTracingLogger::LogHostTraceEventNode(
host_node.Name().c_str(), dur_display.c_str(), host_node.ProcessId(),
host_node.ThreadId(), nsToUs(host_node.StartNs()),
nsToUsFloat(host_node.Duration()),
categary_name_[static_cast<int>(host_node.Type())],
StringTracerEventType(host_node.Type()),
nsToUsFloat(host_node.StartNs(), start_time_),
nsToUsFloat(host_node.EndNs(), start_time_));
break;
Expand Down Expand Up @@ -220,8 +291,7 @@ void ChromeTracingLogger::LogRuntimeTraceEventNode(
runtime_node.Name().c_str(), dur_display.c_str(),
runtime_node.ProcessId(), runtime_node.ThreadId(),
nsToUs(runtime_node.StartNs()), nsToUsFloat(runtime_node.Duration()),
categary_name_[static_cast<int>(runtime_node.Type())],
runtime_node.CorrelationId(),
StringTracerEventType(runtime_node.Type()), runtime_node.CorrelationId(),
nsToUsFloat(runtime_node.StartNs(), start_time_),
nsToUsFloat(runtime_node.EndNs(), start_time_));
pid_tid_set_.insert({runtime_node.ProcessId(), runtime_node.ThreadId()});
Expand Down Expand Up @@ -347,7 +417,7 @@ void ChromeTracingLogger::HandleTypeKernel(
device_node.Name().c_str(), dur_display.c_str(), device_node.DeviceId(),
device_node.StreamId(), nsToUs(device_node.StartNs()),
nsToUsFloat(device_node.Duration()),
categary_name_[static_cast<int>(device_node.Type())],
StringTracerEventType(device_node.Type()),
nsToUsFloat(device_node.StartNs(), start_time_),
nsToUsFloat(device_node.EndNs(), start_time_), device_node.DeviceId(),
device_node.ContextId(), device_node.StreamId(),
Expand Down Expand Up @@ -391,7 +461,7 @@ void ChromeTracingLogger::HandleTypeMemcpy(
device_node.Name().c_str(), dur_display.c_str(), device_node.DeviceId(),
device_node.StreamId(), nsToUs(device_node.StartNs()),
nsToUsFloat(device_node.Duration()),
categary_name_[static_cast<int>(device_node.Type())],
StringTracerEventType(device_node.Type()),
nsToUsFloat(device_node.StartNs(), start_time_),
nsToUsFloat(device_node.EndNs(), start_time_), device_node.StreamId(),
device_node.CorrelationId(), memcpy_info.num_bytes, memory_bandwidth);
Expand Down Expand Up @@ -427,7 +497,7 @@ void ChromeTracingLogger::HandleTypeMemset(
device_node.Name().c_str(), dur_display.c_str(), device_node.DeviceId(),
device_node.StreamId(), nsToUs(device_node.StartNs()),
nsToUsFloat(device_node.Duration()),
categary_name_[static_cast<int>(device_node.Type())],
StringTracerEventType(device_node.Type()),
nsToUsFloat(device_node.StartNs(), start_time_),
nsToUsFloat(device_node.EndNs(), start_time_), device_node.DeviceId(),
device_node.ContextId(), device_node.StreamId(),
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/platform/profiler/chrometracing_logger.h
Expand Up @@ -37,6 +37,7 @@ class ChromeTracingLogger : public BaseLogger {
void LogRuntimeTraceEventNode(const CudaRuntimeTraceEventNode&) override;
void LogNodeTrees(const NodeTrees&) override;
void LogMetaInfo(const std::unordered_map<std::string, std::string>);
void LogMemTraceEventNode(const MemTraceEventNode&) override;

private:
void OpenFile();
Expand Down
12 changes: 10 additions & 2 deletions paddle/fluid/platform/profiler/dump/test_serialization_logger.cc
Expand Up @@ -27,7 +27,9 @@ using paddle::platform::HostTraceEventNode;
using paddle::platform::KernelEventInfo;
using paddle::platform::MemcpyEventInfo;
using paddle::platform::MemsetEventInfo;
using paddle::platform::MemTraceEvent;
using paddle::platform::NodeTrees;
using paddle::platform::OperatorSupplementEvent;
using paddle::platform::ProfilerResult;
using paddle::platform::RuntimeTraceEvent;
using paddle::platform::SerializationLogger;
Expand All @@ -37,6 +39,8 @@ TEST(SerializationLoggerTest, dump_case0) {
std::list<HostTraceEvent> host_events;
std::list<RuntimeTraceEvent> runtime_events;
std::list<DeviceTraceEvent> device_events;
std::list<MemTraceEvent> mem_events;
std::list<OperatorSupplementEvent> op_supplement_events;
host_events.push_back(HostTraceEvent(std::string("dataloader#1"),
TracerEventType::Dataloader, 1000, 10000,
10, 10));
Expand Down Expand Up @@ -72,7 +76,8 @@ TEST(SerializationLoggerTest, dump_case0) {
DeviceTraceEvent(std::string("memset1"), TracerEventType::Memset, 66000,
69000, 0, 10, 11, 5, MemsetEventInfo()));
SerializationLogger logger("test_serialization_logger_case0.pb");
NodeTrees tree(host_events, runtime_events, device_events);
NodeTrees tree(host_events, runtime_events, device_events, mem_events,
op_supplement_events);
std::map<uint64_t, std::vector<HostTraceEventNode*>> nodes =
tree.Traverse(true);
EXPECT_EQ(nodes[10].size(), 4u);
Expand Down Expand Up @@ -101,6 +106,8 @@ TEST(SerializationLoggerTest, dump_case1) {
std::list<HostTraceEvent> host_events;
std::list<RuntimeTraceEvent> runtime_events;
std::list<DeviceTraceEvent> device_events;
std::list<MemTraceEvent> mem_events;
std::list<OperatorSupplementEvent> op_supplement_events;
runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch1"), 15000,
17000, 10, 10, 1, 0));
runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch2"), 25000,
Expand All @@ -127,7 +134,8 @@ TEST(SerializationLoggerTest, dump_case1) {
DeviceTraceEvent(std::string("memset1"), TracerEventType::Memset, 66000,
69000, 0, 10, 11, 5, MemsetEventInfo()));
SerializationLogger logger("test_serialization_logger_case1.pb");
NodeTrees tree(host_events, runtime_events, device_events);
NodeTrees tree(host_events, runtime_events, device_events, mem_events,
op_supplement_events);
std::map<uint64_t, std::vector<HostTraceEventNode*>> nodes =
tree.Traverse(true);
EXPECT_EQ(nodes[10].size(), 1u);
Expand Down

0 comments on commit 06de489

Please sign in to comment.