diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 1fa5726f6..691c89a64 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -7556,7 +7556,7 @@ class ResizeLayer(_ConcatInputLayer): def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_dropout=None, **kwargs): """ - :param int factor: + :param int|float|LayerBase factor: out_len = in_len * factor :param Dim|str axis: the axis to resize :param Dim|None out_dim: :param str kind: "linear", "nn"/"nearest_neighbor", "cubic", "fill" @@ -7565,6 +7565,7 @@ def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_ """ out_dim # noqa # via get_out_data_from_opts super(ResizeLayer, self).__init__(**kwargs) + self.factor = factor # self.output.shape and self.output.batch_dim_axis are already set here via self.get_out_data_from_opts(). input_data = self.input_data.copy_as_batch_major() axis = input_data.get_axis_from_description(axis) @@ -7580,7 +7581,20 @@ def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_ remaining_shape = [shape[i] for i in remaining_axes] remaining_dim = optional_mul(*remaining_shape) if remaining_axes else 1 x = tf.reshape(x, [shape[0], shape[axis], 1, remaining_dim]) # [batch,height,width,channels] - new_size = shape[axis] * factor + if isinstance(factor, (int, float)): + factor_t = factor + elif isinstance(factor, LayerBase): + assert factor.output.batch_shape == (), "%s: factor must be scalar, got %s" % (self, factor) + assert factor.output.dtype in ("int32", "float32"), "%s: factor must be int or float, got %s" % (self, factor) + factor_t = factor.output.placeholder + else: + raise TypeError("%s: unexpected factor type %s" % (self, type(factor).__name__)) + if isinstance(factor_t, int) or (isinstance(factor_t, tf.Tensor) and factor_t.dtype == tf.int32): + new_size = shape[axis] * factor_t + elif isinstance(factor_t, float) or (isinstance(factor_t, tf.Tensor) and factor_t.dtype == tf.float32): + new_size = tf.cast(tf.math.ceil(tf.cast(shape[axis], tf.float32) * factor_t), tf.int32) + else: + raise TypeError("%s: unexpected factor_t %s" % (self, factor_t)) if kind == "linear": x = tf_compat.v1.image.resize_bilinear(x, size=(new_size, 1)) elif kind == "cubic": @@ -7605,6 +7619,7 @@ def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_ raise Exception("invalid kind %r for resizing" % kind) x = tf.reshape(x, [shape[0], new_size] + remaining_shape) # [batch,new_size] + remaining_shape if fill_dropout: + assert isinstance(factor, int) from returnn.tf.util.basic import expand_dims_unbroadcast # We are going to build a mask over the axis. This mask will be shared over all seqs in the batch. # Similar to in tf.nn.dropout. Build random_tensor as uniform [keep_prob, 1.0 + keep_prob). @@ -7626,12 +7641,44 @@ def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_ out_dyn_size, maxlen=new_size, dtype=tf.bool) # (batch,new_size) out_dyn_size = tf.reduce_sum(tf.cast(tf.logical_and(mask, orig_mask), tf.int32), axis=1) self.output.dim_tags[axis].dyn_size = out_dyn_size + elif not isinstance(factor, int): + dyn_size_ext = input_data.dim_tags[axis].dyn_size_ext + if dyn_size_ext is not None and dyn_size_ext.placeholder is not None: + dyn_size_ext = dyn_size_ext.copy(name="%s:dyn_size_ext" % self.name) + if isinstance(factor_t, int) or (isinstance(factor_t, tf.Tensor) and factor_t.dtype == tf.int32): + dyn_size_ext.placeholder = dyn_size_ext.placeholder * factor_t + elif isinstance(factor_t, float) or (isinstance(factor_t, tf.Tensor) and factor_t.dtype == tf.float32): + dyn_size_ext.placeholder = tf.cast( + tf.math.ceil(tf.cast(dyn_size_ext.placeholder, tf.float32) * factor_t), tf.int32) + else: + raise TypeError("%s: unexpected factor_t %s" % (self, factor_t)) + self.output.dim_tags[axis].dyn_size_ext = dyn_size_ext self.output.placeholder = x + def get_dep_layers(self): + """ + :rtype: list[LayerBase] + """ + deps = super(ResizeLayer, self).get_dep_layers() + if isinstance(self.factor, LayerBase): + deps.append(self.factor) + return deps + + @classmethod + def transform_config_dict(cls, d, network, get_layer): + """ + :param dict[str] d: + :param returnn.tf.network.TFNetwork network: + :param (str)->LayerBase get_layer: + """ + super(ResizeLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer) + if isinstance(d.get("factor"), str): + d["factor"] = get_layer(d["factor"]) + @classmethod def get_out_data_from_opts(cls, factor, axis, sources, name, fill_dropout=None, out_dim=None, **kwargs): """ - :param int factor: + :param int|float|LayerBase factor: :param Dim|str axis: :param list[LayerBase] sources: :param str name: @@ -7646,8 +7693,10 @@ def get_out_data_from_opts(cls, factor, axis, sources, name, fill_dropout=None, axis = 1 assert axis != out.batch_dim_axis, "batch-dim resize not supported" tag = out.dim_tags[axis] - if fill_dropout: + if fill_dropout or not isinstance(factor, int): out_dim_ = Dim(kind=tag.kind, description="%s_resize" % name, auto_generated=True) # unknown dim + if tag.dyn_size_ext is not None: + out_dim_.dyn_size_ext = tag.dyn_size_ext.copy_template(name="%s:dyn_size_ext" % name) else: out_dim_ = tag * factor if out_dim: diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index 39cf3b1f9..7f671aa48 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -8047,6 +8047,35 @@ def test_ResizeLayer_BFT(): assert out_v.shape == (n_batch, n_in, n_time * 2) +def test_ResizeLayer_dynamic(): + n_batch, n_time, n_in = 2, 5, 3 + in_v = numpy.arange(0, n_batch * n_time * n_in).astype("float32").reshape((n_batch, n_time, n_in)) + in_seq_lens = numpy.array([5, 4]) + config = Config({ + "extern_data": { + "data": {"shape": (None, n_in)}, + "factor": {"shape": (), "batch_dim_axis": None, "dtype": "float32"}, + } + }) + with make_scope() as session: + net = TFNetwork(config=config) + net.construct_from_dict({ + "output": {"class": "resize", "axis": "T", "factor": "data:factor", "kind": "nn", "from": "data"} + }) + out = net.get_default_output_layer().output + for factor in [0.5, 2.0]: + out_v, out_lens = session.run( + (out.placeholder, out.get_sequence_lengths()), feed_dict={ + net.extern_data.get_batch_info().dim: n_batch, + net.extern_data.data["data"].placeholder: in_v, + net.extern_data.data["data"].get_sequence_lengths(): in_seq_lens, + net.extern_data.data["factor"].placeholder: factor, + }) + assert isinstance(out_v, numpy.ndarray) + assert out_v.shape == (n_batch, numpy.ceil(n_time * factor), n_in) + numpy.testing.assert_equal(out_lens, numpy.ceil(in_seq_lens * factor).astype("int32")) + + def test_PostfixInTimeLayer(): with make_scope() as session: import numpy as np