Skip to content

Commit

Permalink
Add some op yaml (#41173)
Browse files Browse the repository at this point in the history
* add real and imag yaml

* add roi_align and roi_pool yaml

* add qr yaml

* add psroi_pool yaml

* fix bug

* fix param bug of psroi_pool

* fix infrt problem

* fix merge bug
  • Loading branch information
zyfncg committed Apr 6, 2022
1 parent 7ed7c6c commit 176df91
Show file tree
Hide file tree
Showing 13 changed files with 185 additions and 18 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/api/lib/CMakeLists.txt
Expand Up @@ -165,7 +165,7 @@ cc_library(context_pool SRCS context_pool.cc DEPS phi_context phi_enforce place)
cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_context kernel_factory context_pool)
cc_library(api_gen_utils SRCS api_gen_utils.cc DEPS phi_tensor_raw selected_rows sparse_csr_tensor sparse_coo_tensor)
cc_library(phi_data_transform SRCS data_transform.cc DEPS phi_tensor_raw transfer_layout_kernel cast_kernel data_device_transform)
cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform backward_infermeta)
cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform)
cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform)

cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl)
Expand Down
57 changes: 57 additions & 0 deletions paddle/phi/api/lib/api_custom_impl.cc
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/meta_tensor.h"
Expand Down Expand Up @@ -716,6 +717,62 @@ std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x,
return x_grad;
}

Tensor imag_grad_impl(const Tensor& out_grad) {
phi::KernelKey kernel_key{ParseBackend(out_grad),
out_grad.layout(),
phi::dtype::ToComplex(out_grad.dtype())};
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"imag_grad", kernel_key);

VLOG(6) << "imag_grad API kernel key: " << kernel_key;
VLOG(6) << "imag_grad API kernel: " << kernel;

auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());

auto dense_out_grad = TensorToDenseTensor(out_grad);

Tensor out;
auto kernel_out = SetKernelOutput(kernel_key.backend(), &out);
phi::MetaTensor meta_out(kernel_out);
phi::RealAndImagGradInferMeta(*dense_out_grad, &meta_out);

using kernel_signature = void (*)(
const phi::DeviceContext&, const phi::DenseTensor&, phi::DenseTensor*);

auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, *dense_out_grad, kernel_out);

return out;
}

Tensor real_grad_impl(const Tensor& out_grad) {
phi::KernelKey kernel_key{ParseBackend(out_grad),
out_grad.layout(),
phi::dtype::ToComplex(out_grad.dtype())};
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"real_grad", kernel_key);

VLOG(6) << "real_grad API kernel key: " << kernel_key;
VLOG(6) << "real_grad API kernel: " << kernel;

auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());

auto dense_out_grad = TensorToDenseTensor(out_grad);

Tensor out;
auto kernel_out = SetKernelOutput(kernel_key.backend(), &out);
phi::MetaTensor meta_out(kernel_out);
phi::RealAndImagGradInferMeta(*dense_out_grad, &meta_out);

using kernel_signature = void (*)(
const phi::DeviceContext&, const phi::DenseTensor&, phi::DenseTensor*);

auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, *dense_out_grad, kernel_out);

return out;
}

std::vector<Tensor> stack_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad,
int axis) {
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/lib/api_custom_impl.h
Expand Up @@ -92,10 +92,16 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
bool trainable_statistics,
bool fuse_with_relu);

/************************ backward api impl ***************************/

std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Scalar& axis);

Tensor imag_grad_impl(const Tensor& x);

Tensor real_grad_impl(const Tensor& x);

std::vector<Tensor> stack_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad,
int axis);
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/backward.cc
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/phi/infermeta/backward.h"

#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"

namespace phi {
Expand Down Expand Up @@ -402,6 +403,12 @@ void PsroiPoolGradInferMeta(const MetaTensor& x,
dx->share_meta(x);
}

void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx) {
dx->set_dims(out_grad.dims());
dx->set_dtype(dtype::ToComplex(out_grad.dtype()));
dx->set_layout(out_grad.layout());
}

void ScatterGradInferMeta(const MetaTensor& index,
const MetaTensor& updates,
const MetaTensor& out_grad,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/backward.h
Expand Up @@ -174,6 +174,8 @@ void PoolGradInferMeta(const MetaTensor& x,
const std::string& padding_algorithm,
MetaTensor* dx);

void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx);

void ScatterGradInferMeta(const MetaTensor& index,
const MetaTensor& updates,
const MetaTensor& out_grad,
Expand Down
8 changes: 5 additions & 3 deletions python/paddle/fluid/tests/unittests/test_psroi_pool_op.py
Expand Up @@ -95,7 +95,8 @@ def set_data(self):
self.pooled_width).astype('float64')
self.inputs = {
'X': self.x,
'ROIs': (self.rois_with_batch_id[:, 1:5], self.rois_lod)
'ROIs': (self.rois_with_batch_id[:, 1:5], self.rois_lod),
'RoisNum': self.boxes_num
}
self.attrs = {
'output_channels': self.output_channels,
Expand Down Expand Up @@ -145,13 +146,14 @@ def make_rois(self):

def setUp(self):
self.op_type = 'psroi_pool'
self.python_api = lambda x, boxes, boxes_num, pooled_height, pooled_width, output_channels, spatial_scale: paddle.vision.ops.psroi_pool(x, boxes, boxes_num, (pooled_height, pooled_width), spatial_scale)
self.set_data()

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)


