Skip to content

Commit

Permalink
Add Gradient-2 with bfloat16
Browse files Browse the repository at this point in the history
Signed-off-by: Thiago Crepaldi <thiagofc@microsoft.com>
  • Loading branch information
thiagocrepaldi committed Apr 24, 2024
1 parent 5ecc0a9 commit 527f4dd
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 8 deletions.
15 changes: 14 additions & 1 deletion onnx/defs/operator_sets_preview.h
Expand Up @@ -8,7 +8,7 @@

namespace ONNX_NAMESPACE {

// Declare training operators.
// Declare training operators version 1

class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Gradient);
class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Momentum);
Expand All @@ -26,12 +26,25 @@ class OpSet_OnnxPreview_ver1 {
}
};

// Declare training operators version 2

class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(2, Gradient);

// Iterate over schema from ai.onnx.training version 2
class OpSet_OnnxPreview_ver2 {
public:
static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(2, Gradient)>());
}
};

// Register preview operators.
inline void RegisterOnnxPreviewOperatorSetSchema() {
// Preview operators should have only one version.
// If changes are needed for a specific preview operator,
// its spec should be modified without increasing its version.
RegisterOpSetSchema<OpSet_OnnxPreview_ver1>();
RegisterOpSetSchema<OpSet_OnnxPreview_ver2>();
}

} // namespace ONNX_NAMESPACE
4 changes: 2 additions & 2 deletions onnx/defs/schema.h
Expand Up @@ -1181,14 +1181,14 @@ class OpSchemaRegistry final : public ISchemaRegistry {
// ONNX's preview domain contains operators subject to change, so
// versining is not meaningful and that domain should have only one
// version.
map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = std::make_pair(1, 1);
map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = std::make_pair(1, 2);
// Version corresponding last release of ONNX. Update this to match with
// the max version above in a *release* version of ONNX. But in other
// versions, the max version may be ahead of the last-release-version.
last_release_version_map_[ONNX_DOMAIN] = 21;
last_release_version_map_[AI_ONNX_ML_DOMAIN] = 5;
last_release_version_map_[AI_ONNX_TRAINING_DOMAIN] = 1;
last_release_version_map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = 1;
last_release_version_map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = 2;
}

