Skip to content

Commit

Permalink
RF pad TF PadLayer handle_dynamic_dims, new behavior version
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Apr 25, 2024
1 parent 0d67dc6 commit ea0f143
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 24 deletions.
6 changes: 6 additions & 0 deletions docs/configuration_reference/behavior_version.rst
Expand Up @@ -22,6 +22,12 @@ and not listing legacy/deprecated parameters.
Version History
---------------

Behavior version 21 (2024-04-25)

RF ``pad`` and TF ``PadLayer`` defaults changed:

* ``handle_dynamic_dims``: False → True

Behavior version 20 (2024-01-05)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions returnn/frontend/_backend.py
Expand Up @@ -475,6 +475,7 @@ def pad(
axes: Sequence[Dim],
padding: Sequence[Tuple[Union[Dim, int], Union[Dim, int]]],
out_dims: Sequence[Dim],
handle_dynamic_dims: bool,
mode: str = "constant",
value: Optional[Union[rf.RawTensorTypes, Tensor]] = None,
) -> Tensor:
Expand All @@ -483,6 +484,7 @@ def pad(
:param axes:
:param padding:
:param out_dims:
:param handle_dynamic_dims:
:param mode:
:param value:
:return: padded tensor
Expand Down
59 changes: 57 additions & 2 deletions returnn/frontend/array_.py
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations
from typing import Optional, Union, Type, TypeVar, Sequence, Tuple
import logging
import numpy
from returnn.tensor import Tensor, Dim
import returnn.frontend as rf
Expand Down Expand Up @@ -385,16 +386,21 @@ def pad(
out_dims: Optional[Sequence[Dim]] = None,
mode: str = "constant",
value: Optional[Union[rf.RawTensorTypes, Tensor]] = None,
handle_dynamic_dims: Optional[bool] = None,
) -> Tuple[Tensor, Sequence[Dim]]:
"""
Pad values left/right in the specified axes.
:param source:
:param axes: which axes to add padding to
:param padding: list of (left, right) padding for each axis
:param out_dims: (optional) predefined out dim tags, otherwise will automatically create
:param out_dims: (optional) predefined out dims for each padded dim in axes. will automatically create if not given
:param mode: 'constant', 'reflect', 'replicate' or 'circular'
:param value: (optional) value to pad with in "constant" mode
:param handle_dynamic_dims: True: when doing right padding on a dynamic dim, value will be added after the seq end,
not at the end of the dimension. False: value will be added at the end of the dimension.
By default, in behavior version >=21, this is True, in older versions, this is False.
:return: padded tensor, out_dims. out dims are for each dim in axes
"""
assert len(axes) == len(padding)
if not out_dims:
Expand All @@ -405,13 +411,62 @@ def pad(
assert not right.need_masking(), f"padding {padding} does not support dynamic right padding"
# Note that even dynamic middle dims is not exactly correct...
out_dims = [left + middle + right for middle, (left, right) in zip(axes, padding)]
if handle_dynamic_dims is None:
handle_dynamic_dims = _pad_handle_dynamic_dims_default(axes, padding, mode=mode)
# noinspection PyProtectedMember
return (
source._raw_backend.pad(source, axes=axes, padding=padding, out_dims=out_dims, mode=mode, value=value),
source._raw_backend.pad(
source,
axes=axes,
padding=padding,
out_dims=out_dims,
handle_dynamic_dims=handle_dynamic_dims,
mode=mode,
value=value,
),
out_dims,
)


_pad_handle_dynamic_dims_shown_warning = False


def _pad_handle_dynamic_dims_default(
pad_axes: Sequence[Dim], padding: Sequence[Tuple[Union[Dim, int], Union[Dim, int]]], *, mode: str
) -> bool:
"""
:param pad_axes: list of axes to pad
:param padding: list of (left, right) padding for each axis
:param mode: 'constant', 'reflect', 'replicate' or 'circular'
:return: True if dynamic dims should be handled as specified in the default behavior
"""
from returnn.util.basic import BehaviorVersion

if BehaviorVersion.get() >= 21:
return True

# Check whether not handling the dynamic dims is safe. Print a warning if not safe.
global _pad_handle_dynamic_dims_shown_warning
if not _pad_handle_dynamic_dims_shown_warning:
for middle, (left, right) in zip(pad_axes, padding):
middle: Dim
if not middle.need_masking() and (isinstance(left, int) or not left.need_masking()):
continue
if mode != "circular" and isinstance(right, int) and right == 0:
continue

logging.getLogger("returnn.frontend").warning(
f"rf.pad applied on dynamic dim {middle} but handle_dynamic_dims=False used by default"
f" due to behavior version {BehaviorVersion.get()} < 21."
" Set handle_dynamic_dims explicitly to avoid the warning,"
" or switch to a new behavior version >= 21."
" (This warning is only printed once.)"
)
_pad_handle_dynamic_dims_shown_warning = True
break
return False


def cum_concat_step(
source: Tensor, *, prev_accum: Tensor, axis: Dim, out_spatial_dim: Optional[Dim] = None
) -> Tuple[Tensor, Dim]:
Expand Down
2 changes: 2 additions & 0 deletions returnn/tf/frontend_layers/_backend.py
Expand Up @@ -355,6 +355,7 @@ def pad(
axes: Sequence[Dim],
padding: Sequence[Tuple[Union[Dim, int], Union[Dim, int]]],
out_dims: Sequence[Dim],
handle_dynamic_dims: bool,
mode: str = "constant",
value: Union[rf.RawTensorTypes, Tensor] = None,
) -> Tensor:
Expand All @@ -367,6 +368,7 @@ def pad(
"axes": axes,
"padding": padding,
"out_dims": out_dims,
"handle_dynamic_dims": handle_dynamic_dims,
"mode": mode,
"value": value,
},
Expand Down
125 changes: 105 additions & 20 deletions returnn/tf/layers/basic.py
Expand Up @@ -1981,13 +1981,15 @@ def __init__(
# Now we need to implement a similar logic as `returnn.tf.util.basic.nd_indices`, but more generic.
idxs = [
(
tf.reshape(
tf.range(pos_shape[i], dtype=pos_v.dtype), [1] * i + [pos_shape[i]] + [1] * (pos_ndim - i - 1)
(
tf.reshape(
tf.range(pos_shape[i], dtype=pos_v.dtype), [1] * i + [pos_shape[i]] + [1] * (pos_ndim - i - 1)
)
+ tf.zeros_like(pos_v)
)
+ tf.zeros_like(pos_v)
if i != replace_common_axis
else pos_v
)
if i != replace_common_axis
else pos_v
for i in range(pos_ndim)
]
nd_idxs = tf.stack(idxs, axis=-1)
Expand Down Expand Up @@ -4173,13 +4175,28 @@ class PadLayer(_ConcatInputLayer):

layer_class = "pad"

def __init__(self, axes, padding, out_dims=None, value=0, mode="constant", **kwargs):
def __init__(
self,
*,
axes: Union[Dim, str, Sequence[Union[Dim, str]]],
padding: Union[int, Tuple[int, int], Sequence[Tuple[int, int]]],
out_dims: Optional[Union[Dim, Sequence[Dim]]] = None,
handle_dynamic_dims: Optional[bool] = None,
value: Union[int, float] = 0,
mode: str = "constant",
**kwargs,
):
"""
:param Dim|str|list[Dim|str] axes: e.g. "F" etc. see :func:`Data.get_axes_from_description`.
:param list[(int,int)]|(int,int)|int padding: how much to pad left/right in each axis
:param Dim|list[Dim]|None out_dims:
:param int|float value: what constant value to pad, with mode=="constant"
:param str mode: "constant", "reflect", "symmetric" and "replication"
:param axes: e.g. "F" etc. see :func:`Data.get_axes_from_description`.
:param padding: how much to pad left/right in each axis
:param out_dims:
:param handle_dynamic_dims: True: when doing right padding on a dynamic dim,
value will be added after the seq end,
not at the end of the dimension.
False: value will be added at the end of the dimension.
By default, in behavior version >=21, this is True, in older versions, this is False.
:param value: what constant value to pad, with mode=="constant"
:param mode: "constant", "reflect", "symmetric" and "replication"
"""
out_dims # noqa # handled in get_out_data_from_opts
super(PadLayer, self).__init__(**kwargs)
Expand All @@ -4190,15 +4207,42 @@ def __init__(self, axes, padding, out_dims=None, value=0, mode="constant", **kwa
paddings = [(0, 0)] * len(range(self.input_data.batch_ndim))
for i, a in enumerate(axes):
paddings[a] = padding[i]
mode = mode.upper()
mode = mode.lower()
if handle_dynamic_dims is None:
handle_dynamic_dims = self._handle_dynamic_dims_default(
pad_axes=[self.input_data.dims[axis] for axis in axes_],
padding=padding,
mode=mode,
)
if all(sum(p) == 0 for p in padding):
self.output.placeholder = self.input_data.placeholder
elif mode == "REPLICATION":
elif mode == "replication":
self.output.placeholder = tf_util.pad_replicate(self.input_data.placeholder, axes, padding)
else:
self.output.placeholder = tf.pad(
self.input_data.placeholder, paddings=paddings, mode=mode, constant_values=value
)
if any(dim.need_masking() for dim in out_dims) and handle_dynamic_dims:
if all(right == 0 for left, right in padding) and mode != "circular":
pass # no masking needed
else:
import returnn.frontend as rf

if mode != "constant":
raise NotImplementedError(
f"pad: mode {mode} not implemented with dynamic dims and handle_dynamic_dims=True"
)
for out_dim, middle_axis, (left, right) in zip(out_dims, axes, padding):
out_dim: Dim
middle = self.input_data.dims[middle_axis]
if middle.need_masking() or (isinstance(left, Dim) and left.need_masking()):
if isinstance(right, Dim) or right > 0:
mask = rf.compare_bc(rf.range_over_dim(out_dim), "<", (left + middle).dyn_size_ext)
self.output.raw_tensor = tf_util.where_bc(
mask.copy_compatible_to(self.output, check_sparse=False, check_dtype=False).raw_tensor,
self.output.raw_tensor,
tf.convert_to_tensor(value, dtype=self.output.dtype),
)

@classmethod
def _transform_padding(cls, padding, axes):
Expand All @@ -4218,6 +4262,45 @@ def _transform_padding(cls, padding, axes):
padding = [(padding, padding)] * len(axes)
return padding

_handle_dynamic_dims_shown_warning = False

@classmethod
def _handle_dynamic_dims_default(
cls, pad_axes: Sequence[Dim], padding: Sequence[Tuple[Union[Dim, int], Union[Dim, int]]], *, mode: str
) -> bool:
"""
:param pad_axes: list of axes to pad
:param padding: list of (left, right) padding for each axis
:param mode: 'constant', 'reflect', 'replicate' or 'circular'
:return: True if dynamic dims should be handled as specified in the default behavior
"""
from returnn.util.basic import BehaviorVersion

if BehaviorVersion.get() >= 21:
return True

# Check whether not handling the dynamic dims is safe. Print a warning if not safe.
if not cls._handle_dynamic_dims_shown_warning:
import logging

for middle, (left, right) in zip(pad_axes, padding):
middle: Dim
if not middle.need_masking() and (isinstance(left, int) or not left.need_masking()):
continue
if mode != "circular" and isinstance(right, int) and right == 0:
continue

logging.getLogger("returnn.tf").warning(
f"PadLayer applied on dynamic dim {middle} but handle_dynamic_dims=False used by default"
f" due to behavior version {BehaviorVersion.get()} < 21."
" Set handle_dynamic_dims explicitly to avoid the warning,"
" or switch to a new behavior version >= 21."
" (This warning is only printed once.)"
)
cls._handle_dynamic_dims_shown_warning = True
break
return False

@classmethod
def get_out_data_from_opts(cls, name, sources, axes, padding, out_dims=None, **kwargs):
"""
Expand Down Expand Up @@ -4929,14 +5012,16 @@ def get_out_data_from_opts(cls, name, axis, dims, pad_to_multiples=None, sources
rem_dim = None
if not resolved_dims:
resolved_dims = tuple(
Dim(
kind=axis_dim_tag.kind if not axis_dim_tag.is_batch_dim() else Dim.Types.Spatial,
description="%s_split_dims%i" % (name, i),
dimension=shape_dim,
auto_generated=True,
(
Dim(
kind=axis_dim_tag.kind if not axis_dim_tag.is_batch_dim() else Dim.Types.Spatial,
description="%s_split_dims%i" % (name, i),
dimension=shape_dim,
auto_generated=True,
)
if rem_dim is None or i != rem_dim_idx
else rem_dim
)
if rem_dim is None or i != rem_dim_idx
else rem_dim
for i, shape_dim in enumerate(resolved_shape_dims)
)
out_batch = data.batch
Expand Down
18 changes: 18 additions & 0 deletions returnn/torch/frontend/_backend.py
Expand Up @@ -439,6 +439,7 @@ def pad(
axes: Sequence[Dim],
padding: Sequence[Tuple[Union[Dim, int], Union[Dim, int]]],
out_dims: Sequence[Dim],
handle_dynamic_dims: bool,
mode: str = "constant",
value: Optional[Union[rf.RawTensorTypes, Tensor]] = None,
) -> Tensor:
Expand All @@ -465,6 +466,23 @@ def pad(
assert value.dims == (), f"value {value} must be a scalar"
value = value.raw_tensor
out.raw_tensor = torch.nn.functional.pad(source.raw_tensor, pad=raw_pad, mode=mode, value=value)
if any(dim.need_masking() for dim in out_dims) and handle_dynamic_dims:
if all(right == 0 for right in raw_pad[1::2]) and mode != "circular":
pass # no masking needed
else:
if mode != "constant":
raise NotImplementedError(
f"pad: mode {mode} not implemented with dynamic dims and handle_dynamic_dims=True"
)
for out_dim, middle, (left, right) in zip(out_dims, axes, padding):
if middle.need_masking() or (isinstance(left, Dim) and left.need_masking()):
if isinstance(right, Dim) or right > 0:
mask = rf.compare_bc(rf.range_over_dim(out_dim), "<", (left + middle).dyn_size_ext)
out.raw_tensor = torch.where(
mask.copy_compatible_to(out, check_dtype=False, check_sparse=False).raw_tensor,
out.raw_tensor,
value,
)
return out

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion returnn/util/basic.py
Expand Up @@ -203,7 +203,7 @@ class BehaviorVersion:
See :ref:`behavior_version`.
"""

_latest_behavior_version = 20
_latest_behavior_version = 21
_behavior_version = None # type: typing.Optional[int]
_min_behavior_version = 0 # type: int

Expand Down

0 comments on commit ea0f143

Please sign in to comment.