Skip to content

Commit

Permalink
[Eager]Fix segment_pool/allclose/isclose/scale API bug (#41506)
Browse files Browse the repository at this point in the history
* [Eager]Fix segment_pool/allclose/isclose/scale API bug

* fix kernel register problem
  • Loading branch information
Aurelius84 committed Apr 8, 2022
1 parent 70036d5 commit 0a6fe69
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 15 deletions.
22 changes: 11 additions & 11 deletions paddle/fluid/operators/cast_op.cu
Expand Up @@ -19,15 +19,15 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;

using CUDA = paddle::platform::CUDADeviceContext;
#define REGISTER_CAST_CUDA_BASE(op_name, ...) \
REGISTER_OP_CUDA_KERNEL( \
op_name, ops::CastOpKernel<CUDA, float>, \
ops::CastOpKernel<CUDA, double>, ops::CastOpKernel<CUDA, int>, \
ops::CastOpKernel<CUDA, int64_t>, ops::CastOpKernel<CUDA, int16_t>, \
ops::CastOpKernel<CUDA, bool>, ops::CastOpKernel<CUDA, uint8_t>, \
ops::CastOpKernel<CUDA, plat::float16>, \
ops::CastOpKernel<CUDA, plat::complex<float>>, \
ops::CastOpKernel<CUDA, plat::complex<double>>, ##__VA_ARGS__);

// See [ why register transfer_dtype_op alias with cast_op? ] in cast_op.cc
REGISTER_CAST_CUDA_BASE(transfer_dtype, ops::CastOpKernel<CUDA, plat::bfloat16>)
REGISTER_OP_CUDA_KERNEL(transfer_dtype, ops::CastOpKernel<CUDA, float>,
ops::CastOpKernel<CUDA, double>,
ops::CastOpKernel<CUDA, int>,
ops::CastOpKernel<CUDA, int64_t>,
ops::CastOpKernel<CUDA, int16_t>,
ops::CastOpKernel<CUDA, bool>,
ops::CastOpKernel<CUDA, uint8_t>,
ops::CastOpKernel<CUDA, plat::float16>,
ops::CastOpKernel<CUDA, plat::complex<float>>,
ops::CastOpKernel<CUDA, plat::complex<double>>,
ops::CastOpKernel<CUDA, plat::bfloat16>);
2 changes: 1 addition & 1 deletion python/paddle/incubate/tensor/math.py
Expand Up @@ -222,7 +222,7 @@ def segment_max(data, segment_ids, name=None):
"""

if in_dygraph_mode():
out = _C_ops.final_state_segment_pool(data, segment_ids, "MAX")[0]
out, tmp = _C_ops.final_state_segment_pool(data, segment_ids, "MAX")
return out

if _non_static_mode():
Expand Down
14 changes: 12 additions & 2 deletions python/paddle/tensor/logic.py
Expand Up @@ -127,7 +127,12 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
"""

if in_dygraph_mode():
return _C_ops.final_state_allclose(x, y, rtol, atol, equal_nan)
# NOTE(dev): Pass tol as Tensor to fix precision loss problem, because
# C++ backend will cast it into float32 if passing float from python.
as_tensor = lambda x: paddle.to_tensor([x], dtype='float64', place='cpu')
return _C_ops.final_state_allclose(x, y,
as_tensor(rtol),
as_tensor(atol), equal_nan)
if _in_legacy_dygraph():
return _C_ops.allclose(x, y, 'rtol',
str(rtol), 'atol',
Expand Down Expand Up @@ -689,7 +694,12 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
"""

if in_dygraph_mode():
return _C_ops.final_state_isclose(x, y, rtol, atol, equal_nan)
# NOTE(dev): Pass tol as Tensor to fix precision loss problem, because
# C++ backend will cast it into float32 if passing float from python.
as_tensor = lambda x: paddle.to_tensor([x], dtype='float64', place='cpu')
return _C_ops.final_state_isclose(x, y,
as_tensor(rtol),
as_tensor(atol), equal_nan)
if _in_legacy_dygraph():
return _C_ops.isclose(x, y, 'rtol',
str(rtol), 'atol',
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/utils/code_gen/backward.yaml
Expand Up @@ -1217,7 +1217,7 @@
forward : scale (Tensor x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(out)
args : (Tensor out_grad, Scalar scale=1.0, float bias=0.0, bool bias_after_scale=true)
output : Tensor(x_grad)
invoke : scale(out_grad, scale, bias, bias_after_scale)
invoke : scale(out_grad, scale, 0.0, bias_after_scale)

- backward_api : scatter_grad
forward : scatter (Tensor x, Tensor index, Tensor updates, bool overwrite) -> Tensor(out)
Expand Down Expand Up @@ -1250,6 +1250,7 @@
param : [x]
kernel :
func : segment_pool_grad
data_type : x
optional : summed_ids

- backward_api : selu_grad
Expand Down

0 comments on commit 0a6fe69

Please sign in to comment.