Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hackathon No.28] implement logcumsumexp #42267

Merged
merged 27 commits into from Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5c3b6bb
implement logcumsumexp
tiancaishaonvjituizi Apr 26, 2022
4054f7c
polish
tiancaishaonvjituizi Apr 26, 2022
b8ade29
Merge remote-tracking branch 'origin/develop' into logcumsumexp
tiancaishaonvjituizi Apr 30, 2022
1f98cc7
fix ci
tiancaishaonvjituizi May 3, 2022
518c75a
reformat
tiancaishaonvjituizi May 3, 2022
e94f42c
update
tiancaishaonvjituizi May 3, 2022
8c680e6
address reviews
tiancaishaonvjituizi May 9, 2022
442bc00
Merge remote-tracking branch 'origin/develop' into logcumsumexp
tiancaishaonvjituizi May 9, 2022
a3e50da
add OpTest
tiancaishaonvjituizi May 9, 2022
0b4b8ca
use user defined grad
tiancaishaonvjituizi May 13, 2022
3bf4cfe
add formula in docs, address reviews
tiancaishaonvjituizi May 13, 2022
34f57f1
remove 'reference' comment
tiancaishaonvjituizi May 14, 2022
661bff3
Update logcumsumexp_grad_kernel.h
tiancaishaonvjituizi May 14, 2022
30241bb
Update logcumsumexp_sig.cc
tiancaishaonvjituizi May 14, 2022
2454012
Update logcumsumexp_grad_impl.h
tiancaishaonvjituizi May 14, 2022
1734440
decrease input size, update python
tiancaishaonvjituizi May 16, 2022
d6a773e
Merge branch 'logcumsumexp' of github.com:tiancaishaonvjituizi/Paddle…
tiancaishaonvjituizi May 16, 2022
3e4953a
shrink test data size
tiancaishaonvjituizi May 17, 2022
790e616
fix sample code
tiancaishaonvjituizi May 17, 2022
9797d10
refine docs
tiancaishaonvjituizi May 18, 2022
3b4b8fe
update docs
tiancaishaonvjituizi May 25, 2022
57bb711
fix docs;test=document_fix
tiancaishaonvjituizi May 27, 2022
4601d17
Merge remote-tracking branch 'origin/develop' into logcumsumexp
tiancaishaonvjituizi May 27, 2022
250998c
Merge remote-tracking branch 'origin/develop' into logcumsumexp
tiancaishaonvjituizi Jun 7, 2022
6a3647c
set test timeout to 30s
tiancaishaonvjituizi Jun 7, 2022
d6c7aa7
Merge remote-tracking branch 'origin/develop' into logcumsumexp
tiancaishaonvjituizi Jun 8, 2022
13edc4f
reformat
tiancaishaonvjituizi Jun 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/operators/cum_op.cc
Expand Up @@ -49,7 +49,7 @@ class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
The cumulative sum of the elements along a given axis.
By default, the first element of the result is the same of the first element of
the input. If exlusive is true, the first element of the result is 0.
the input. If exclusive is true, the first element of the result is 0.
)DOC");
}
};
Expand Down Expand Up @@ -97,7 +97,7 @@ class LogcumsumexpOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Returns the logarithm of the cumulative summation of the exponentiation of elements of input along the given axis.
By default, the first element of the result is the same of the first element of
the input. If exlusive is true, the first element of the result is the minimum value of dtype.
the input. If exclusive is true, the first element of the result is the the lowest finite value of the dtype of output tensor.
tiancaishaonvjituizi marked this conversation as resolved.
Show resolved Hide resolved
)DOC");
}
};
Expand Down
2 changes: 0 additions & 2 deletions paddle/phi/kernels/cpu/cum_kernel.cc
Expand Up @@ -146,8 +146,6 @@ void CumsumKernel(const Context& dev_ctx,
dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out);
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本文件从这一行往上的内容是从 cumsum_kernel.cc 移动过来的,增加了 Reducer 参数


