Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected sensitivity of group normalisation results to batch size #2745

Open
wsmigaj opened this issue Aug 12, 2022 · 2 comments
Open

Unexpected sensitivity of group normalisation results to batch size #2745

wsmigaj opened this issue Aug 12, 2022 · 2 comments

Comments

@wsmigaj
Copy link

wsmigaj commented Aug 12, 2022

System information

  • OS Platform and Distribution: Windows 10
  • TensorFlow version and how it was installed: 2.9.1 (binary, from pip)
  • TensorFlow-Addons version and how it was installed: 1.17.1 (binary, from pip)
  • Python version: 3.8
  • Is GPU used? (yes/no): no

Describe the bug

Results produced by group normalisation differ markedly depending on whether the batch size == 1 or not.

Code to reproduce the issue

The following code applies group normalisation along the last axis, first to a batch of 4 slices, then to individual slices, and finally to pairs of consecutive slices. In principle, the results obtained in each case should be the same, since group normalisation is not done along the batch dimension.

import tensorflow as tf
import tensorflow_addons as tfa

# Dimensions
b = 4
h = 600
w = 400
g = 3
c = g * 8

gn = tfa.layers.GroupNormalization(groups=g, axis=-1)
gn.build([None, None, None, c])

input = tf.random.stateless_uniform([b, h, w, c], seed=[1, 2])

# Apply group normalisation to the whole batch (4 slices) at once
output_b4 = gn.call(input)

# Apply group normalisation to each slice individually
output_0 = gn.call(input[0:1])
output_1 = gn.call(input[1:2])
output_2 = gn.call(input[2:3])
output_3 = gn.call(input[3:4])
output_b1 = tf.concat([output_0, output_1, output_2, output_3], axis=0)
tf.print("Batch size 1 vs batch size 4:", tf.reduce_max(tf.abs(output_b1 - output_b4)))

# Apply group normalisation to pairs of slices
output_01 = gn.call(input[0:2])
output_23 = gn.call(input[2:4])
output_b2 = tf.concat([output_01, output_23], axis=0)
tf.print("Batch size 2 vs batch size 4:", tf.reduce_max(tf.abs(output_b2 - output_b4)))

Output:

Batch size 1 vs batch size 4: 0.000871777534
Batch size 2 vs batch size 4: 2.38418579e-07

So the difference between results obtained for batch sizes 2 and 4 is on the order of machine precision, but that between results obtained for batch sizes 1 and 4 is three orders of magnitude larger.

Other info / logs

The difference is introduced by the call to tf.nn.moments() in GroupNormalization._apply_normalization(). The Reduce operations executed by moments() reshape the input tensor differently depending on whether its first dimension is 1 (and therefore it doesn't matter whether it is reduced over or not) or not (and therefore it must not be reduced over). The determines whether or not all elements with the same batch and group index are located next to each other in memory, and probably affects the order in which these elements are added together by Eigen::Tensor::reduce(). The difference in the final result is then a consequence of the non-associativity of floating-point addition.

The problem can be worked around by transposing the input tensor to a channels-first format, in which all axes not participating in the reduction (batch and group index) are located at the start of the axes list. However, it would be more user-friendly for this transpose to be done automatically inside GroupNormalization. If this sounds reasonable, I'm happy to open a PR patching GroupNormalization in this way.

@bhack
Copy link
Contributor

bhack commented Aug 12, 2022

Can you try to add a new test to cover this case?

@wsmigaj
Copy link
Author

wsmigaj commented Aug 12, 2022

Can you try to add a new test to cover this case?

Yes -- I've added a new test in #2746.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants