Skip to content

Commit

Permalink
#1749: Fix fused_batch_norm 5d NDHWC input reshape convert (#1769)
Browse files Browse the repository at this point in the history
* fix fused_batch_norm 5d input reshape convert

* fix for #1749

Signed-off-by: hwangdeyu <dejack953@outlook.com>
  • Loading branch information
hwangdeyu committed Nov 15, 2021
1 parent d87ba34 commit 5dfd36f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
25 changes: 25 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2817,6 +2817,31 @@ def func(x):
return tf.identity(y, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04)

@check_opset_min_version(7, "batchnorm")
@check_tf_min_version("2.0", "tf-1.x does not support NDHWC")
def test_fused_batchnorm_3d(self):
x_shape = [1, 28, 28, 2, 2]
x_dtype = np.float32
scale_dtype = np.float32
scale_shape = [2]
data_format = "NDHWC"
x_val = np.random.random_sample(x_shape).astype(x_dtype)
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
var_val = np.random.random_sample(scale_shape).astype(scale_dtype)
def func(x):
scale = tf.constant(scale_val, name='scale')
offset = tf.constant(offset_val, name='offset')
mean = tf.constant(mean_val, name='mean')
var = tf.constant(var_val, name='variance')
epsilon = 0.001
y, _, _ = fused_batch_norm(
x, scale, offset, mean=mean, variance=var,
epsilon=epsilon, data_format=data_format, is_training=False)
return tf.identity(y, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04)

@check_opset_min_version(7, "batchnorm")
@skip_tfjs("TFJS executes model incorrectly")
def test_fused_batchnorm_training(self):
Expand Down
10 changes: 9 additions & 1 deletion tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,14 @@ class BatchNorm:
@classmethod
def version_6(cls, ctx, node, **kwargs):
tf_type = node.type
input_rank = len(ctx.get_shape(node.input[0]))
if input_rank == 4:
spatial = 2
elif input_rank == 5:
spatial = 3
else:
raise ValueError("node input must be 4 or 5-dimensional, is {} now".format(input_rank))

node.type = "BatchNormalization"
# tf inputs: x, scale, bias, mean, variance
# tf outputs: y, batch_mean, batch_var
Expand Down Expand Up @@ -973,7 +981,7 @@ def version_6(cls, ctx, node, **kwargs):
# the setter makes a copy of new_output
node.output = new_output

conv_convert_inputs(ctx, node, with_kernel=False)
conv_convert_inputs(ctx, node, with_kernel=False, spatial=spatial)

inp_shape = ctx.get_shape(node.input[0])
inp_rank = len(inp_shape) if inp_shape is not None else None
Expand Down

0 comments on commit 5dfd36f

Please sign in to comment.