Skip to content

Commit

Permalink
ReduceLayer, fix reduce_mean over mult axes incl dyn
Browse files Browse the repository at this point in the history
Fix #1242
  • Loading branch information
albertz committed Dec 6, 2022
1 parent f4c8d92 commit 767b939
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
20 changes: 16 additions & 4 deletions returnn/tf/layers/basic.py
Expand Up @@ -6236,12 +6236,24 @@ def reduce(cls, input_data, mode, axes=None, keep_dims=False, enforce_batch_dim_

x_ = tf_util.where_bc(mask, x_, replacement_value, name="x_masked_axis_%i" % axis)
if f == tf.reduce_mean:
tag = x.dim_tags[axis]
assert tag.dyn_size_ext is not None # checked above
size_all = tf.shape(x.placeholder)[axis]
size_actual = tag.dyn_size_ext
while any(d not in out_data.dim_tags for d in size_actual.dim_tags):
# We have some axis (e.g. B) which is not in the output.
# We need to remove this.
# https://github.com/rwth-i6/returnn/issues/1242
i, d = [(i, d) for i, d in enumerate(size_actual.dim_tags) if d not in out_data.dim_tags][0]
assert not d.is_dynamic() # not implemented
size_all *= d.get_dim_value()
s = tf.reduce_sum(size_actual.placeholder, axis=i)
size_actual = size_actual.copy_template_excluding_axis(i)
size_actual.placeholder = s
seq_len_bc = (
x.dim_tags[axis].dyn_size_ext
.copy_compatible_to(out_data, check_sparse=False, check_dtype=False)
.placeholder)
size_actual.copy_compatible_to(out_data, check_sparse=False, check_dtype=False).placeholder)
seq_len_bc = tf.maximum(seq_len_bc, 1) # avoid nan
correction_factor_ = tf.cast(tf.shape(x.placeholder)[axis], tf.float32) / tf.cast(seq_len_bc, tf.float32)
correction_factor_ = tf.cast(size_all, tf.float32) / tf.cast(seq_len_bc, tf.float32)
correction_factor = tf_util.optional_mul(correction_factor, correction_factor_)
if mode in arg_funcs:
assert len(axes) == 1, "For argmax/argmin, only one reduction axis is supported"
Expand Down
23 changes: 23 additions & 0 deletions tests/test_TFNetworkLayer.py
Expand Up @@ -10567,6 +10567,29 @@ def test_reduce_mean_batch_time():
numpy.testing.assert_allclose(ref, v, rtol=1e-5)


def test_ReduceLayer_mean_btf():
# https://github.com/rwth-i6/returnn/issues/1242
net_dict = {
"output": {"class": "reduce", "mode": "mean", "from": "data", "axis": ["B", "T", "F"]}
}
config = Config(dict(
extern_data={"data": {"shape": (None, 4)}}
))
with make_scope() as session:
network = TFNetwork(config=config)
network.construct_from_dict(net_dict)
in_ = network.extern_data.get_default_input_data()
out = network.get_default_output_layer().output
in_v, seq_len, out_v = session.run(
(in_.placeholder, in_.get_sequence_lengths(), out.placeholder),
feed_dict=make_feed_dict(network.extern_data))
n_batch = in_v.shape[0]
assert n_batch == seq_len.shape[0]
for b in range(n_batch):
in_v[b, seq_len[b]:, :] = numpy.nan
numpy.testing.assert_almost_equal(out_v, numpy.nanmean(in_v))


def test_automatic_seq_lengths():
with make_scope() as session:
n_out = 5
Expand Down

0 comments on commit 767b939

Please sign in to comment.