Skip to content

Commit

Permalink
Merge branch 'main' into thiagofc/gradient-2
Browse files Browse the repository at this point in the history
Signed-off-by: Thiago Crepaldi <thiagofc@microsoft.com>
  • Loading branch information
thiagocrepaldi committed Apr 26, 2024
2 parents 527f4dd + 4e7289d commit fcc8731
Show file tree
Hide file tree
Showing 235 changed files with 7,029 additions and 664 deletions.
2,367 changes: 2,359 additions & 8 deletions docs/Changelog.md

Large diffs are not rendered by default.

402 changes: 231 additions & 171 deletions docs/Operators.md

Large diffs are not rendered by default.

Binary file modified onnx/backend/test/data/node/test_acos/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_acos_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_acosh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_acosh_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_asin/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_asin_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_asinh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_asinh_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_atan/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_atan_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_atanh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_atanh_example/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_averagepool_2d_ceil/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_averagepool_2d_pads/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_bernoulli/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_bernoulli_double/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_bernoulli_expanded/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_bernoulli_seed/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose_1d/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose_3d/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose_pad/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_convtranspose_pads/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cos/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cos_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cosh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cosh_example/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_det_2d/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_det_nd/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_dropout_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_dropout_default_mask/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_elu/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_elu_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_elu_example/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_eyelike_with_dtype/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_globalaveragepool/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_globalmaxpool/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gridsample/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gridsample_bicubic/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gridsample_bilinear/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gridsample_nearest/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gru_batchwise/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gru_defaults/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_gru_seq_length/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardsigmoid/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardsigmoid_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardsigmoid_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardswish/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_hardswish_expanded/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_instancenorm_epsilon/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_instancenorm_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_1d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_dilations/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_pads/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_same_lower/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_same_upper/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_2d_strides/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lppool_3d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lstm_batchwise/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lstm_defaults/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_lstm_with_peepholes/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_1d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_ceil/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_dilations/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_pads/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_strides/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_2d_uint8/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_3d_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_maxpool_3d_dilations/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_mish/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_mish_expanded/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NC/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NC_expanded/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NCd1/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NCd1_ii/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NCd1_weight/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_nllloss_NCd1d2/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
�1�?
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
�1�?
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
�1�?
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
�1�?
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
�1�?
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
�1�?
Binary file modified onnx/backend/test/data/node/test_rnn_seq_length/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_roialign_mode_max/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_round/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_selu/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_selu_default/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_selu_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_simple_rnn_batchwise/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_simple_rnn_defaults/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_sin/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_sin_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_sinh/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_sinh_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_softplus/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_softplus_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_softsign/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_softsign_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_tan/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_tan_example/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_thresholdedrelu/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_training_dropout/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
116 changes: 32 additions & 84 deletions onnx/defs/generator/defs.cc
Expand Up @@ -126,7 +126,7 @@ ONNX_OPERATOR_SET_SCHEMA(
}
}));

static const char* EyeLike_ver9_doc = R"DOC(
static const char* EyeLike_ver22_doc = R"DOC(
Generate a 2D tensor (matrix) with ones on the diagonal and zeros everywhere else. Only 2D
tensors are supported, i.e. input T1 must be of rank 2. The shape of the output tensor is the
same as the input tensor. The data type can be specified by the 'dtype' argument. If
Expand All @@ -138,9 +138,9 @@ TensorProto message and be valid as an output type.

