Skip to content

Commit

Permalink
ResizeLayer, support float, dynamic tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Dec 6, 2022
1 parent 2c0bf36 commit f59a167
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 4 deletions.
57 changes: 53 additions & 4 deletions returnn/tf/layers/basic.py
Expand Up @@ -7556,7 +7556,7 @@ class ResizeLayer(_ConcatInputLayer):

def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_dropout=None, **kwargs):
"""
:param int factor:
:param int|float|LayerBase factor: out_len = in_len * factor
:param Dim|str axis: the axis to resize
:param Dim|None out_dim:
:param str kind: "linear", "nn"/"nearest_neighbor", "cubic", "fill"
Expand All @@ -7565,6 +7565,7 @@ def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_
"""
out_dim # noqa # via get_out_data_from_opts
super(ResizeLayer, self).__init__(**kwargs)
self.factor = factor
# self.output.shape and self.output.batch_dim_axis are already set here via self.get_out_data_from_opts().
input_data = self.input_data.copy_as_batch_major()
axis = input_data.get_axis_from_description(axis)
Expand All @@ -7580,7 +7581,20 @@ def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_
remaining_shape = [shape[i] for i in remaining_axes]
remaining_dim = optional_mul(*remaining_shape) if remaining_axes else 1
x = tf.reshape(x, [shape[0], shape[axis], 1, remaining_dim]) # [batch,height,width,channels]
new_size = shape[axis] * factor
if isinstance(factor, (int, float)):
factor_t = factor
elif isinstance(factor, LayerBase):
assert factor.output.batch_shape == (), "%s: factor must be scalar, got %s" % (self, factor)
assert factor.output.dtype in ("int32", "float32"), "%s: factor must be int or float, got %s" % (self, factor)
factor_t = factor.output.placeholder
else:
raise TypeError("%s: unexpected factor type %s" % (self, type(factor).__name__))
if isinstance(factor_t, int) or (isinstance(factor_t, tf.Tensor) and factor_t.dtype == tf.int32):
new_size = shape[axis] * factor_t
elif isinstance(factor_t, float) or (isinstance(factor_t, tf.Tensor) and factor_t.dtype == tf.float32):
new_size = tf.cast(tf.math.ceil(tf.cast(shape[axis], tf.float32) * factor_t), tf.int32)
else:
raise TypeError("%s: unexpected factor_t %s" % (self, factor_t))
if kind == "linear":
x = tf_compat.v1.image.resize_bilinear(x, size=(new_size, 1))
elif kind == "cubic":
Expand All @@ -7605,6 +7619,7 @@ def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_
raise Exception("invalid kind %r for resizing" % kind)
x = tf.reshape(x, [shape[0], new_size] + remaining_shape) # [batch,new_size] + remaining_shape
if fill_dropout:
assert isinstance(factor, int)
from returnn.tf.util.basic import expand_dims_unbroadcast
# We are going to build a mask over the axis. This mask will be shared over all seqs in the batch.
# Similar to in tf.nn.dropout. Build random_tensor as uniform [keep_prob, 1.0 + keep_prob).
Expand All @@ -7626,12 +7641,44 @@ def __init__(self, factor, axis, out_dim=None, kind="nn", fill_value=None, fill_
out_dyn_size, maxlen=new_size, dtype=tf.bool) # (batch,new_size)
out_dyn_size = tf.reduce_sum(tf.cast(tf.logical_and(mask, orig_mask), tf.int32), axis=1)
self.output.dim_tags[axis].dyn_size = out_dyn_size
elif not isinstance(factor, int):
dyn_size_ext = input_data.dim_tags[axis].dyn_size_ext
if dyn_size_ext is not None and dyn_size_ext.placeholder is not None:
dyn_size_ext = dyn_size_ext.copy(name="%s:dyn_size_ext" % self.name)
if isinstance(factor_t, int) or (isinstance(factor_t, tf.Tensor) and factor_t.dtype == tf.int32):
dyn_size_ext.placeholder = dyn_size_ext.placeholder * factor_t
elif isinstance(factor_t, float) or (isinstance(factor_t, tf.Tensor) and factor_t.dtype == tf.float32):
dyn_size_ext.placeholder = tf.cast(
tf.math.ceil(tf.cast(dyn_size_ext.placeholder, tf.float32) * factor_t), tf.int32)
else:
raise TypeError("%s: unexpected factor_t %s" % (self, factor_t))
self.output.dim_tags[axis].dyn_size_ext = dyn_size_ext
self.output.placeholder = x

