-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add pnp domain with one op: LinalgSVD #5821
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) ONNX Project Contributors | ||
Check warning Code scanning / lintrunner BLACK-ISORT/format Warning
Run lintrunner -a to apply this patch.
|
||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import numpy as np | ||
|
||
import onnx | ||
from onnx.backend.test.case.base import Base | ||
from onnx.backend.test.case.node import expect | ||
|
||
|
||
class LinalgSVD(Base): | ||
@staticmethod | ||
def export() -> None: | ||
threshold = 1.0 | ||
node = onnx.helper.make_node( | ||
"LinalgSVD", | ||
inputs=["A"], | ||
outputs=["U", "S", "Vh"], | ||
threshold=threshold, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is |
||
domain="ai.onnx.pnp", | ||
) | ||
|
||
A = np.array([ | ||
[ | ||
[-1.125840, -1.152360, -0.250579, -0.433879], | ||
[0.848710, 0.692009, -0.316013, -2.115219], | ||
[0.468096, -0.157712, 1.443660, 0.266049], | ||
], | ||
[ | ||
[0.166455, 0.874382, -0.143474, -0.111609], | ||
[0.931827, 1.259009, 2.004981, 0.053737], | ||
[0.618057, -0.412802, -0.841065, -2.316042] | ||
] | ||
]) | ||
U, S, Vh = np.linalg.svd(A, full_matrices=True) | ||
Check warning Code scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warning Code scanning / lintrunner RUFF/W291 Warning
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
expect(node, inputs=[A], outputs=[U, S, Vh], name="test_ai_onnx_pnp_linalg_svd") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
"ONNX_DOMAIN", | ||
"ONNX_ML_DOMAIN", | ||
"AI_ONNX_PREVIEW_TRAINING_DOMAIN", | ||
"AI_ONNX_PNP_DOMAIN", | ||
"has", | ||
"get_schema", | ||
"get_all_schemas", | ||
|
@@ -25,6 +26,7 @@ | |
ONNX_DOMAIN = "" | ||
ONNX_ML_DOMAIN = "ai.onnx.ml" | ||
AI_ONNX_PREVIEW_TRAINING_DOMAIN = "ai.onnx.preview.training" | ||
AI_ONNX_PNP_DOMAIN = "ai.onnx.pnp" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just thinking aloud: I wonder if "ai.onnx.linalg" would be better? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would think |
||
|
||
|
||
has = C.has_schema | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include "onnx/defs/schema.h" | ||
|
||
namespace ONNX_NAMESPACE { | ||
|
||
// Declare training operators. | ||
|
||
class ONNX_PNP_OPERATOR_SET_SCHEMA_CLASS_NAME(1, LinalgSVD); | ||
|
||
// Iterate over schema from ai.onnx.training version 1 | ||
class OpSet_OnnxPNP_ver1 { | ||
public: | ||
static void ForEachSchema(std::function<void(OpSchema&&)> fn) { | ||
Check warning on line 18 in onnx/defs/operator_sets_pnp.h GitHub Actions / clang-tidy-reviewclang-tidy
|
||
fn(GetOpSchema<ONNX_PNP_OPERATOR_SET_SCHEMA_CLASS_NAME(1, LinalgSVD)>()); | ||
} | ||
}; | ||
|
||
// Register preview operators. | ||
inline void RegisterOnnxPNPOperatorSetSchema() { | ||
// Preview operators should have only one version. | ||
// If changes are needed for a specific preview operator, | ||
// its spec should be modified without increasing its version. | ||
RegisterOpSetSchema<OpSet_OnnxPNP_ver1>(); | ||
} | ||
|
||
} // namespace ONNX_NAMESPACE | ||
Check warning on line 31 in onnx/defs/operator_sets_pnp.h GitHub Actions / Optional Lint
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
Check warning Code scanning / lintrunner CLANGFORMAT/format Warning
See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch. Check warning Code scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Final newline expected
|
||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
#include <algorithm> | ||
#include <cmath> | ||
#include <numeric> | ||
|
||
#include "onnx/defs/schema.h" | ||
|
||
namespace ONNX_NAMESPACE { | ||
|
||
ONNX_PNP_OPERATOR_SET_SCHEMA(LinalgSVD, 1, | ||
Check warning on line 13 in onnx/defs/pnp/def.cc GitHub Actions / clang-tidy-reviewclang-tidy
Check warning on line 13 in onnx/defs/pnp/def.cc GitHub Actions / clang-tidy-reviewclang-tidy
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could just call it "SVD". Especially if we change the domain to "ai.onnx.linalg" as discussed elsewhere. |
||
OpSchema() | ||
.SetDoc(R"DOC(For internal use.)DOC") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would be great to have a longer description. onnx runtimes do not need to declare an operator in onnx to implement a kernel. Is it needed for onnx? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1. I don't understand what "internal use" means here. |
||
.Attr( | ||
"full_matrices", | ||
"", | ||
AttributeProto::INT, | ||
static_cast<int64_t>(1)) | ||
.Input( | ||
0, | ||
"A", | ||
"", | ||
"T") | ||
.Output( | ||
0, | ||
"U", | ||
"", | ||
"T") | ||
.Output( | ||
1, | ||
"S", | ||
"", | ||
"T") | ||
.Output( | ||
2, | ||
"Vh", | ||
"", | ||
"T") | ||
.TypeConstraint( | ||
"T", | ||
{"tensor(float)", "tensor(double)"}, | ||
"") | ||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { | ||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); | ||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1); | ||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2); | ||
int64_t full_matrices = ctx.getAttribute("full_matrices")->i(); | ||
|
||
const TensorShapeProto& A_shape = ctx.getInputType(0)->tensor_type().shape(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These lines should be protected by the usual |
||
const auto& M = A_shape.dim(A_shape.dim_size() - 2); | ||
const auto& N = A_shape.dim(A_shape.dim_size() - 1); | ||
if (!M.has_dim_value() || !N.has_dim_value()) { | ||
// cannot do shape inference without knowing dimension values | ||
return; | ||
} | ||
const auto& K = M.dim_value() < N.dim_value() ? M : N; | ||
auto* u_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); | ||
auto* s_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape(); | ||
auto* v_shape = ctx.getOutputType(2)->mutable_tensor_type()->mutable_shape(); | ||
if (A_shape.dim_size() == 3) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These shape aspects should be explained in the operator documentation: I guess the input is required to be either 2-dim or 3-dim? |
||
const auto& batch_dim = A_shape.dim(0); | ||
*u_shape->add_dim() = batch_dim; | ||
*s_shape->add_dim() = batch_dim; | ||
*v_shape->add_dim() = batch_dim; | ||
} | ||
*u_shape->add_dim() = M; | ||
*u_shape->add_dim() = full_matrices ? M : K; | ||
*s_shape->add_dim() = K; | ||
*v_shape->add_dim() = full_matrices ? N : K; | ||
*v_shape->add_dim() = N; | ||
})); | ||
|
||
} // namespace ONNX_NAMESPACE | ||
Check warning on line 75 in onnx/defs/pnp/def.cc GitHub Actions / Optional Lint
Check warning on line 75 in onnx/defs/pnp/def.cc GitHub Actions / Optional Lint
|
Check warning
Code scanning / lintrunner
RUFF/format Warning