ONNX_OPERATOR_SET_SCHEMA(
EyeLike,
9,
22,
OpSchema()
.SetDoc(EyeLike_ver9_doc)
.SetDoc(EyeLike_ver22_doc)
.Attr(
"k",
"(Optional) Index of the diagonal to be populated with ones. Default is 0."
Expand All @@ -159,33 +159,11 @@ ONNX_OPERATOR_SET_SCHEMA(
.Output(0, "output", "Output tensor, same shape as input tensor T1.", "T2")
.TypeConstraint(
"T1",
{"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(bool)"},
OpSchema::all_non_complex_numeric_types_plus_bool_ir4(),
"Constrain input types. Strings and complex are not supported.")
.TypeConstraint(
"T2",
{"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(bool)"},
OpSchema::all_non_complex_numeric_types_plus_bool_ir4(),
"Constrain output types. Strings and complex are not supported.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (ctx.getAttribute("dtype") != nullptr) {
Expand All @@ -202,7 +180,7 @@ ONNX_OPERATOR_SET_SCHEMA(
propagateShapeFromInputToOutput(ctx, 0, 0);
}));

static const char* RandomUniform_ver1_doc = R"DOC(
static const char* RandomUniform_ver22_doc = R"DOC(
Generate a tensor with random values drawn from a uniform distribution. The shape
of the tensor is specified by the `shape` argument and the range by `low` and `high`.
Expand All @@ -213,9 +191,9 @@ TensorProto message.

ONNX_OPERATOR_SET_SCHEMA(
RandomUniform,
1,
22,
OpSchema()
.SetDoc(RandomUniform_ver1_doc)
.SetDoc(RandomUniform_ver22_doc)
.Attr("low", "Lower boundary of the output values.", AttributeProto::FLOAT, 0.0f)
.Attr("high", "Upper boundary of the output values.", AttributeProto::FLOAT, 1.0f)
.Attr(
Expand All @@ -230,16 +208,13 @@ ONNX_OPERATOR_SET_SCHEMA(
static_cast<int64_t>(TensorProto::FLOAT))
.Attr("shape", "The shape of the output tensor.", AttributeProto::INTS)
.Output(0, "output", "Output tensor of random values drawn from uniform distribution", "T")
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain output types to float tensors.")
.TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0, TensorProto::FLOAT);
propagateShapeFromAttributeToOutput(ctx, "shape", 0);
}));

static const char* RandomNormal_ver1_doc = R"DOC(
static const char* RandomNormal_ver22_doc = R"DOC(
Generate a tensor with random values drawn from a normal distribution. The shape
of the tensor is specified by the `shape` argument and the parameter of the normal distribution
specified by `mean` and `scale`.
Expand All @@ -251,9 +226,9 @@ TensorProto message.

ONNX_OPERATOR_SET_SCHEMA(
RandomNormal,
1,
22,
OpSchema()
.SetDoc(RandomNormal_ver1_doc)
.SetDoc(RandomNormal_ver22_doc)
.Attr("mean", "The mean of the normal distribution.", AttributeProto::FLOAT, 0.0f)
.Attr("scale", "The standard deviation of the normal distribution.", AttributeProto::FLOAT, 1.0f)
.Attr(
Expand All @@ -268,16 +243,13 @@ ONNX_OPERATOR_SET_SCHEMA(
static_cast<int64_t>(TensorProto::FLOAT))
.Attr("shape", "The shape of the output tensor.", AttributeProto::INTS)
.Output(0, "output", "Output tensor of random values drawn from normal distribution", "T")
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain output types to float tensors.")
.TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0, TensorProto::FLOAT);
propagateShapeFromAttributeToOutput(ctx, "shape", 0);
}));

static const char* RandomUniformLike_ver1_doc = R"DOC(
static const char* RandomUniformLike_ver22_doc = R"DOC(
Generate a tensor with random values drawn from a uniform distribution.
The shape of the output tensor is copied from the shape of the input tensor,
and the parameters of the uniform distribution are specified by `low` and `high`.
Expand All @@ -289,9 +261,9 @@ TensorProto message and be valid as an output type.