class TestPSROIPoolDynamicFunctionAPI(unittest.TestCase):
Expand Down
7 changes: 5 additions & 2 deletions python/paddle/fluid/tests/unittests/test_real_imag_op.py
Expand Up @@ -39,6 +39,7 @@ def setUp(self):
paddle.enable_static()
# op test attrs
self.op_type = "real"
self.python_api = paddle.real
self.dtype = np.float64
self.init_input_output()
# backward attrs
Expand All @@ -58,14 +59,15 @@ def init_grad_input_output(self):
self.grad_out.shape)

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(
['X'],
'Out',
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out])
user_defined_grad_outputs=[self.grad_out],
check_eager=True)


class TestImagOp(TestRealOp):
Expand All @@ -74,6 +76,7 @@ def setUp(self):
paddle.enable_static()
# op test attrs
self.op_type = "imag"
self.python_api = paddle.imag
self.dtype = np.float64
self.init_input_output()
# backward attrs
Expand Down
10 changes: 8 additions & 2 deletions python/paddle/fluid/tests/unittests/test_roi_pool_op.py
Expand Up @@ -14,6 +14,7 @@

from __future__ import print_function

import paddle
import unittest
import numpy as np
import math
Expand All @@ -32,6 +33,7 @@ def set_data(self):
self.inputs = {
'X': self.x,
'ROIs': (self.rois[:, 1:5], self.rois_lod),
'RoisNum': self.boxes_num
}

self.attrs = {
Expand Down Expand Up @@ -130,16 +132,20 @@ def make_rois(self):
rois.append(roi)
self.rois_num = len(rois)
self.rois = np.array(rois).astype("float64")
self.boxes_num = np.array(
[bno + 1 for bno in range(self.batch_size)]).astype('int32')

def setUp(self):
self.op_type = "roi_pool"
self.python_api = lambda x, boxes, boxes_num, pooled_height, pooled_width, spatial_scale: paddle.vision.ops.roi_pool(x, boxes, boxes_num, (pooled_height, pooled_width), spatial_scale)
self.python_out_sig = ["Out"]
self.set_data()

def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)


class BadInputTestRoiPool(unittest.TestCase):
Expand Down
11 changes: 8 additions & 3 deletions python/paddle/tensor/attribute.py
Expand Up @@ -18,12 +18,13 @@
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype

# TODO: define functions to get tensor attributes
# TODO: define functions to get tensor attributes
from ..fluid.layers import rank # noqa: F401
from ..fluid.layers import shape # noqa: F401
import paddle
from paddle import _C_ops
from paddle.static import Variable
from ..fluid.framework import _in_legacy_dygraph, in_dygraph_mode

__all__ = []

Expand Down Expand Up @@ -185,7 +186,9 @@ def real(x, name=None):
# [[1., 2., 3.],
# [4., 5., 6.]])
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_real(x)
if _in_legacy_dygraph():
return _C_ops.real(x)

check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'real')
Expand Down Expand Up @@ -229,7 +232,9 @@ def imag(x, name=None):
# [[6., 5., 4.],
# [3., 2., 1.]])
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_imag(x)
if _in_legacy_dygraph():
return _C_ops.imag(x)

check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'imag')
Expand Down
48 changes: 48 additions & 0 deletions python/paddle/utils/code_gen/api.yaml
Expand Up @@ -802,6 +802,15 @@
func : huber_loss
# backward : huber_loss_grad

- api : imag
args : (Tensor x)
output : Tensor
infer_meta :
func : RealAndImagInferMeta
kernel :
func : imag
backward : imag_grad

# increment
- api : increment
args : (Tensor x, float value)
Expand Down Expand Up @@ -1336,6 +1345,16 @@
func : prelu
backward : prelu_grad

- api : psroi_pool
args : (Tensor x, Tensor boxes, Tensor boxes_num, int pooled_height, int pooled_width, int output_channels, float spatial_scale)
output : Tensor
infer_meta :
func : PsroiPoolInferMeta
kernel :
func : psroi_pool
optional : boxes_num
backward : psroi_pool_grad

# put_along_axis
- api : put_along_axis
args : (Tensor x, Tensor index, Tensor value, int axis, str reduce)
Expand All @@ -1348,6 +1367,15 @@
data_type : x
backward : put_along_axis_grad

- api : qr
args : (Tensor x, str mode)
output : Tensor(q), Tensor(r)
infer_meta :
func : QrInferMeta
kernel :
func : qr
# backward : qr_grad

- api : randint
args : (int low, int high, IntArray shape, DataType dtype=DataType::INT64, Place place={})
output : Tensor(out)
Expand All @@ -1372,6 +1400,15 @@
data_type : dtype
backend : place

- api : real
args : (Tensor x)
output : Tensor
infer_meta :
func : RealAndImagInferMeta
kernel :
func : real
backward : real_grad

- api : reciprocal
args : (Tensor x)
output : Tensor
Expand Down Expand Up @@ -1423,6 +1460,17 @@
optional : boxes_num
backward : roi_align_grad

- api : roi_pool
args : (Tensor x, Tensor boxes, Tensor boxes_num, int pooled_height, int pooled_width, float spatial_scale)
output : Tensor(out), Tensor(arg_max)
infer_meta :
func : RoiPoolInferMeta
kernel :
func : roi_pool
optional : boxes_num
intermediate : arg_max
backward : roi_pool_grad

- api : roll
args : (Tensor x, IntArray shifts, int64_t[] axis)
output : Tensor(out)
Expand Down

0 comments on commit 176df91

Please sign in to comment.