diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.cc b/paddle/fluid/eager/custom_operator/custom_operator_node.cc index 1f0a055cbd386..57932ec4c1e69 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_node.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.cc @@ -217,18 +217,20 @@ RunCustomOpNode::operator()( VLOG(6) << "Prepare Grad outputs for size: " << grad_outputs_names.size(); for (size_t i = 0; i < OutputMeta().size(); i++) { if (map[0][0].find(i) != map[0][0].end()) { + int grad_output_idx = map[0][0][i]; VLOG(7) << "Insert grad outputs: " << i - << " with size: " << OutputMeta()[i].size() - << " to tmp_outputs: " << map[0][0][i]; - for (size_t j = 0; j < OutputMeta()[i].size(); j++) { - outs[i].emplace_back(/* init it incase of copy nullptr of shared_ptr */ - std::make_shared( - phi::DataType::UNDEFINED), - egr::Controller::Instance().GenerateUniqueName( - "custom_tmp_grad")); - egr::EagerUtils::autograd_meta(&(outs[i][j])); + << " with size: " << OutputMeta()[grad_output_idx].size() + << " to tmp_outputs: " << grad_output_idx; + for (size_t j = 0; j < OutputMeta()[grad_output_idx].size(); j++) { + outs[grad_output_idx] + .emplace_back(/* init it incase of copy nullptr of shared_ptr */ + std::make_shared( + phi::DataType::UNDEFINED), + egr::Controller::Instance().GenerateUniqueName( + "custom_tmp_grad")); + egr::EagerUtils::autograd_meta(&(outs[grad_output_idx][j])); } - tmp_outs[map[0][0][i]] = outs[i]; + tmp_outs[grad_output_idx] = outs[grad_output_idx]; } } for (size_t i = 0; i < tmp_outs.size(); i++) { @@ -408,17 +410,19 @@ RunCustomOpDoubleGradNode::operator()( for (size_t i = 0; i < OutputMeta().size(); i++) { if (map[1][0].find(i) != map[1][0].end()) { + int grad_output_idx = map[1][0][i]; VLOG(7) << "Insert grad outputs: " << i - << " with size: " << OutputMeta()[i].size() - << " to tmp_outputs: " << map[1][0][i]; - for (size_t j = 0; j < OutputMeta()[i].size(); j++) { - outs[i].emplace_back(/* init it incase of copy nullptr of shared_ptr */ - std::make_shared( - phi::DataType::UNDEFINED), - egr::Controller::Instance().GenerateUniqueName( - "custom_tmp_grad")); + << " with size: " << OutputMeta()[grad_output_idx].size() + << " to tmp_outputs: " << grad_output_idx; + for (size_t j = 0; j < OutputMeta()[grad_output_idx].size(); j++) { + outs[grad_output_idx] + .emplace_back(/* init it incase of copy nullptr of shared_ptr */ + std::make_shared( + phi::DataType::UNDEFINED), + egr::Controller::Instance().GenerateUniqueName( + "custom_tmp_grad")); } - tmp_outs[map[1][0][i]] = outs[i]; + tmp_outs[grad_output_idx] = outs[grad_output_idx]; } } for (size_t i = 0; i < tmp_outs.size(); i++) { diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu index e791ea8cb7600..53b61b4bb6611 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu @@ -44,7 +44,7 @@ __global__ void relu_cuda_double_backward_kernel(const data_t* out_data, data_t* ddout_data, int64_t num) { int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; - for (int64_t i = num; i < num; i += blockDim.x * gridDim.x) { + for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { ddout_data[i] = ddx_data[i] * (out_data[i] > static_cast(0.) ? static_cast(1.) : static_cast(0.)); diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py index 0cc1b19e654bf..25209795db5ba 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py @@ -21,6 +21,7 @@ import tempfile import subprocess import numpy as np +from paddle import fluid from paddle.vision.transforms import Compose, Normalize from paddle.utils.cpp_extension.extension_utils import run_cmd from paddle.fluid.framework import _test_eager_guard @@ -43,12 +44,9 @@ def custom_relu_dynamic(func, device, dtype, np_x, use_func=True): return out.numpy(), t.grad.numpy() -def custom_relu_static(func, - device, - dtype, - np_x, - use_func=True, - test_infer=False): +def custom_relu_static( + func, device, dtype, np_x, use_func=True, test_infer=False +): paddle.enable_static() paddle.set_device(device) @@ -62,9 +60,11 @@ def custom_relu_static(func, exe = static.Executor() exe.run(static.default_startup_program()) # in static mode, x data has been covered by out - out_v = exe.run(static.default_main_program(), - feed={'X': np_x}, - fetch_list=[out.name]) + out_v = exe.run( + static.default_main_program(), + feed={'X': np_x}, + fetch_list=[out.name], + ) paddle.disable_static() return out_v @@ -87,11 +87,11 @@ def custom_relu_static_pe(func, device, dtype, np_x, use_func=True): # in static mode, x data has been covered by out compiled_prog = static.CompiledProgram( - static.default_main_program()).with_data_parallel( - loss_name=out.name, places=places) - out_v = exe.run(compiled_prog, - feed={'X': np_x}, - fetch_list=[out.name]) + static.default_main_program() + ).with_data_parallel(loss_name=out.name, places=places) + out_v = exe.run( + compiled_prog, feed={'X': np_x}, fetch_list=[out.name] + ) paddle.disable_static() return out_v @@ -103,9 +103,9 @@ def custom_relu_static_inference(func, device, np_data, np_label, path_prefix): with static.scope_guard(static.Scope()): with static.program_guard(static.Program()): # simple module - data = static.data(name='data', - shape=[None, 1, 28, 28], - dtype='float32') + data = static.data( + name='data', shape=[None, 1, 28, 28], dtype='float32' + ) label = static.data(name='label', shape=[None, 1], dtype='int64') hidden = static.nn.fc(data, size=128) @@ -124,23 +124,21 @@ def custom_relu_static_inference(func, device, np_data, np_label, path_prefix): # train for i in range(4): - avg_loss_v = exe.run(static.default_main_program(), - feed={ - 'data': np_data, - 'label': np_label - }, - fetch_list=[avg_loss]) + avg_loss_v = exe.run( + static.default_main_program(), + feed={'data': np_data, 'label': np_label}, + fetch_list=[avg_loss], + ) # save inference model static.save_inference_model(path_prefix, [data], [predict], exe) # get train predict value - predict_v = exe.run(static.default_main_program(), - feed={ - 'data': np_data, - 'label': np_label - }, - fetch_list=[predict]) + predict_v = exe.run( + static.default_main_program(), + feed={'data': np_data, 'label': np_label}, + fetch_list=[predict], + ) return predict_v @@ -151,30 +149,37 @@ def custom_relu_double_grad_dynamic(func, device, dtype, np_x, use_func=True): t = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) out = func(t) if use_func else paddle.nn.functional.relu(t) - out.stop_gradient = False - - dx = paddle.grad(outputs=[out], - inputs=[t], - create_graph=True, - retain_graph=True) + dx = paddle.grad( + outputs=out, + inputs=t, + grad_outputs=paddle.ones_like(t), + create_graph=True, + retain_graph=True, + ) - dx[0].backward() + ddout = paddle.grad( + outputs=dx[0], + inputs=out.grad, + grad_outputs=paddle.ones_like(t), + create_graph=False, + ) - assert dx[0].grad is not None - return dx[0].numpy(), dx[0].grad.numpy() + assert ddout[0].numpy() is not None + return dx[0].numpy(), ddout[0].numpy() class TestNewCustomOpSetUpInstall(unittest.TestCase): - def setUp(self): cur_dir = os.path.dirname(os.path.abspath(__file__)) # compile, install the custom op egg into site-packages under background if os.name == 'nt': cmd = 'cd /d {} && python custom_relu_setup.py install'.format( - cur_dir) + cur_dir + ) else: cmd = 'cd {} && {} custom_relu_setup.py install'.format( - cur_dir, sys.executable) + cur_dir, sys.executable + ) run_cmd(cmd) # NOTE(Aurelius84): Normally, it's no need to add following codes for users. @@ -190,16 +195,18 @@ def setUp(self): custom_egg_path = [ x for x in os.listdir(site_dir) if 'custom_relu_module_setup' in x ] - assert len(custom_egg_path - ) == 1, "Matched egg number is %d." % len(custom_egg_path) + assert len(custom_egg_path) == 1, "Matched egg number is %d." % len( + custom_egg_path + ) sys.path.append(os.path.join(site_dir, custom_egg_path[0])) # usage: import the package directly import custom_relu_module_setup + # `custom_relu_dup` is same as `custom_relu_dup` self.custom_ops = [ custom_relu_module_setup.custom_relu, - custom_relu_module_setup.custom_relu_dup + custom_relu_module_setup.custom_relu_dup, ] self.dtypes = ['float32', 'float64'] @@ -222,13 +229,16 @@ def test_static(self): x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) for custom_op in self.custom_ops: out = custom_relu_static(custom_op, device, dtype, x) - pd_out = custom_relu_static(custom_op, device, dtype, x, - False) + pd_out = custom_relu_static( + custom_op, device, dtype, x, False + ) np.testing.assert_array_equal( out, pd_out, - err_msg='custom op out: {},\n paddle api out: {}'. - format(out, pd_out)) + err_msg='custom op out: {},\n paddle api out: {}'.format( + out, pd_out + ), + ) def test_static_pe(self): for device in self.devices: @@ -238,13 +248,16 @@ def test_static_pe(self): x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) for custom_op in self.custom_ops: out = custom_relu_static_pe(custom_op, device, dtype, x) - pd_out = custom_relu_static_pe(custom_op, device, dtype, x, - False) + pd_out = custom_relu_static_pe( + custom_op, device, dtype, x, False + ) np.testing.assert_array_equal( out, pd_out, - err_msg='custom op out: {},\n paddle api out: {}'. - format(out, pd_out)) + err_msg='custom op out: {},\n paddle api out: {}'.format( + out, pd_out + ), + ) def func_dynamic(self): for device in self.devices: @@ -253,20 +266,26 @@ def func_dynamic(self): continue x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) for custom_op in self.custom_ops: - out, x_grad = custom_relu_dynamic(custom_op, device, dtype, - x) + out, x_grad = custom_relu_dynamic( + custom_op, device, dtype, x + ) pd_out, pd_x_grad = custom_relu_dynamic( - custom_op, device, dtype, x, False) + custom_op, device, dtype, x, False + ) np.testing.assert_array_equal( out, pd_out, - err_msg='custom op out: {},\n paddle api out: {}'. - format(out, pd_out)) + err_msg='custom op out: {},\n paddle api out: {}'.format( + out, pd_out + ), + ) np.testing.assert_array_equal( x_grad, pd_x_grad, - err_msg='custom op x grad: {},\n paddle api x grad: {}'. - format(x_grad, pd_x_grad)) + err_msg='custom op x grad: {},\n paddle api x grad: {}'.format( + x_grad, pd_x_grad + ), + ) def test_dynamic(self): with _test_eager_guard(): @@ -279,22 +298,29 @@ def test_static_save_and_load_inference_model(self): np_label = np.random.random((1, 1)).astype("int64") path_prefix = "custom_op_inference/custom_relu" for device in self.devices: - predict = custom_relu_static_inference(self.custom_ops[0], device, - np_data, np_label, - path_prefix) + predict = custom_relu_static_inference( + self.custom_ops[0], device, np_data, np_label, path_prefix + ) # load inference model with static.scope_guard(static.Scope()): exe = static.Executor() - [inference_program, feed_target_names, - fetch_targets] = static.load_inference_model(path_prefix, exe) - predict_infer = exe.run(inference_program, - feed={feed_target_names[0]: np_data}, - fetch_list=fetch_targets) + [ + inference_program, + feed_target_names, + fetch_targets, + ] = static.load_inference_model(path_prefix, exe) + predict_infer = exe.run( + inference_program, + feed={feed_target_names[0]: np_data}, + fetch_list=fetch_targets, + ) np.testing.assert_array_equal( predict, predict_infer, - err_msg='custom op predict: {},\n custom op infer predict: {}' - .format(predict, predict_infer)) + err_msg='custom op predict: {},\n custom op infer predict: {}'.format( + predict, predict_infer + ), + ) paddle.disable_static() def test_static_save_and_run_inference_predictor(self): @@ -304,62 +330,80 @@ def test_static_save_and_run_inference_predictor(self): path_prefix = "custom_op_inference/custom_relu" from paddle.inference import Config from paddle.inference import create_predictor + for device in self.devices: - predict = custom_relu_static_inference(self.custom_ops[0], device, - np_data, np_label, - path_prefix) + predict = custom_relu_static_inference( + self.custom_ops[0], device, np_data, np_label, path_prefix + ) # load inference model - config = Config(path_prefix + ".pdmodel", - path_prefix + ".pdiparams") + config = Config( + path_prefix + ".pdmodel", path_prefix + ".pdiparams" + ) predictor = create_predictor(config) input_tensor = predictor.get_input_handle( - predictor.get_input_names()[0]) + predictor.get_input_names()[0] + ) input_tensor.reshape(np_data.shape) input_tensor.copy_from_cpu(np_data.copy()) predictor.run() output_tensor = predictor.get_output_handle( - predictor.get_output_names()[0]) + predictor.get_output_names()[0] + ) predict_infer = output_tensor.copy_to_cpu() self.assertTrue( np.isclose(predict, predict_infer, rtol=5e-5).any(), "custom op predict: {},\n custom op infer predict: {}".format( - predict, predict_infer)) + predict, predict_infer + ), + ) paddle.disable_static() - def test_func_double_grad_dynamic(self): + def test_double_grad_dynamic(self): + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) for device in self.devices: for dtype in self.dtypes: if device == 'cpu' and dtype == 'float16': continue x = np.random.uniform(-1, 1, [4, 8]).astype(dtype) out, dx_grad = custom_relu_double_grad_dynamic( - self.custom_ops[0], device, dtype, x) + self.custom_ops[0], device, dtype, x + ) pd_out, pd_dx_grad = custom_relu_double_grad_dynamic( - self.custom_ops[0], device, dtype, x, False) + self.custom_ops[0], device, dtype, x, False + ) np.testing.assert_array_equal( out, pd_out, err_msg='custom op out: {},\n paddle api out: {}'.format( - out, pd_out)) + out, pd_out + ), + ) np.testing.assert_array_equal( dx_grad, pd_dx_grad, - err_msg='custom op dx grad: {},\n paddle api dx grad: {}'. - format(dx_grad, pd_dx_grad)) + err_msg='custom op dx grad: {},\n paddle api dx grad: {}'.format( + dx_grad, pd_dx_grad + ), + ) + fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False}) def test_with_dataloader(self): for device in self.devices: paddle.set_device(device) # data loader transform = Compose( - [Normalize(mean=[127.5], std=[127.5], data_format='CHW')]) - train_dataset = paddle.vision.datasets.MNIST(mode='train', - transform=transform) - train_loader = paddle.io.DataLoader(train_dataset, - batch_size=64, - shuffle=True, - drop_last=True, - num_workers=0) + [Normalize(mean=[127.5], std=[127.5], data_format='CHW')] + ) + train_dataset = paddle.vision.datasets.MNIST( + mode='train', transform=transform + ) + train_loader = paddle.io.DataLoader( + train_dataset, + batch_size=64, + shuffle=True, + drop_last=True, + num_workers=0, + ) for batch_id, (image, _) in enumerate(train_loader()): out = self.custom_ops[0](image) @@ -368,7 +412,9 @@ def test_with_dataloader(self): out, pd_out, err_msg='custom op out: {},\n paddle api out: {}'.format( - out, pd_out)) + out, pd_out + ), + ) if batch_id == 5: break