Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Extend Optional op #5043

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnx/defs/operator_sets.h
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 19, AveragePool);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 19, Pad);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 19, Resize);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 19, DeformConv);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 19, Optional);

// Iterate over schema from ai.onnx version 19
class OpSet_Onnx_ver19 {
Expand All @@ -1074,6 +1075,7 @@ class OpSet_Onnx_ver19 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 19, Pad)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 19, Resize)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 19, DeformConv)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 19, Optional)>());
}
};

Expand Down
206 changes: 177 additions & 29 deletions onnx/defs/optional/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,190 @@ static std::vector<std::string> optional_and_tensor_types() {
return optional_types;
}

static const char* Optional_ver15_doc = R"DOC(
// getOptionalType: returns the type of a value-attribute, if specified, of an Optional op.
// Returns true if a value-attribute is present and type was filled in.
bool getOptionalType (InferenceContext& ctx, TypeProto_Tensor& tensorTypeProto) {
auto* value = ctx.getAttribute("value");
auto* sparse_value = ctx.getAttribute("sparse_value");
auto* value_int = ctx.getAttribute("value_int");
auto* value_ints = ctx.getAttribute("value_ints");
auto* value_float = ctx.getAttribute("value_float");
auto* value_floats = ctx.getAttribute("value_floats");
auto* value_string = ctx.getAttribute("value_string");
auto* value_strings = ctx.getAttribute("value_strings");

int num_value_attrs =
(nullptr != value) +
(nullptr != sparse_value) +
(nullptr != value_int) +
(nullptr != value_ints) +
(nullptr != value_float) +
(nullptr != value_floats) +
(nullptr != value_string) +
(nullptr != value_strings);

if (num_value_attrs > 1) {
fail_shape_inference(
"Only one of the attributes 'value', 'value_*' or 'sparse_value' must be specified for an Optional node.");
}

tensorTypeProto.mutable_shape()->clear_dim();

auto set_scalar_type = [&](int dtype) {
tensorTypeProto.set_elem_type(dtype);
};

auto set_1D_type = [&](int dtype, int64_t size) {
tensorTypeProto.set_elem_type(dtype);
tensorTypeProto.mutable_shape()->add_dim()->set_dim_value(size);
};

auto set_ND_type = [&](int dtype, const google::protobuf::RepeatedField<int64_t>& dims) {
tensorTypeProto.set_elem_type(dtype);
for (auto d : dims) {
tensorTypeProto.mutable_shape()->add_dim()->set_dim_value(d);
}
};

if (nullptr != value) {
const TensorProto& tensor_proto = value->t();
set_ND_type(tensor_proto.data_type(), tensor_proto.dims());
return true;
}

if (nullptr != value_int) {
set_scalar_type(TensorProto::INT64);
return true;
}

if (nullptr != value_ints) {
set_1D_type(TensorProto::INT64, value_ints->ints_size());
return true;
}

if (nullptr != value_float) {
set_scalar_type(TensorProto::FLOAT);
return true;
}

if (nullptr != value_floats) {
set_1D_type(TensorProto::FLOAT, value_floats->floats_size());
return true;
}

if (nullptr != value_string) {
set_scalar_type(TensorProto::STRING);
return true;
}

if (nullptr != value_strings) {
set_1D_type(TensorProto::STRING, value_strings->strings_size());
return true;
}

if (nullptr != sparse_value) {
const SparseTensorProto& sparse = sparse_value->sparse_tensor();
set_ND_type(sparse.values().data_type(), sparse.dims());
return true;
}

return false;
}

void OptionalInferenceFunction(InferenceContext& ctx) {
const size_t numInputs = ctx.getNumInputs();

// Type specified via "type" attribute, if any.
const auto* type_attr_proto = ctx.getAttribute("type");
const TypeProto* attr_type = (type_attr_proto == nullptr) ? nullptr : & type_attr_proto->tp();

// Type of value specified via some "value" attribute, if any.
TypeProto val_type;
bool const_value_specified = getOptionalType(ctx, *val_type.mutable_tensor_type());

auto& target_type = *ctx.getOutputType(0)->mutable_optional_type()->mutable_elem_type();

if ((numInputs > 0) && const_value_specified) {
fail_type_inference("Optional must not specify both an input value and a value attribute.");
}
if (numInputs > 0) {
// Construct an optional containing the input value
auto input_type = ctx.getInputType(0);
if (input_type == nullptr) {
fail_type_inference("Input type is null. Type information is expected for the input.");
}
target_type.CopyFrom(*input_type);
if (attr_type != nullptr)
UnionTypeInfo (*attr_type, target_type);
} else if (const_value_specified) {
// Construct an optional containing the attribute-specified value
target_type.CopyFrom(val_type);
if (attr_type != nullptr)
UnionTypeInfo (*attr_type, target_type);
} else if (attr_type != nullptr) {
auto& source_type = type_attr_proto->tp();
target_type.CopyFrom(*attr_type);
} else {
fail_type_inference("Optional must specify type attribute if no value is specified.");
}
}

