From 45b6d0756be6e1ece4bf7ecb1a98eab3200ee1dc Mon Sep 17 00:00:00 2001 From: Chun-Wei Chen Date: Tue, 5 Oct 2021 17:00:05 -0400 Subject: [PATCH] return if input does not have type Signed-off-by: Chun-Wei Chen --- onnx/defs/tensor/defs.cc | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/onnx/defs/tensor/defs.cc b/onnx/defs/tensor/defs.cc index 5515441a273..34b46c69fc1 100644 --- a/onnx/defs/tensor/defs.cc +++ b/onnx/defs/tensor/defs.cc @@ -433,18 +433,20 @@ ONNX_OPERATOR_SET_SCHEMA( auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); auto* output_length = output_shape->add_dim(); - if (ctx.getInputType(0)->tensor_type().has_shape()) { - int64_t rank = static_cast(ctx.getInputType(0)->tensor_type().shape().dim_size()); - int64_t start = getAttribute(ctx, "start", 0); - if (start < 0) - start += rank; - start = (start < 0) ? 0 : (start > rank) ? rank : start; - int64_t end = getAttribute(ctx, "end", rank); - if (end < 0) - end += rank; - end = (end < 0) ? 0 : (end > rank) ? rank : end; - output_length->set_dim_value((end - start) < 0 ? 0 : (end - start)); + if (!hasNInputShapes(ctx, 1)) { + return; } + + int64_t rank = static_cast(ctx.getInputType(0)->tensor_type().shape().dim_size()); + int64_t start = getAttribute(ctx, "start", 0); + if (start < 0) + start += rank; + start = (start < 0) ? 0 : (start > rank) ? rank : start; + int64_t end = getAttribute(ctx, "end", rank); + if (end < 0) + end += rank; + end = (end < 0) ? 0 : (end > rank) ? rank : end; + output_length->set_dim_value((end - start) < 0 ? 0 : (end - start)); }) .PartialDataPropagationFunction([](DataPropagationContext& ctx) { if (ctx.getInputType(0)->tensor_type().has_shape()) {