Skip to content

Commit

Permalink
[API/OP] Migrate Lstsq op into phi (#44318)
Browse files Browse the repository at this point in the history
* migrate lstsq op

* update

* fix bugs for CIs

* update

* fix bugs

* add uts

* update

* update

* update

* fix bugs of jip

* fix bugs of hip

* update

* update according to review

* update

* update

* update

* update
  • Loading branch information
haohongxiang committed Jul 29, 2022
1 parent ec1e0d5 commit ab2aaf8
Show file tree
Hide file tree
Showing 13 changed files with 1,211 additions and 132 deletions.
105 changes: 23 additions & 82 deletions paddle/fluid/operators/lstsq_op.cc
Expand Up @@ -12,92 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/lstsq_op.h"

#include <string>
#include <vector>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/infermeta/binary.h"

namespace paddle {
namespace operators {

class LstsqOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LstsqOp");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "LstsqOp");

OP_INOUT_CHECK(ctx->HasOutput("Solution"), "Output", "Solution", "LstsqOp");
OP_INOUT_CHECK(ctx->HasOutput("Rank"), "Output", "Rank", "LstsqOp");
OP_INOUT_CHECK(ctx->HasOutput("SingularValues"),
"Output",
"SingularValues",
"LstsqOp");

auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
int x_rank = x_dims.size();
int y_rank = y_dims.size();

PADDLE_ENFORCE_GE(x_rank,
2,
platform::errors::InvalidArgument(
"Expects input tensor x to be not less than "
"2 dimentions, but got dimention %d",
x_rank));
PADDLE_ENFORCE_GE(y_rank,
2,
platform::errors::InvalidArgument(
"Expects input tensor y to be not less than "
"2 dimentions, but got dimention %d",
y_rank));

PADDLE_ENFORCE_EQ(
x_rank,
y_rank,
platform::errors::InvalidArgument(
"Expects input tensor x and y to have the same dimension "
"but got x's dimention [%d] and y's dimention [%d]",
x_rank,
y_rank));

std::vector<int> batch_dims_vec{};
for (int i = 0; i < x_rank - 2; ++i) {
PADDLE_ENFORCE_EQ(
x_dims[i],
y_dims[i],
platform::errors::InvalidArgument(
"Expects input tensor x and y to have the same batch "
"dimension, but got x's batch dimention [%d] and "
"y's batch dimention [%d] in %d-th dim",
x_dims[i],
y_dims[i],
i));
batch_dims_vec.emplace_back(x_dims[i]);
}

PADDLE_ENFORCE_EQ(
x_dims[x_rank - 2],
y_dims[y_rank - 2],
platform::errors::InvalidArgument(
"Expects input tensor x and y to have the same row dimension "
"of the inner-most 2-dims matrix, "
"but got x's row dimention [%d] and y's row dimention [%d]",
x_dims[x_rank - 2],
y_dims[y_rank - 2]));

ctx->SetOutputDim("Rank", phi::make_ddim(batch_dims_vec));

batch_dims_vec.emplace_back(
std::min(x_dims[x_rank - 2], x_dims[x_rank - 1]));
ctx->SetOutputDim("SingularValues", phi::make_ddim(batch_dims_vec));

batch_dims_vec[x_rank - 2] = x_dims[x_rank - 1];
batch_dims_vec.emplace_back(y_dims[x_rank - 1]);
ctx->SetOutputDim("Solution", phi::make_ddim(batch_dims_vec));
}

protected:
// The output of lstsq is always complex-valued even for real-valued inputs
Expand Down Expand Up @@ -133,6 +58,9 @@ class LstsqOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault("gels");
AddOutput("Solution",
"(Tensor), The output Solution tensor with shape (*, n, k).");
AddOutput("Residuals",
"(Tensor), The output Residuals tensor with shape (*, k).")
.AsDispensable();
AddOutput("Rank", "(Tensor), The output Rank tensor with shape (*).");
AddOutput(
"SingularValues",
Expand All @@ -148,8 +76,21 @@ This API processes Lstsq functor for general matrices.
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(lstsq, ops::LstsqOp, ops::LstsqOpMaker)

REGISTER_OP_CPU_KERNEL(lstsq,
ops::LstsqCPUKernel<phi::CPUContext, float>,
ops::LstsqCPUKernel<phi::CPUContext, double>);
DECLARE_INFER_SHAPE_FUNCTOR(lstsq,
LstsqInferShapeFunctor,
PD_INFER_META(phi::LstsqInferMeta));

REGISTER_OPERATOR(lstsq,
ops::LstsqOp,
ops::LstsqOpMaker,
LstsqInferShapeFunctor);

REGISTER_OP_VERSION(lstsq).AddCheckpoint(
R"ROC(
Upgrade lstsq, add 1 outputs [Residuals].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewOutput(
"Residuals",
"Output tensor of lstsq operator, "
"meaning the squared residuals of the calculated solutions."));
1 change: 1 addition & 0 deletions paddle/fluid/pybind/op_function_generator.h
Expand Up @@ -245,6 +245,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
"SavedMean",
"SavedVariance",
"ReserveSpace"}},
{"lstsq", {"Solution", "Residuals", "Rank", "SingularValues"}},
{"inplace_abn",
{"Y",
"MeanOut",
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Expand Up @@ -1425,6 +1425,15 @@
func : logsumexp
backward : logsumexp_grad

- api : lstsq
args : (Tensor x, Tensor y, Scalar rcond, str driver)
output : Tensor(solution), Tensor(residuals), Tensor(rank), Tensor(singular_values)
infer_meta :
func : LstsqInferMeta
dtype : x
kernel :
func : lstsq

- api : lu
args : (Tensor x, bool pivot)
output : Tensor(out), Tensor(pivots), Tensor(infos)
Expand Down
84 changes: 84 additions & 0 deletions paddle/phi/infermeta/binary.cc
Expand Up @@ -2007,6 +2007,90 @@ void TriangularSolveInferMeta(const MetaTensor& x,
out->share_lod(y);
}

void LstsqInferMeta(const MetaTensor& x,
const MetaTensor& y,
const Scalar& rcond,
const std::string& driver,
MetaTensor* solution,
MetaTensor* residuals,
MetaTensor* rank,
MetaTensor* singular_values) {
auto x_dims = x.dims();
auto y_dims = y.dims();
int x_rank = x_dims.size();
int y_rank = y_dims.size();

int m = x_dims[x_rank - 2];
int n = x_dims[x_rank - 1];
int nrhs = y_dims[x_rank - 1];

PADDLE_ENFORCE_GE(
x_rank,
2,
phi::errors::InvalidArgument("Expects input tensor x to be not less than "
"2 dimentions, but got dimention %d",
x_rank));
PADDLE_ENFORCE_GE(
y_rank,
2,
phi::errors::InvalidArgument("Expects input tensor y to be not less than "
"2 dimentions, but got dimention %d",
y_rank));

PADDLE_ENFORCE_EQ(
x_rank,
y_rank,
phi::errors::InvalidArgument(
"Expects input tensor x and y to have the same dimension "
"but got x's dimention [%d] and y's dimention [%d]",
x_rank,
y_rank));

std::vector<int> batch_dims_vec{};
for (int i = 0; i < x_rank - 2; ++i) {
PADDLE_ENFORCE_EQ(x_dims[i],
y_dims[i],
phi::errors::InvalidArgument(
"Expects input tensor x and y to have the same batch "
"dimension, but got x's batch dimention [%d] and "
"y's batch dimention [%d] in %d-th dim",
x_dims[i],
y_dims[i],
i));
batch_dims_vec.emplace_back(x_dims[i]);
}

PADDLE_ENFORCE_EQ(
m,
y_dims[y_rank - 2],
phi::errors::InvalidArgument(
"Expects input tensor x and y to have the same row dimension "
"of the inner-most 2-dims matrix, "
"but got x's row dimention [%d] and y's row dimention [%d]",
m,
y_dims[y_rank - 2]));

rank->set_dims(phi::make_ddim(batch_dims_vec));

if (m > n) {
batch_dims_vec.emplace_back(nrhs);
residuals->set_dims(phi::make_ddim(batch_dims_vec));
batch_dims_vec.pop_back();
} else {
residuals->set_dims(phi::make_ddim({0}));
}
residuals->set_dtype(y.dtype());

batch_dims_vec.emplace_back(std::min(m, n));
singular_values->set_dims(phi::make_ddim(batch_dims_vec));
singular_values->set_dtype(y.dtype());

batch_dims_vec[x_rank - 2] = n;
batch_dims_vec.emplace_back(nrhs);
solution->set_dims(phi::make_ddim(batch_dims_vec));
solution->set_dtype(y.dtype());
}

void YoloBoxInferMeta(const MetaTensor& x,
const MetaTensor& img_size,
const std::vector<int>& anchors,
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/binary.h
Expand Up @@ -288,6 +288,15 @@ void TriangularSolveInferMeta(const MetaTensor& x,
bool unitriangular,
MetaTensor* out);

void LstsqInferMeta(const MetaTensor& x,
const MetaTensor& y,
const Scalar& rcond,
const std::string& driver,
MetaTensor* solution,
MetaTensor* residuals,
MetaTensor* rank,
MetaTensor* singular_values);

void YoloBoxInferMeta(const MetaTensor& x,
const MetaTensor& img_size,
const std::vector<int>& anchors,
Expand Down

0 comments on commit ab2aaf8

Please sign in to comment.