static const char* Optional_ver19_doc = R"DOC(
Constructs an optional-type value containing either an empty optional of a certain type specified by the attribute,
or a non-empty value containing the input element.
or a non-empty value containing the input element or an attribute value, whichever is specified.

This operator is used to create either a `SOME v` value or a `NONE` value.
The value `v` may be specified either as an input argument or via attributes
(exactly as in the `Constant` op).

If no input value is specified (either via input or attribute) a `NONE` value is
constructed. In this case, the `type` attribute must be specified to enable a
(monomorphic) type to be inferred for the output.
)DOC";

ONNX_OPERATOR_SET_SCHEMA(
Optional,
15,
19,
OpSchema()
.SetDoc(Optional_ver15_doc)
.SetDoc(Optional_ver19_doc)
.Input(0, "input", "The input element.", "V", OpSchema::Optional)
.Attr("type", "Type of the element in the optional output", AttributeProto::TYPE_PROTO, OPTIONAL_VALUE)
.Attr("value", "The value for the elements of the output tensor.", AttributeProto::TENSOR, false)
.Attr(
"sparse_value",
"The value for the elements of the output tensor in sparse format.",
AttributeProto::SPARSE_TENSOR,
false)
.Attr(
"value_int",
"The value for the sole element for the scalar, int64, output tensor.",
AttributeProto::INT,
false)
.Attr(
"value_ints",
"The values for the elements for the 1D, int64, output tensor.",
AttributeProto::INTS,
false)
.Attr(
"value_float",
"The value for the sole element for the scalar, float32, output tensor.",
AttributeProto::FLOAT,
false)
.Attr(
"value_floats",
"The values for the elements for the 1D, float32, output tensor.",
AttributeProto::FLOATS,
false)
.Attr(
"value_string",
"The value for the sole element for the scalar, UTF-8 string, output tensor.",
AttributeProto::STRING,
false)
.Attr(
"value_strings",
"The values for the elements for the 1D, UTF-8 string, output tensor.",
AttributeProto::STRINGS,
false)
.Output(0, "output", "The optional output enclosing the input element.", "O")
.TypeConstraint(
"V",
Expand All @@ -44,31 +216,7 @@ ONNX_OPERATOR_SET_SCHEMA(
"O",
OpSchema::all_optional_types(),
"Constrain output type to all optional tensor or optional sequence types.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
const size_t numOutputs = ctx.getNumOutputs();
if (numOutputs != 1) {
fail_type_inference("Optional is expected to have an output.");
}

const size_t numInputs = ctx.getNumInputs();
const auto* attr_proto = ctx.getAttribute("type");

if ((numInputs == 0) && (attr_proto != nullptr)) {
if (!attr_proto->has_tp())
fail_type_inference("Attribute 'type' should be a TypeProto and it should specify a type.");
auto attr_tp = attr_proto->tp();

ctx.getOutputType(0)->mutable_optional_type()->mutable_elem_type()->CopyFrom(attr_tp);
} else if (numInputs == 1) {
auto input_type = ctx.getInputType(0);
if (input_type == nullptr) {
fail_type_inference("Input type is null. Type information is expected for the input.");
}
ctx.getOutputType(0)->mutable_optional_type()->mutable_elem_type()->CopyFrom(*input_type);
} else {
fail_type_inference("Optional is expected to have either an input or the type attribute set.");
}
}));
.TypeAndShapeInferenceFunction(OptionalInferenceFunction));

