diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index af08972a6..39f86a411 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -7665,6 +7665,7 @@ def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_ else: raise TypeError("%s: unexpected factor_t %s" % (self, factor_t)) self.output.dim_tags[axis].dyn_size_ext = dyn_size_ext + self.output.dim_tags[axis].set_tag_on_size_tensor(dyn_size_ext.placeholder, batch=dyn_size_ext.batch) self.output.placeholder = x def get_dep_layers(self):