ONNX_OPERATOR_SET_SCHEMA(
RandomUniformLike,
1,
22,
OpSchema()
.SetDoc(RandomUniformLike_ver1_doc)
.SetDoc(RandomUniformLike_ver22_doc)
.Attr("low", "Lower boundary of the output values.", AttributeProto::FLOAT, 0.0f)
.Attr("high", "Upper boundary of the output values.", AttributeProto::FLOAT, 1.0f)
.Attr(
Expand All @@ -309,12 +281,9 @@ ONNX_OPERATOR_SET_SCHEMA(
.Output(0, "output", "Output tensor of random values drawn from uniform distribution", "T2")
.TypeConstraint(
"T1",
OpSchema::all_tensor_types(),
OpSchema::all_tensor_types_ir4(),
"Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.")
.TypeConstraint(
"T2",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain output types to float tensors.")
.TypeConstraint("T2", OpSchema::all_float_types_ir4(), "Constrain output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (ctx.getAttribute("dtype") != nullptr)
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0);
Expand All @@ -326,7 +295,7 @@ ONNX_OPERATOR_SET_SCHEMA(
propagateShapeFromInputToOutput(ctx, 0, 0);
}));

static const char* RandomNormalLike_ver1_doc = R"DOC(
static const char* RandomNormalLike_ver22_doc = R"DOC(
Generate a tensor with random values drawn from a normal distribution.
The shape of the output tensor is copied from the shape of the input tensor,
and the parameters of the normal distribution are specified by `mean` and `scale`.
Expand All @@ -338,9 +307,9 @@ TensorProto message, and be valid as an output type.

ONNX_OPERATOR_SET_SCHEMA(
RandomNormalLike,
1,
22,
OpSchema()
.SetDoc(RandomNormalLike_ver1_doc)
.SetDoc(RandomNormalLike_ver22_doc)
.Attr("mean", "The mean of the normal distribution.", AttributeProto::FLOAT, 0.0f)
.Attr("scale", "The standard deviation of the normal distribution.", AttributeProto::FLOAT, 1.0f)
.Attr(
Expand All @@ -358,12 +327,9 @@ ONNX_OPERATOR_SET_SCHEMA(
.Output(0, "output", "Output tensor of random values drawn from normal distribution", "T2")
.TypeConstraint(
"T1",
OpSchema::all_tensor_types(),
OpSchema::all_tensor_types_ir4(),
"Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.")
.TypeConstraint(
"T2",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain output types to float tensors.")
.TypeConstraint("T2", OpSchema::all_float_types_ir4(), "Constrain output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (ctx.getAttribute("dtype") != nullptr)
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0);
Expand All @@ -375,16 +341,16 @@ ONNX_OPERATOR_SET_SCHEMA(
propagateShapeFromInputToOutput(ctx, 0, 0);
}));

static const char* Multinomial_ver7_doc = R"DOC(
static const char* Multinomial_ver22_doc = R"DOC(
Generate a tensor of samples from a multinomial distribution according to the probabilities
of each of the possible outcomes.
)DOC";

ONNX_OPERATOR_SET_SCHEMA(
Multinomial,
7,
22,
OpSchema()
.SetDoc(Multinomial_ver7_doc)
.SetDoc(Multinomial_ver22_doc)
.Attr("sample_size", "Number of times to sample.", AttributeProto::INT, static_cast<int64_t>(1))
.Attr(
"seed",
Expand All @@ -406,10 +372,7 @@ ONNX_OPERATOR_SET_SCHEMA(
"output",
"Output tensor with shape [batch_size, sample_size], where sample_size is the number of times to sample. Each value along the axis zero represents the outcome of the corresponding sample in a batch.",
"T2")
.TypeConstraint(
"T1",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input types to float tensors.")
.TypeConstraint("T1", OpSchema::all_float_types_ir4(), "Constrain input types to float tensors.")
.TypeConstraint("T2", {"tensor(int32)", "tensor(int64)"}, "Constrain output types to integral tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
auto dtype = ctx.getAttribute("dtype");
Expand Down Expand Up @@ -562,7 +525,7 @@ ONNX_OPERATOR_SET_SCHEMA(
}
}));

static const char* Bernoulli_ver15_doc = R"DOC(
static const char* Bernoulli_ver22_doc = R"DOC(
Draws binary random numbers (0 or 1) from a Bernoulli distribution. The input tensor should be a tensor
containing probabilities p (a value in the range [0,1]) to be used for drawing the binary random number,
where an output of 1 is produced with probability p and an output of 0 is produced with probability (1-p).
Expand All @@ -573,9 +536,9 @@ implementations (even if a seed is specified).

ONNX_OPERATOR_SET_SCHEMA(
Bernoulli,
15,
22,
OpSchema()
.SetDoc(Bernoulli_ver15_doc)
.SetDoc(Bernoulli_ver22_doc)
.Attr(
"seed",
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
Expand All @@ -589,25 +552,10 @@ ONNX_OPERATOR_SET_SCHEMA(
OPTIONAL_VALUE)
.Input(0, "input", "All values in input have to be in the range:[0, 1].", "T1")
.Output(0, "output", "The returned output tensor only has values 0 or 1, same shape as input tensor.", "T2")
.TypeConstraint(
"T1",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input types to float tensors.")
.TypeConstraint("T1", OpSchema::all_float_types_ir4(), "Constrain input types to float tensors.")
.TypeConstraint(
"T2",
{"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(bfloat16)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(bool)"},
OpSchema::all_non_complex_numeric_types_plus_bool_ir4(),
"Constrain output types to all numeric tensors and bool tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
if (ctx.getAttribute("dtype") != nullptr)
Expand Down

0 comments on commit fcc8731

Please sign in to comment.