Skip to content

Commit

Permalink
Add shape inference test for custom op (#6068)
Browse files Browse the repository at this point in the history
Add a unit test for the fix introduced by
#6049

---------

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
  • Loading branch information
gramalingam committed Apr 12, 2024
1 parent 3f24ef3 commit 77127a7
Showing 1 changed file with 38 additions and 2 deletions.
40 changes: 38 additions & 2 deletions onnx/test/cpp/shape_inference_test.cc
Expand Up @@ -507,15 +507,19 @@ TEST(GraphInferencerImplTest, Scan9_BasicTest) {
doInferencingTest(false);
}

void RunReshapeShapeInfTest(const char* modelStr, TensorShapeProto& expectedShape) {
ModelProto model;
void ParseAndInfer(ModelProto& model, const char* modelStr) {
OnnxParser parser(modelStr);
auto status = parser.Parse(model);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected.";

ShapeInferenceOptions options{true, 1, true};
ONNX_NAMESPACE::shape_inference::InferShapes(model, ONNX_NAMESPACE::OpSchemaRegistry::Instance(), options);
}

void RunReshapeShapeInfTest(const char* modelStr, TensorShapeProto& expectedShape) {
ModelProto model;
ParseAndInfer(model, modelStr);

const auto inferredShape = model.graph().output(0).type().tensor_type().shape();
EXPECT_TRUE(inferredShape.dim_size() == expectedShape.dim_size());
Expand Down Expand Up @@ -620,5 +624,37 @@ TEST(ShapeInferenceTest, CheckShapesAndTypesTest) {
#endif
}

TEST(ShapeInferenceTest, CustomOpTest) {
const char* modelStr = R"ONNX(
<ir_version: 8, opset_import: ["" : 15, "custom.domain" : 1]>
agraph (float[256, 768, 3] x) => (z1, z2)
{
z1 = custom.domain.CustomOp (x)
# Inference cannot determine the type/shape of z1
z2 = Abs(x)
# Inference SHOULD determine the type/shape of z2 (same as that of x)
}
)ONNX";

ModelProto model;
ParseAndInfer(model, modelStr);

auto& z1_value_info = model.graph().output(0);
// Check no inferred type for z1 (It's a quirk of the implementation that it
// has a dummy TypeProto, but it should have no values filled in.)
ASSERT_TRUE(z1_value_info.has_type());
ASSERT_FALSE(z1_value_info.type().has_tensor_type());

// Check inferred type for z2:
auto& z2_value_info = model.graph().output(1);
ASSERT_TRUE(z2_value_info.has_type());
ASSERT_TRUE(z2_value_info.type().has_tensor_type());
EXPECT_EQ(z2_value_info.type().tensor_type().elem_type(), TensorProto_DataType_FLOAT);
EXPECT_EQ(z2_value_info.type().tensor_type().shape().dim_size(), 3);
EXPECT_EQ(z2_value_info.type().tensor_type().shape().dim(0).dim_value(), 256);
EXPECT_EQ(z2_value_info.type().tensor_type().shape().dim(1).dim_value(), 768);
EXPECT_EQ(z2_value_info.type().tensor_type().shape().dim(2).dim_value(), 3);
}

} // namespace Test
} // namespace ONNX_NAMESPACE

0 comments on commit 77127a7

Please sign in to comment.