// Copied from
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/scan_ops.h
template <typename T>
struct LogSumExp {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
Expand Down
31 changes: 31 additions & 0 deletions paddle/phi/kernels/gpu/cum_kernel.cu
Expand Up @@ -241,6 +241,37 @@ void ScanKernel(const Context& dev_ctx,
T* out_data = dev_ctx.template Alloc<T>(out);
const T* in_data = x.data<T>();

// Use thrust for parallel acceleration when the input size is equal to the
// length of the ‘axis’ dimension.
if (std::is_same<Op, cub::Sum>::value && size == out_dims[axis]) {
#ifdef __HIPCC__
const auto& policy = thrust::hip::par.on(dev_ctx.stream());
#else
const auto& policy = thrust::cuda::par.on(dev_ctx.stream());
#endif
if (reverse) {
thrust::reverse_iterator<thrust::device_ptr<const T>> reversed_in(
thrust::device_pointer_cast(in_data) + size);
thrust::reverse_iterator<thrust::device_ptr<T>> reversed_out(
thrust::device_pointer_cast(out_data) + size);
if (exclusive) {
thrust::exclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
} else {
thrust::inclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
}
} else {
if (exclusive) {
thrust::exclusive_scan(policy, in_data, in_data + size, out_data);
} else {
thrust::inclusive_scan(policy, in_data, in_data + size, out_data);
}
}
return;
}


size_t height = 1;
size_t width = 1;
for (size_t i = 0; i <= axis; i++) {
Expand Down
9 changes: 0 additions & 9 deletions paddle/utils/variant.h
Expand Up @@ -13,11 +13,6 @@

#pragma once

#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 9
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-copy"
#endif

/*
variant synopsis

Expand Down Expand Up @@ -2833,7 +2828,3 @@ struct hash<paddle::monostate> {
};

} // namespace std

#if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 9
#pragma GCC diagnostic pop
#endif
57 changes: 53 additions & 4 deletions python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py
Expand Up @@ -23,6 +23,11 @@
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.framework import _test_eager_guard
from op_test import OpTest


def np_naive_logcumsumexp(x: np.ndarray, axis: Optional[int]=None):
return np.log(np.cumsum(np.exp(x), axis=axis))


def np_logcumsumexp(x: np.ndarray,
Expand Down Expand Up @@ -56,8 +61,8 @@ def np_logcumsumexp(x: np.ndarray,
return x


class TestLogcumsumexpOp(unittest.TestCase):
def run_cases(self):
class TestLogcumsumexp(unittest.TestCase):
def run_imperative(self):
data_np = np.arange(12, dtype=np.float32).reshape(3, 4)
data = paddle.to_tensor(data_np)

Expand Down Expand Up @@ -86,6 +91,17 @@ def run_cases(self):
with self.assertRaises(IndexError):
y = paddle.logcumsumexp(data, axis=2)

data_np = np.arange(10000, 10024, dtype=np.float32)
data = paddle.to_tensor(data_np)
y = paddle.logcumsumexp(data)
z = np_naive_logcumsumexp(data_np)
# check that naive algorithm overflows
self.assertTrue(all(z == np.inf))
z = np_logcumsumexp(data_np)
# check that our algorithm doesn't overflow
self.assertTrue(all(z != np.inf))
self.assertTrue(np.allclose(z, y.numpy()))

def run_static(self, use_gpu=False):
with fluid.program_guard(fluid.Program()):
data_np = np.random.random((100, 100)).astype(np.float32)
Expand Down Expand Up @@ -120,7 +136,7 @@ def run_static(self, use_gpu=False):

def test_cpu(self):
paddle.disable_static(paddle.fluid.CPUPlace())
self.run_cases()
self.run_imperative()
paddle.enable_static()

self.run_static()
Expand All @@ -129,7 +145,7 @@ def test_gpu(self):
if not fluid.core.is_compiled_with_cuda():
return
paddle.disable_static(paddle.fluid.CUDAPlace(0))
self.run_cases()
self.run_imperative()
paddle.enable_static()

self.run_static(use_gpu=True)
Expand All @@ -154,5 +170,38 @@ def test_type_error(self):
out = exe.run(feed={'X': data_np}, fetch_list=[y.name])


class BaseOpTest(OpTest):
def setUp(self):
self.op_type = "logcumsumexp"
input, attrs = self.input_and_attrs()
self.inputs = {'X': input}
self.attrs = attrs
self.outputs = {'Out': np_logcumsumexp(input)}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Out')

def input_and_attrs(self):
raise NotImplementedError()


def TestLogcumsumexpOp1(BaseOpTest):
def input_and_attrs(self):
return np.random.randn(20, 6), {'axis': 0, 'flatten': True, 'reverse': True}


def TestLogcumsumexpOp2(BaseOpTest):
def input_and_attrs(self):
return np.random.randn(20, 6), {'axis': 1, 'flatten': False, 'reverse': True}


def TestLogcumsumexpOp3(BaseOpTest):
def input_and_attrs(self):
return np.random.randn(20, 6), {'axis': 1, 'flatten': False, 'reverse': False}


if __name__ == '__main__':
unittest.main()
41 changes: 41 additions & 0 deletions python/paddle/tensor/math.py
Expand Up @@ -2971,6 +2971,47 @@ def cumsum(x, axis=None, dtype=None, name=None):


def logcumsumexp(x, axis=None, dtype=None, name=None):
tiancaishaonvjituizi marked this conversation as resolved.
Show resolved Hide resolved
"""
The the logarithm of the cumulative summation of the exponentiation of the elements along a given axis.
tiancaishaonvjituizi marked this conversation as resolved.
Show resolved Hide resolved

**Note**:
tiancaishaonvjituizi marked this conversation as resolved.
Show resolved Hide resolved
The first element of the result is the same of the first element of the input.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里预览没有触发Note的样式,建议加一个缩进
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,已修复


Args:
x (Tensor): The input tensor
tiancaishaonvjituizi marked this conversation as resolved.
Show resolved Hide resolved
axis (int, optional): The dimension to do the operation along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array.
dtype (str, optional): The data type of the output tensor, can be float32, float64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None.
tiancaishaonvjituizi marked this conversation as resolved.
Show resolved Hide resolved
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
Tensor, the result of logcumsumexp operator.

Examples:
.. code-block:: python

import paddle

data = paddle.arange(12)
data = paddle.reshape(data, (3, 4))

y = paddle.logcumsumexp(data)
# [ 0. 1.3132617 2.4076061 3.4401898 4.4519143 5.4561934
# 6.4577627 7.4583397 8.458551 9.45863 10.458658 11.458669 ]

y = paddle.logcumsumexp(data, axis=0)
# [[ 0. 1. 2. 3. ]
# [ 4.01815 5.01815 6.01815 7.01815 ]
# [ 8.018479 9.018479 10.018479 11.018479]]

y = paddle.logcumsumexp(data, axis=-1)
# [[ 0. 1.3132617 2.4076061 3.4401898]
# [ 4. 5.3132615 6.407606 7.44019 ]
# [ 8. 9.313262 10.407606 11.440189 ]]

y = paddle.logcumsumexp(data, dtype='float64')
print(y.dtype)
# paddle.float64
"""
if axis is None:
flatten = True
else:
Expand Down