/
conv2d_fwd_function.cc
153 lines (138 loc) · 5.99 KB
/
conv2d_fwd_function.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/amp_utils.h"
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
#include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
#include "paddle/fluid/eager/nan_inf_utils.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#pragma GCC diagnostic ignored "-Wunused-variable"
DECLARE_bool(check_nan_inf);
paddle::experimental::Tensor conv2d_dygraph_function(
const paddle::experimental::Tensor& input,
const paddle::experimental::Tensor& filter,
std::vector<int> strides,
std::vector<int> paddings,
std::string paddding_algorithm,
int groups,
std::vector<int> dilations,
std::string data_format,
bool use_addto,
int workspace_size_MB,
bool exhaustive_search) {
// Dygraph Record Event
paddle::platform::RecordEvent dygraph_entrance_record_event(
"conv2d dygraph", paddle::platform::TracerEventType::Operator, 1);
// AMP Logic
if (egr::Controller::Instance().GetAMPLevel() !=
paddle::imperative::AmpLevel::O0) {
VLOG(5) << "Check and Prepare For AMP";
auto op_name = phi::TransToFluidOpName("conv2d");
paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>
amp_tensors_vector = {{input}, {filter}};
auto amp_dst_dtype = egr::GetAmpDestDtype(op_name, amp_tensors_vector);
auto NEW_input =
egr::EagerAmpAutoCast("input", input, amp_dst_dtype, op_name);
auto NEW_filter =
egr::EagerAmpAutoCast("filter", filter, amp_dst_dtype, op_name);
{
paddle::imperative::AutoCastGuard guard(
egr::Controller::Instance().GetCurrentTracer(),
paddle::imperative::AmpLevel::O0);
return conv2d_dygraph_function(NEW_input,
NEW_filter,
strides,
paddings,
paddding_algorithm,
groups,
dilations,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search);
}
}
// Get Input AutoGradMeta
egr::AutogradMeta* input_autograd_meta =
egr::EagerUtils::nullable_autograd_meta(input);
egr::AutogradMeta* filter_autograd_meta =
egr::EagerUtils::nullable_autograd_meta(filter);
// Forward API Call
VLOG(3) << "Final State Running: "
<< "conv2d_dygraph_function";
auto api_result = paddle::experimental::conv2d(input,
filter,
strides,
paddings,
paddding_algorithm,
groups,
dilations,
data_format,
use_addto,
workspace_size_MB,
exhaustive_search);
// Check NaN and Inf if needed
if (FLAGS_check_nan_inf) {
egr::CheckTensorHasNanOrInf("conv2d", api_result);
}
// Get Outputs
auto& out = api_result;
// Get Output AutoGradMeta
egr::AutogradMeta* out_autograd_meta = egr::EagerUtils::autograd_meta(&out);
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad(
trace_backward, input_autograd_meta, filter_autograd_meta);
// Check Inplace if needed
// Node Creation
if (require_any_grad) {
paddle::platform::RecordEvent node_creation_record_event(
"conv2d node_creation",
paddle::platform::TracerEventType::OperatorInner,
1);
egr::EagerUtils::PassStopGradient(false, out_autograd_meta);
// Node Construction
auto grad_node =
std::shared_ptr<Conv2dGradNodeFinal>(new Conv2dGradNodeFinal(1, 2));
// SetAttributes if needed
grad_node->SetAttributestrides(strides);
grad_node->SetAttributepaddings(paddings);
grad_node->SetAttributepaddding_algorithm(paddding_algorithm);
grad_node->SetAttributegroups(groups);
grad_node->SetAttributedilations(dilations);
grad_node->SetAttributedata_format(data_format);
grad_node->SetAttributeuse_addto(use_addto);
grad_node->SetAttributeworkspace_size_MB(workspace_size_MB);
grad_node->SetAttributeexhaustive_search(exhaustive_search);
// Set TensorWrappers for Forward Inputs if needed
grad_node->SetTensorWrapperinput(input);
grad_node->SetTensorWrapperfilter(filter);
// SetGradOutMeta & SetEdges
grad_node->SetGradOutMeta(input, 0);
grad_node->SetGradOutMeta(filter, 1);
// SetOutRank & SetHistory & SetGradInMeta & RetainGrad
if (out_autograd_meta) {
egr::EagerUtils::SetOutRankWithSlot(out_autograd_meta, 0);
}
if (out_autograd_meta) {
egr::EagerUtils::SetHistory(out_autograd_meta, grad_node);
}
grad_node->SetGradInMeta(out, 0);
egr::EagerUtils::CheckAndRetainGrad(out);
// Set TensorWrappers for Forward Outputs if needed
}
// Returns
return out;
}