Skip to content

Commit

Permalink
Batchnorm shape inference (#3657)
Browse files Browse the repository at this point in the history
* Fix batchnorm shape inference

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

* Undo format change of other code

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
  • Loading branch information
gramalingam committed Aug 15, 2021
1 parent 3b27cb9 commit c24cd42
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
14 changes: 13 additions & 1 deletion onnx/defs/nn/defs.cc
Expand Up @@ -1748,9 +1748,21 @@ ONNX_OPERATOR_SET_SCHEMA(
propagateShapeAndTypeFromFirstInput(ctx);
propagateShapeFromInputToOutput(ctx, 0, 0);

// Inputs 1 to 4 must be of rank 1.
checkInputRank(ctx, 1, 1);
checkInputRank(ctx, 2, 1);
checkInputRank(ctx, 3, 1);
checkInputRank(ctx, 4, 1);

Dim num_channels;

unifyInputDim(ctx, 0, 1, num_channels);
if (hasInputShape(ctx, 0)) {
if (getInputShape(ctx, 0).dim_size() > 1)
unifyInputDim(ctx, 0, 1, num_channels);
else
unifyDim(num_channels, 1);
}

unifyInputDim(ctx, 1, 0, num_channels);
unifyInputDim(ctx, 2, 0, num_channels);
unifyInputDim(ctx, 3, 0, num_channels);
Expand Down
14 changes: 13 additions & 1 deletion onnx/defs/nn/old.cc
Expand Up @@ -1888,9 +1888,21 @@ ONNX_OPERATOR_SET_SCHEMA(
propagateShapeAndTypeFromFirstInput(ctx);
propagateShapeFromInputToOutput(ctx, 0, 0);

// Inputs 1 to 4 must be of rank 1.
checkInputRank(ctx, 1, 1);
checkInputRank(ctx, 2, 1);
checkInputRank(ctx, 3, 1);
checkInputRank(ctx, 4, 1);

Dim num_channels;

unifyInputDim(ctx, 0, 1, num_channels);
if (hasInputShape(ctx, 0)) {
if (getInputShape(ctx, 0).dim_size() > 1)
unifyInputDim(ctx, 0, 1, num_channels);
else
unifyDim(num_channels, 1);
}

unifyInputDim(ctx, 1, 0, num_channels);
unifyInputDim(ctx, 2, 0, num_channels);
unifyInputDim(ctx, 3, 0, num_channels);
Expand Down
22 changes: 22 additions & 0 deletions onnx/test/shape_inference_test.py
Expand Up @@ -1322,6 +1322,28 @@ def test_batch_norm(self): # type: () -> None
[])
self._assert_inferred(graph, [make_tensor_value_info('out', TensorProto.FLOAT, (3, 4, 5, 6, 7))])

def test_batch_norm_rank1(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.FLOAT, (128,)), # 1-dimensional permitted
('scale', TensorProto.FLOAT, (1,)),
('b', TensorProto.FLOAT, (1,)),
('mean', TensorProto.FLOAT, (1,)),
('var', TensorProto.FLOAT, (1,))],
[make_node('BatchNormalization', ['x', 'scale', 'b', 'mean', 'var'], ['out'])],
[])
self._assert_inferred(graph, [make_tensor_value_info('out', TensorProto.FLOAT, (128,))])

def test_batch_norm_invalid(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.FLOAT, (128,)),
('scale', TensorProto.FLOAT, (1, 2)), # invalid rank
('b', TensorProto.FLOAT, (1,)),
('mean', TensorProto.FLOAT, (1,)),
('var', TensorProto.FLOAT, (1,))],
[make_node('BatchNormalization', ['x', 'scale', 'b', 'mean', 'var'], ['out'])],
[])
self.assertRaises(onnx.shape_inference.InferenceError, self._inferred, graph)

def test_split_negative_axis(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.FLOAT, (2, 4))],
Expand Down

0 comments on commit c24cd42

Please sign in to comment.