static const char* OptionalHasElement_ver18_doc = R"DOC(
Returns true if (1) the input is an optional-type and contains an element,
Expand Down
52 changes: 52 additions & 0 deletions onnx/defs/optional/old.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,56 @@ ONNX_OPERATOR_SET_SCHEMA(
ctx.getOutputType(0)->CopyFrom(input_type->optional_type().elem_type());
}));

static const char* Optional_ver15_doc = R"DOC(
Constructs an optional-type value containing either an empty optional of a certain type specified by the attribute,
or a non-empty value containing the input element.
)DOC";

ONNX_OPERATOR_SET_SCHEMA(
Optional,
15,
OpSchema()
.SetDoc(Optional_ver15_doc)
.Input(0, "input", "The input element.", "V", OpSchema::Optional)
.Attr("type", "Type of the element in the optional output", AttributeProto::TYPE_PROTO, OPTIONAL_VALUE)
.Output(0, "output", "The optional output enclosing the input element.", "O")
.TypeConstraint(
"V",
[]() {
auto t = OpSchema::all_tensor_types();
auto s = OpSchema::all_tensor_sequence_types();
t.insert(t.end(), s.begin(), s.end());
return t;
}(),
"Constrain input type to all tensor and sequence types.")
.TypeConstraint(
"O",
OpSchema::all_optional_types(),
"Constrain output type to all optional tensor or optional sequence types.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
const size_t numOutputs = ctx.getNumOutputs();
if (numOutputs != 1) {
fail_type_inference("Optional is expected to have an output.");
}

const size_t numInputs = ctx.getNumInputs();
const auto* attr_proto = ctx.getAttribute("type");

if ((numInputs == 0) && (attr_proto != nullptr)) {
if (!attr_proto->has_tp())
fail_type_inference("Attribute 'type' should be a TypeProto and it should specify a type.");
auto attr_tp = attr_proto->tp();

ctx.getOutputType(0)->mutable_optional_type()->mutable_elem_type()->CopyFrom(attr_tp);
} else if (numInputs == 1) {
auto input_type = ctx.getInputType(0);
if (input_type == nullptr) {
fail_type_inference("Input type is null. Type information is expected for the input.");
}
ctx.getOutputType(0)->mutable_optional_type()->mutable_elem_type()->CopyFrom(*input_type);
} else {
fail_type_inference("Optional is expected to have either an input or the type attribute set.");
}
}));

} // namespace ONNX_NAMESPACE
28 changes: 28 additions & 0 deletions onnx/test/model_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,34 @@ def test_mi_constant_in_function(self):
"""
self._check_shape(model, [4, 4], [8, 8, 8])

def test_mi_optional_attribute(self):
"""Test promotion of optional attribute parameters to optional values"""
model = """
<
ir_version: 7,
opset_import: [ "" : 19, "local" : 1]
>
main (int64[128] x) => (y, z) {
y = local.mul (x)
z = local.mul <alpha = 2> (x)
}
<
opset_import: [ "" : 19 ],
domain: "local"
>
mul <alpha> (x) => (y) {
alpha_opt = Optional <value_int : int = @alpha, type = int64> ()
cond = OptionalHasElement(alpha_opt)
y = If (cond) <
then_branch = g1 () => (y_then) {
alpha_val = OptionalGetElement (alpha_opt)
y_then = Mul (alpha_val, x)
},
else_branch = g2 () => (y_else) { y_else = Identity (x) }
>
}
"""
self._check_shape(model, [128], [128])

if __name__ == "__main__":
unittest.main()
13 changes: 13 additions & 0 deletions onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7014,6 +7014,19 @@ def test_optional_construct_tensor(self) -> None:
)
self._assert_inferred(graph, [optional_val_info]) # type: ignore

def test_optional_float_val (self) -> None:
graph = self._make_graph(
[],
[make_node("Optional", [], ["output"], value_float=10.0)],
[],
)
tensor_type = helper.make_tensor_type_proto(elem_type=TensorProto.FLOAT, shape=[])
optional_type = helper.make_optional_type_proto(tensor_type)
inferred_val_info = helper.make_value_info(
name="output", type_proto=optional_type
)
self._assert_inferred(graph, [inferred_val_info]) # type: ignore

def test_optional_construct_sequence(self) -> None:
tensor_type_proto = helper.make_tensor_type_proto(
elem_type=TensorProto.INT64, shape=[2, 3, 0]
Expand Down