diff --git a/paddle/fluid/operators/logspace_op.cc b/paddle/fluid/operators/logspace_op.cc new file mode 100644 index 0000000000000..1d1653b053679 --- /dev/null +++ b/paddle/fluid/operators/logspace_op.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + +namespace paddle { +namespace operators { + +class LogspaceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::proto::VarType::Type(ctx.Attr("dtype")), + ctx.GetPlace()); + } +}; + +class LogspaceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Start", + "Exponent of first entry in the sequence. It is a tensor of " + "shape [1], should be of type int32, int64, float32 or float64."); + AddInput("Stop", + "Exponent of last entry in the sequence. It is a tensor of " + "shape [1], should be of type int32, int64, float32 or float64."); + AddInput("Num", + "Number of entry in the sequence. It is a tensor of shape [1], " + "should be of type int32."); + AddInput("Base", + "Base of the logarithm function. It is a tensor of shape [1], " + "should be of type int32, int64, float32 or float64."); + AddAttr("dtype", "The output data type."); + AddOutput("Out", "A sequence of numbers."); + AddComment(R"DOC( + Return fixed number of logarithmical-evenly spaced values within a given + interval. First entry is exponential of Start with base Base, and last + entry is exponential of Stop with base Base. In the case when Num is 1, + only exponential of Start with base Base is returned. If dtype is int32 + or int64, the decimal part of values will be truncated. + Like logspace function of numpy. + )DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(logspace, LogspaceInferShapeFunctor, + PD_INFER_META(phi::LogspaceInferMeta)); +REGISTER_OPERATOR( + logspace, ops::LogspaceOp, ops::LogspaceOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + LogspaceInferShapeFunctor); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 6cf805bc1a127..519d21b323fc2 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1489,6 +1489,43 @@ void InterpolateInferMeta( } } +void LogspaceInferMeta(const MetaTensor& start, + const MetaTensor& stop, + const MetaTensor& number, + const MetaTensor& base, + MetaTensor* out) { + auto s_dims = start.dims(); + PADDLE_ENFORCE_EQ( + (s_dims.size() == 1) && (s_dims[0] == 1), + true, + phi::errors::InvalidArgument("The shape of Input(Start) must be [1]," + "but received input shape is [%s].", + s_dims)); + auto e_dims = stop.dims(); + PADDLE_ENFORCE_EQ( + (e_dims.size() == 1) && (e_dims[0] == 1), + true, + phi::errors::InvalidArgument("The shape of Input(Stop) must be [1]," + "but received input shape is [%s].", + e_dims)); + auto num_dims = number.dims(); + PADDLE_ENFORCE_EQ( + (num_dims.size() == 1) && (num_dims[0] == 1), + true, + phi::errors::InvalidArgument("The shape of Input(Num) must be [1]," + "but received input shape is [%s].", + num_dims)); + auto b_dims = base.dims(); + PADDLE_ENFORCE_EQ( + (b_dims.size() == 1) && (b_dims[0] == 1), + true, + phi::errors::InvalidArgument("The shape of Input(Base) must be [1]," + "but received input shape is [%s].", + b_dims)); + out->set_dims(phi::make_ddim({-1})); + out->set_dtype(start.dtype()); +} + void MeshgridInferMeta(const std::vector& inputs, std::vector outputs) { const size_t inputs_num = inputs.size(); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 557855219bb51..65b5819b602ba 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -228,6 +228,12 @@ void InterpolateInferMeta( MetaTensor* output, MetaConfig config = MetaConfig()); +void LogspaceInferMeta(const MetaTensor& start, + const MetaTensor& stop, + const MetaTensor& number, + const MetaTensor& base, + MetaTensor* out); + void MeshgridInferMeta(const std::vector& inputs, std::vector outputs); diff --git a/paddle/phi/kernels/cpu/logspace_kernel.cc b/paddle/phi/kernels/cpu/logspace_kernel.cc new file mode 100644 index 0000000000000..fbb31057a35ae --- /dev/null +++ b/paddle/phi/kernels/cpu/logspace_kernel.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/logspace_kernel.h" + +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/data_type_transform.h" + +namespace phi { + +template +void LogspaceKernel(const Context& ctx, + const DenseTensor& start, + const DenseTensor& stop, + const DenseTensor& number, + const DenseTensor& base, + DataType dtype, + DenseTensor* out) { + int32_t num = number.data()[0]; + auto start_t = phi::funcs::TransDataType(ctx, start, dtype); + auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype); + auto base_t = phi::funcs::TransDataType(ctx, base, dtype); + + T start_data = start_t.template data()[0]; + T stop_data = stop_t.template data()[0]; + T base_data = base_t.template data()[0]; + PADDLE_ENFORCE_GT( + num, + 0, + phi::errors::InvalidArgument("The num of logspace op should be larger " + "than 0, but received num is %d", + num)); + + out->Resize(phi::make_ddim({num})); + T* out_data = ctx.template Alloc(out); + + if (num > 1) { + // step should be of double type for all types + double step = (static_cast(stop_data - start_data)) / (num - 1); + int half_num = num / 2; + for (int i = 0; i < num; ++i) { + if (i < half_num) { + out_data[i] = + static_cast(std::pow(base_data, start_data + step * i)); + } else { + out_data[i] = static_cast( + std::pow(base_data, stop_data - step * (num - i - 1))); + } + } + } else { + out_data[0] = static_cast(std::pow(base_data, start_data)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(logspace, + CPU, + ALL_LAYOUT, + phi::LogspaceKernel, + float, + int32_t, + int64_t, + double) {} diff --git a/paddle/phi/kernels/gpu/logspace_kernel.cu b/paddle/phi/kernels/gpu/logspace_kernel.cu new file mode 100644 index 0000000000000..f47b7d35cdcda --- /dev/null +++ b/paddle/phi/kernels/gpu/logspace_kernel.cu @@ -0,0 +1,107 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/logspace_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/data_type_transform.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +__global__ void LogspaceKernelInner( + T start, T stop, double step, T base, int64_t size, T* out) { + int64_t index = blockIdx.x * blockDim.x + threadIdx.x; + + for (; index < size; index += blockDim.x * gridDim.x) { + if (index < size / 2) { + out[index] = + static_cast(pow(static_cast(base), + static_cast(start + step * index))); + } else { + out[index] = static_cast( + pow(static_cast(base), + static_cast(stop - step * (size - index - 1)))); + } + } +} + +template +__global__ void LogspaceSpecialKernel(T start, T base, T* out) { + out[0] = static_cast( + pow(static_cast(base), static_cast(start))); +} + +template +void LogspaceKernel(const Context& ctx, + const DenseTensor& start, + const DenseTensor& stop, + const DenseTensor& number, + const DenseTensor& base, + DataType dtype, + DenseTensor* out) { + auto start_t = phi::funcs::TransDataType(ctx, start, dtype); + auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype); + auto base_t = phi::funcs::TransDataType(ctx, base, dtype); + + DenseTensor n_start; + DenseTensor n_stop; + DenseTensor n_num; + DenseTensor n_base; + phi::Copy(ctx, start_t, phi::CPUPlace(), false, &n_start); + T start_data = n_start.data()[0]; + phi::Copy(ctx, stop_t, phi::CPUPlace(), false, &n_stop); + T stop_data = n_stop.data()[0]; + phi::Copy(ctx, number, phi::CPUPlace(), false, &n_num); + int64_t num = static_cast(n_num.data()[0]); + phi::Copy(ctx, base_t, phi::CPUPlace(), false, &n_base); + T base_data = n_base.data()[0]; + + PADDLE_ENFORCE_GT( + num, + 0, + phi::errors::InvalidArgument("The num of logspace op should be larger " + "than 0, but received num is %d", + num)); + + out->Resize(phi::make_ddim({num})); + T* out_data = ctx.template Alloc(out); + + double step = 0; + auto stream = ctx.stream(); + int block = 512; + int grid = (num + block - 1) / block; + if (num != 1) { + step = (static_cast(stop_data - start_data)) / (num - 1); + LogspaceKernelInner<<>>( + start_data, stop_data, step, base_data, num, out_data); + } else { + LogspaceSpecialKernel<<>>( + start_data, base_data, out_data); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(logspace, + GPU, + ALL_LAYOUT, + phi::LogspaceKernel, + float, + int32_t, + int64_t, + double) {} diff --git a/paddle/phi/kernels/logspace_kernel.h b/paddle/phi/kernels/logspace_kernel.h new file mode 100644 index 0000000000000..59862514e78ae --- /dev/null +++ b/paddle/phi/kernels/logspace_kernel.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LogspaceKernel(const Context& ctx, + const DenseTensor& start, + const DenseTensor& stop, + const DenseTensor& number, + const DenseTensor& base, + DataType dtype, + DenseTensor* out); + +} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 3578b9a1aaeea..cb0135d9b4c29 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -89,6 +89,7 @@ from .tensor.creation import diagflat # noqa: F401 from .tensor.creation import eye # noqa: F401 from .tensor.creation import linspace # noqa: F401 +from .tensor.creation import logspace # noqa: F401 from .tensor.creation import ones # noqa: F401 from .tensor.creation import ones_like # noqa: F401 from .tensor.creation import zeros # noqa: F401 @@ -591,6 +592,7 @@ 'sqrt', 'randperm', 'linspace', + 'logspace', 'reshape', 'reshape_', 'reverse', diff --git a/python/paddle/fluid/tests/unittests/test_logspace.py b/python/paddle/fluid/tests/unittests/test_logspace.py new file mode 100644 index 0000000000000..ffa9885e7671e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_logspace.py @@ -0,0 +1,231 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle + + +class TestLogspaceOpCommonCase(OpTest): + def setUp(self): + self.op_type = "logspace" + dtype = 'float32' + self.inputs = { + 'Start': np.array([0]).astype(dtype), + 'Stop': np.array([10]).astype(dtype), + 'Num': np.array([11]).astype('int32'), + 'Base': np.array([2]).astype(dtype), + } + self.attrs = {'dtype': int(paddle.float32)} + + self.outputs = {'Out': np.power(2, np.arange(0, 11)).astype(dtype)} + + def test_check_output(self): + self.check_output() + + +class TestLogspaceOpReverseCase(OpTest): + def setUp(self): + self.op_type = "logspace" + dtype = 'float32' + self.inputs = { + 'Start': np.array([10]).astype(dtype), + 'Stop': np.array([0]).astype(dtype), + 'Num': np.array([11]).astype('int32'), + 'Base': np.array([2]).astype(dtype) + } + self.attrs = {'dtype': int(paddle.float32)} + + self.outputs = {'Out': np.power(2, np.arange(10, -1, -1)).astype(dtype)} + + def test_check_output(self): + self.check_output() + + +class TestLogspaceOpNumOneCase(OpTest): + def setUp(self): + self.op_type = "logspace" + dtype = 'float32' + self.inputs = { + 'Start': np.array([10]).astype(dtype), + 'Stop': np.array([0]).astype(dtype), + 'Num': np.array([1]).astype('int32'), + 'Base': np.array([2]).astype(dtype) + } + self.attrs = {'dtype': int(paddle.float32)} + + self.outputs = {'Out': np.power(2, np.array(10)).astype(dtype)} + + def test_check_output(self): + self.check_output() + + +class TestLogspaceOpMinusBaseCase(OpTest): + def setUp(self): + self.op_type = "logspace" + dtype = 'float32' + self.inputs = { + 'Start': np.array([0]).astype(dtype), + 'Stop': np.array([10]).astype(dtype), + 'Num': np.array([11]).astype('int32'), + 'Base': np.array([-2]).astype(dtype), + } + self.attrs = {'dtype': int(paddle.float32)} + + self.outputs = {'Out': np.power(-2, np.arange(0, 11)).astype(dtype)} + + def test_check_output(self): + self.check_output() + + +class TestLogspaceOpZeroBaseCase(OpTest): + def setUp(self): + self.op_type = "logspace" + dtype = 'float32' + self.inputs = { + 'Start': np.array([0]).astype(dtype), + 'Stop': np.array([10]).astype(dtype), + 'Num': np.array([11]).astype('int32'), + 'Base': np.array([0]).astype(dtype), + } + self.attrs = {'dtype': int(paddle.float32)} + + self.outputs = {'Out': np.power(0, np.arange(0, 11)).astype(dtype)} + + def test_check_output(self): + self.check_output() + + +class TestLogspaceAPI(unittest.TestCase): + def test_variable_input1(self): + paddle.enable_static() + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + start = paddle.full(shape=[1], fill_value=0, dtype='float32') + stop = paddle.full(shape=[1], fill_value=10, dtype='float32') + num = paddle.full(shape=[1], fill_value=5, dtype='int32') + base = paddle.full(shape=[1], fill_value=2, dtype='float32') + out = paddle.logspace(start, stop, num, base, dtype='float32') + + exe = paddle.static.Executor() + res = exe.run(prog, fetch_list=[out]) + np_res = np.logspace(0, 10, 5, base=2, dtype='float32') + self.assertEqual((res == np_res).all(), True) + paddle.disable_static() + + def test_variable_input2(self): + paddle.disable_static() + start = paddle.full(shape=[1], fill_value=0, dtype='float32') + stop = paddle.full(shape=[1], fill_value=10, dtype='float32') + num = paddle.full(shape=[1], fill_value=5, dtype='int32') + base = paddle.full(shape=[1], fill_value=2, dtype='float32') + out = paddle.logspace(start, stop, num, base, dtype='float32') + np_res = np.logspace(0, 10, 5, base=2, dtype='float32') + self.assertEqual((out.numpy() == np_res).all(), True) + paddle.enable_static() + + def test_dtype(self): + paddle.enable_static() + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + out_1 = paddle.logspace(0, 10, 5, 2, dtype='float32') + out_2 = paddle.logspace(0, 10, 5, 2, dtype=np.float32) + + exe = paddle.static.Executor() + res_1, res_2 = exe.run(prog, fetch_list=[out_1, out_2]) + assert np.array_equal(res_1, res_2) + paddle.disable_static() + + def test_name(self): + with paddle.static.program_guard(paddle.static.Program()): + out = paddle.logspace( + 0, 10, 5, 2, dtype='float32', name='logspace_res') + assert 'logspace_res' in out.name + + def test_imperative(self): + paddle.disable_static() + out1 = paddle.logspace(0, 10, 5, 2, dtype='float32') + np_out1 = np.logspace(0, 10, 5, base=2, dtype='float32') + out2 = paddle.logspace(0, 10, 5, 2, dtype='int32') + np_out2 = np.logspace(0, 10, 5, base=2, dtype='int32') + out3 = paddle.logspace(0, 10, 200, 2, dtype='int32') + np_out3 = np.logspace(0, 10, 200, base=2, dtype='int32') + paddle.enable_static() + self.assertEqual((out1.numpy() == np_out1).all(), True) + self.assertEqual((out2.numpy() == np_out2).all(), True) + self.assertEqual((out3.numpy() == np_out3).all(), True) + + +class TestLogspaceOpError(unittest.TestCase): + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + + def test_dtype(): + paddle.logspace(0, 10, 1, 2, dtype="int8") + + self.assertRaises(TypeError, test_dtype) + + def test_dtype1(): + paddle.logspace(0, 10, 1.33, 2, dtype="int32") + + self.assertRaises(TypeError, test_dtype1) + + def test_start_type(): + paddle.logspace([0], 10, 1, 2, dtype="float32") + + self.assertRaises(TypeError, test_start_type) + + def test_end_type(): + paddle.logspace(0, [10], 1, 2, dtype="float32") + + self.assertRaises(TypeError, test_end_type) + + def test_num_type(): + paddle.logspace(0, 10, [0], 2, dtype="float32") + + self.assertRaises(TypeError, test_num_type) + + def test_start_dtype(): + start = paddle.static.data( + shape=[1], dtype="float64", name="start") + paddle.logspace(start, 10, 1, 2, dtype="float32") + + self.assertRaises(ValueError, test_start_dtype) + + def test_end_dtype(): + end = paddle.static.data(shape=[1], dtype="float64", name="end") + paddle.logspace(0, end, 1, 2, dtype="float32") + + self.assertRaises(ValueError, test_end_dtype) + + def test_num_dtype(): + num = paddle.static.data( + shape=[1], dtype="float32", name="step") + paddle.logspace(0, 10, num, 2, dtype="float32") + + self.assertRaises(TypeError, test_num_dtype) + + def test_base_dtype(): + base = paddle.static.data( + shape=[1], dtype="float64", name="end") + paddle.logspace(0, 10, 1, base, dtype="float32") + + self.assertRaises(ValueError, test_base_dtype) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index f2dc16071c2c8..aeec256bc1580 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -146,6 +146,130 @@ def linspace(start, stop, num, dtype=None, name=None): return out +def logspace(start, stop, num, base=10.0, dtype=None, name=None): + r""" + Return fixed number of logarithmical-evenly spaced values within the interval \ + :math:`[base^{start}, base^{stop}]`. + + Notes: + This API does not compute the gradient. + + Args: + start(int|float|Tensor): The input :attr:`start` is exponent of first entry in \ + the sequence. It is a scalar, or a Tensor of shape [1] with input data \ + type int32, int64, float32 or float64. + stop(int|float|Tensor): The input :attr:`stop` is exponent of last entry in the \ + sequence. It is a scalar, or a Tensor of shape [1] with input data \ + type int32, int64, float32 or float64. + num(int|Tensor): The input :attr:`num` is given number of items in the sequence. \ + It is an int scalar, or a Tensor of shape [1] with data type int32. + base(int|float|Tensor): The input :attr:`base` is base of the logarithm function. \ + It is a scalar, or a Tensor of shape [1] with input data type int32, int64, \ + float32 or float64. + dtype(np.dtype|str, optional): The data type of output tensor, it could be \ + int32, int64, float32 or float64. Default: if None, the data type is float32. \ + name(str, optional): Normally there is no need for user to set this property. \ + For more information, please refer to :ref:`api_guide_Name`. Default: None. + + Returns: + Tensor: The output data type will be float32, float64. The 1-D tensor with \ + fixed number of logarithmical-evenly spaced values, the data shape of this \ + tensor is :math:`[num]`. If the :attr:`num` is set 1, the output tensor \ + just has the value with exponential of :attr:`start` with base :attr:`base`. + + Examples: + .. code-block:: python + :name: logspace-example + + import paddle + data = paddle.logspace(0, 10, 5, 2, 'float32') + # [1. , 5.65685415 , 32. , 181.01933289, 1024. ] + data = paddle.logspace(0, 10, 1, 2, 'float32') + # [1.] + """ + if dtype is None: + dtype = 'float32' + tensor_num = num + tensor_start = start + tensor_stop = stop + tensor_base = base + if not isinstance(num, Variable): + check_type(num, 'num', (int), 'logspace') + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) + if not isinstance(start, Variable): + with device_guard("cpu"): + tensor_start = fill_constant([1], dtype, start) + if not isinstance(stop, Variable): + with device_guard("cpu"): + tensor_stop = fill_constant([1], dtype, stop) + if not isinstance(num, Variable): + with device_guard("cpu"): + tensor_num = fill_constant([1], 'int32', num) + if not isinstance(base, Variable): + with device_guard("cpu"): + tensor_base = fill_constant([1], dtype, base) + if _non_static_mode(): + return _C_ops.logspace(tensor_start, tensor_stop, tensor_num, + tensor_base, 'dtype', dtype) + + helper = LayerHelper("logspace", **locals()) + + start_dtype = convert_dtype(tensor_start.dtype) + stop_dtype = convert_dtype(tensor_stop.dtype) + base_dtype = convert_dtype(tensor_base.dtype) + out_dtype = convert_dtype(dtype) + if isinstance(start, Variable): + check_dtype(start.dtype, 'start', + ['float32', 'float64', 'int32', 'int64'], 'logspace') + else: + check_type(start, 'start', (int, float), 'logspace') + + if isinstance(stop, Variable): + check_dtype(stop.dtype, 'stop', + ['float32', 'float64', 'int32', 'int64'], 'logspace') + else: + check_type(stop, 'stop', (int, float), 'logspace') + + if isinstance(num, Variable): + check_dtype(num.dtype, 'num', ['int32'], 'logspace') + + if isinstance(base, Variable): + check_dtype(base.dtype, 'base', + ['float32', 'float64', 'int32', 'int64'], 'logspace') + else: + check_type(base, 'base', (int, float), 'logspace') + + check_dtype(dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'], + 'logspace') + if ((stop_dtype == "float64" or start_dtype == "float64" + or base_dtype == "float64") + and out_dtype in ["float32", "int32"]) or \ + ((stop_dtype == "int64" or start_dtype == "int64" + or base_dtype == "int64") + and out_dtype == "int32"): + raise ValueError( + "The dtype of start/stop/base is {}/{}/{} but the attr(dtype) of logspace is {}, " + "which may cause data type overflows. Please reset attr(dtype) of logspace." + .format(start_dtype, stop_dtype, base_dtype, dtype)) + + out = helper.create_variable_for_type_inference(dtype=dtype) + + helper.append_op( + type='logspace', + inputs={ + 'Start': tensor_start, + 'Stop': tensor_stop, + 'Num': tensor_num, + 'Base': tensor_base + }, + attrs={'dtype': dtype}, + outputs={'Out': [out]}) + if isinstance(num, int): + out.desc.set_shape((num, )) + return out + + @dygraph_only def to_tensor(data, dtype=None, place=None, stop_gradient=True): r""" diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index f907d51e4d038..47b1ba5700e1b 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -306,6 +306,7 @@ 'test_linear_interp_op', 'test_linear_interp_v2_op', 'test_linspace', + 'test_logspace', 'test_load_op', 'test_load_vars_shape_check', 'test_locality_aware_nms_op',