def get_dep_layers(self):
"""
:rtype: list[LayerBase]
"""
deps = super(ResizeLayer, self).get_dep_layers()
if isinstance(self.factor, LayerBase):
deps.append(self.factor)
return deps

@classmethod
def transform_config_dict(cls, d, network, get_layer):
"""
:param dict[str] d:
:param returnn.tf.network.TFNetwork network:
:param (str)->LayerBase get_layer:
"""
super(ResizeLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer)
if isinstance(d.get("factor"), str):
d["factor"] = get_layer(d["factor"])

@classmethod
def get_out_data_from_opts(cls, factor, axis, sources, name, fill_dropout=None, out_dim=None, **kwargs):
"""
:param int factor:
:param int|float|LayerBase factor:
:param Dim|str axis:
:param list[LayerBase] sources:
:param str name:
Expand All @@ -7646,8 +7693,10 @@ def get_out_data_from_opts(cls, factor, axis, sources, name, fill_dropout=None,
axis = 1
assert axis != out.batch_dim_axis, "batch-dim resize not supported"
tag = out.dim_tags[axis]
if fill_dropout:
if fill_dropout or not isinstance(factor, int):
out_dim_ = Dim(kind=tag.kind, description="%s_resize" % name, auto_generated=True) # unknown dim
if tag.dyn_size_ext is not None:
out_dim_.dyn_size_ext = tag.dyn_size_ext.copy_template(name="%s:dyn_size_ext" % name)
else:
out_dim_ = tag * factor
if out_dim:
Expand Down
29 changes: 29 additions & 0 deletions tests/test_TFNetworkLayer.py
Expand Up @@ -8047,6 +8047,35 @@ def test_ResizeLayer_BFT():
assert out_v.shape == (n_batch, n_in, n_time * 2)


def test_ResizeLayer_dynamic():
n_batch, n_time, n_in = 2, 5, 3
in_v = numpy.arange(0, n_batch * n_time * n_in).astype("float32").reshape((n_batch, n_time, n_in))
in_seq_lens = numpy.array([5, 4])
config = Config({
"extern_data": {
"data": {"shape": (None, n_in)},
"factor": {"shape": (), "batch_dim_axis": None, "dtype": "float32"},
}
})
with make_scope() as session:
net = TFNetwork(config=config)
net.construct_from_dict({
"output": {"class": "resize", "axis": "T", "factor": "data:factor", "kind": "nn", "from": "data"}
})
out = net.get_default_output_layer().output
for factor in [0.5, 2.0]:
out_v, out_lens = session.run(
(out.placeholder, out.get_sequence_lengths()), feed_dict={
net.extern_data.get_batch_info().dim: n_batch,
net.extern_data.data["data"].placeholder: in_v,
net.extern_data.data["data"].get_sequence_lengths(): in_seq_lens,
net.extern_data.data["factor"].placeholder: factor,
})
assert isinstance(out_v, numpy.ndarray)
assert out_v.shape == (n_batch, numpy.ceil(n_time * factor), n_in)
numpy.testing.assert_equal(out_lens, numpy.ceil(in_seq_lens * factor).astype("int32"))


def test_PostfixInTimeLayer():
with make_scope() as session:
import numpy as np
Expand Down

0 comments on commit f59a167

Please sign in to comment.