Skip to content

Commit

Permalink
TF PadLayer small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Apr 25, 2024
1 parent ea0f143 commit 349ea73
Showing 1 changed file with 20 additions and 21 deletions.
41 changes: 20 additions & 21 deletions returnn/tf/layers/basic.py
Expand Up @@ -4222,27 +4222,26 @@ def __init__(
self.output.placeholder = tf.pad(
self.input_data.placeholder, paddings=paddings, mode=mode, constant_values=value
)
if any(dim.need_masking() for dim in out_dims) and handle_dynamic_dims:
if all(right == 0 for left, right in padding) and mode != "circular":
pass # no masking needed
else:
import returnn.frontend as rf

if mode != "constant":
raise NotImplementedError(
f"pad: mode {mode} not implemented with dynamic dims and handle_dynamic_dims=True"
)
for out_dim, middle_axis, (left, right) in zip(out_dims, axes, padding):
out_dim: Dim
middle = self.input_data.dims[middle_axis]
if middle.need_masking() or (isinstance(left, Dim) and left.need_masking()):
if isinstance(right, Dim) or right > 0:
mask = rf.compare_bc(rf.range_over_dim(out_dim), "<", (left + middle).dyn_size_ext)
self.output.raw_tensor = tf_util.where_bc(
mask.copy_compatible_to(self.output, check_sparse=False, check_dtype=False).raw_tensor,
self.output.raw_tensor,
tf.convert_to_tensor(value, dtype=self.output.dtype),
)
if all(right == 0 for left, right in padding) and mode != "circular":
pass # no masking needed
else:
import returnn.frontend as rf

for middle_axis, (left, right) in zip(axes, padding):
out_dim: Dim = self.output.dims[middle_axis]
middle = self.input_data.dims[middle_axis]
if handle_dynamic_dims and middle.need_masking() or (isinstance(left, Dim) and left.need_masking()):
if mode != "constant":
raise NotImplementedError(
f"pad: mode {mode} not implemented with dynamic dims and handle_dynamic_dims=True"
)
if isinstance(right, Dim) or right > 0:
mask = rf.compare_bc(rf.range_over_dim(out_dim), "<", (left + middle).dyn_size_ext)
self.output.raw_tensor = tf_util.where_bc(
mask.copy_compatible_to(self.output, check_sparse=False, check_dtype=False).raw_tensor,
self.output.raw_tensor,
tf.convert_to_tensor(value, dtype=self.output.dtype),
)

@classmethod
def _transform_padding(cls, padding, axes):
Expand Down

0 comments on commit 349ea73

Please sign in to comment.