Skip to content

Commit

Permalink
return if input does not have type (onnx#3757)
Browse files Browse the repository at this point in the history
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Mark <mhamilton723@gmail.com>
  • Loading branch information
jcwchen authored and mhamilton723 committed Oct 22, 2021
1 parent 2a9487c commit b4cfc97
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions onnx/defs/tensor/defs.cc
Expand Up @@ -433,18 +433,20 @@ ONNX_OPERATOR_SET_SCHEMA(
auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
auto* output_length = output_shape->add_dim();

if (ctx.getInputType(0)->tensor_type().has_shape()) {
int64_t rank = static_cast<int64_t>(ctx.getInputType(0)->tensor_type().shape().dim_size());
int64_t start = getAttribute(ctx, "start", 0);
if (start < 0)
start += rank;
start = (start < 0) ? 0 : (start > rank) ? rank : start;
int64_t end = getAttribute(ctx, "end", rank);
if (end < 0)
end += rank;
end = (end < 0) ? 0 : (end > rank) ? rank : end;
output_length->set_dim_value((end - start) < 0 ? 0 : (end - start));
if (!hasNInputShapes(ctx, 1)) {
return;
}

int64_t rank = static_cast<int64_t>(ctx.getInputType(0)->tensor_type().shape().dim_size());
int64_t start = getAttribute(ctx, "start", 0);
if (start < 0)
start += rank;
start = (start < 0) ? 0 : (start > rank) ? rank : start;
int64_t end = getAttribute(ctx, "end", rank);
if (end < 0)
end += rank;
end = (end < 0) ? 0 : (end > rank) ? rank : end;
output_length->set_dim_value((end - start) < 0 ? 0 : (end - start));
})
.PartialDataPropagationFunction([](DataPropagationContext& ctx) {
if (ctx.getInputType(0)->tensor_type().has_shape()) {
Expand Down

0 comments on commit b4cfc97

Please sign in to comment.