const std::unordered_map<std::string, std::pair<int, int>>& Map() const {
Expand Down
10 changes: 5 additions & 5 deletions onnx/defs/training/defs.cc
Expand Up @@ -10,7 +10,7 @@

namespace ONNX_NAMESPACE {

static const char* Gradient_ver1_doc = R"DOC(
static const char* Gradient_ver2_doc = R"DOC(
Gradient operator computes the partial derivatives of a specific tensor w.r.t.
some other tensors. This operator is widely used in gradient-based training
algorithms. To illustrate its use, let's consider a computation graph,
Expand Down Expand Up @@ -138,9 +138,9 @@ auto-differentiation.

ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(
Gradient,
1,
2,
OpSchema()
.SetDoc(Gradient_ver1_doc)
.SetDoc(Gradient_ver2_doc)
.Input(
0,
"Inputs",
Expand Down Expand Up @@ -187,10 +187,10 @@ ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(
"\"zs\" are the minimal independent variable set that determines "
"the value of \"y\".",
AttributeProto::STRING)
.TypeConstraint("T1", OpSchema::all_tensor_types(), "Allow outputs to be any kind of tensor.")
.TypeConstraint("T1", OpSchema::all_tensor_types_ir4(), "Allow outputs to be any kind of tensor.")
.TypeConstraint(
"T2",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
OpSchema::all_float_types_ir4(),
"Allow inputs to be any kind of floating-point tensor."));

static const char* Adagrad_ver1_doc = R"DOC(
Expand Down
196 changes: 196 additions & 0 deletions onnx/defs/training/old.cc
@@ -0,0 +1,196 @@
// /*
// * SPDX-License-Identifier: Apache-2.0
// */

#include <algorithm>
#include <cmath>

// #include "onnx/defs/function.h"

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
#include "onnx/defs/schema.h"

namespace ONNX_NAMESPACE {

static const char* Gradient_ver1_doc = R"DOC(
Gradient operator computes the partial derivatives of a specific tensor w.r.t.
some other tensors. This operator is widely used in gradient-based training
algorithms. To illustrate its use, let's consider a computation graph,
```
X -----.
|
v
W --> Conv --> H --> Gemm --> Y
^
|
Z
```
, where W and Z are trainable tensors. Note that operators' attributes are
omitted for the sake of simplicity. Let dY/dW (dY/dZ) be the gradient of
Y with respect to W (Z). The user can compute gradient by inserting Gradient
operator to form another graph shown below.
```
W --> Conv --> H --> Gemm --> Y
| ^ ^
| | |
| X Z
| | |
| | .----------'
| | | (W/Z/X is the 1st/2nd/3rd input of Gradient as shown in
| | | "xs" followed by "zs")
| v v
'---> Gradient(xs=["W", "Z"], zs=["X"], y="Y")
| |
| '-----------------------------------> dY/dW (1st output of Gradient)
|
'---------------------------------------> dY/dZ (2nd output of Gradient)
```
By definition, the tensor "y" is a function of independent variables in "xs"
and "zs". Since we only compute the gradient of "y" w.r.t. the differentiable
variables in "xs", this Gradient only outputs dY/dW and dY/dZ. Note that "H"
cannot appear in "xs" and "zs". The reason is that "H" can be determined by
tensors "W" and "X" and therefore "H" is not an independent variable.
All outputs are optional. If needed, for example, user can assign an empty
string to the 1st output name of that Gradient to skip the generation of dY/dW.
Note that the concept of optional outputs can also be found in ONNX's RNN, GRU,
and LSTM.
Gradient operator can compute derivative against intermediate tensors. For
example, the gradient of Y with respect to H can be done via
```
W --> Conv --> H --> Gemm --> Y
^ | ^
| | |
X | Z
.-------' |
| .----------'
| | (H/Z is the 1st/2nd input of Gradient as shown in "xs")
v v
Gradient(xs=["H", "Z"], y="Y")
| |
| '-----------------------------------> dY/dH (1st output of Gradient)
|
'---------------------------------------> dY/dZ (2nd output of Gradient)
```
It is possible to represent high-order differentiation using Gradient operators.
For example, given the following linear model:
```
W --> Gemm --> Y --> Loss --> O
^ ^
| |
X L
```
To compute the 2nd order derivative of O with respect to W (denoted by
d^2O/dW^2), one can do
```
W --> Gemm --> Y --> Loss --> O
| ^ ^
| | |
| X .------------L
| | | |
| | | v
+------+-+> Gradient(xs=["X", "W"], zs=["L"], y="O") ---> dO/dX (1st output of Gradient)
| | | |
| | | '---> dO/dW (2nd output of Gradient)
| v v
'---> Gradient(xs=["X", "W"], zs=["L"], y="dO/dW") ---> d(dO/dW)dX (1st output of
| Gradient)
|
|
'---> d^2O/dW^2 (2nd output of Gradient)
```
The tensors named in attributes "xs", "zs", and "y" define the differentiated
computation graph, and the inputs to Gradient node define the values at
which the gradient is computed. We can feed different tensors to the identified
graph. For example, one can compute the gradient of Y with respect to H at
a specific value of H, H_1, by providing that value as an input to the Gradient
node.
```
W --> Conv --> H --> Gemm --> Y
^ ^
| |
X Z
Z_1 (2nd input of Gradient)
|
v
H_1 --> Gradient(xs=["H", "Z"], y="Y") ---> dY/dH when H = H_1 and Y = Y_1.
|
'------------------------------> dY/dZ (2nd output of Gradient)
```
When the inputs of Gradient are the tensors named in "xs" and "zs", the
computation can be optimized. More specifically, intermediate variables in
forward pass can be reused if the gradient is computed via reverse-mode
auto-differentiation.
)DOC";

ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(
Gradient,
1,
OpSchema()
.SetDoc(Gradient_ver1_doc)
.Input(
0,
"Inputs",
"The values fed into graph identified by the attributes. "
"The i-th input is the value of the i-th tensor specified in the "
"concatenated list of the attribute \"xs\" and the attribute "
" \"zs\". For example, if xs=[\"A\", \"B\"] and zs=[\"C\"], the "
"first input is used as the value of symbol \"A\" and the 3rd "
"input is substituted for all the occurrences of \"C\".",
"T1",
OpSchema::Variadic,
false)
.Output(
0,
"Outputs",
"The gradient of the tensor specified by the attribute \"y\" "
"with respect to each of tensors specified in the "
"attribute \"xs\". The i-th output is the gradient of \"y\" with "
"respect to the i-th tensor specified in the attribute \"xs\".",
"T2",
OpSchema::Variadic,
false)
.Attr(
"xs",
"Input tensor names of the differentiated sub-graph. It "
"contains only the necessary differentiated "
"inputs of a (sub-)graph. Variables (usually called "
"intermediate variables) that can be generated from inputs "
"cannot be included in this attribute.",
AttributeProto::STRINGS)
.Attr(
"zs",
"Input tensor names of the differentiated sub-graph. It "
"contains only the necessary non-differentiated "
"inputs of a (sub-)graph. Variables (usually called "
"intermediate variables) that can be generated from inputs "
"cannot be included in this attribute.",
AttributeProto::STRINGS,
OPTIONAL_VALUE)
.Attr(
"y",
"The targeted tensor. It can be viewed as the output of the "
"differentiated function. The attribute \"xs\" and attribute "
"\"zs\" are the minimal independent variable set that determines "
"the value of \"y\".",
AttributeProto::STRING)
.TypeConstraint("T1", OpSchema::all_tensor_types(), "Allow outputs to be any kind of tensor.")
.TypeConstraint(
"T2",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Allow inputs to be any kind of floating-point tensor."));

} // namespace ONNX_NAMESPACE

Check warning on line 196 in onnx/defs/training/old.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnx/defs/training/old.cc:196: At least two spaces is best between code and comments [whitespace/comments] [2]
1 change: 1 addition & 0 deletions onnx/helper.py
Expand Up @@ -76,6 +76,7 @@
("1.14.1", 9, 19, 3, 1),
("1.15.0", 9, 20, 4, 1),
("1.16.0", 10, 21, 5, 1),
("1.17.0", 10, 21, 5, 1),
]

VersionMapType = Dict[Tuple[str, int], int]
Expand Down

0 comments on commit 527f4dd

Please sign in to comment.