Skip to content

Commit

Permalink
Temporory hacks for ONNX conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-wilken committed May 7, 2024
1 parent d3204ae commit a36f93c
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 9 deletions.
5 changes: 4 additions & 1 deletion returnn/datasets/cached.py
Expand Up @@ -77,7 +77,10 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
self._update_tag_idx()
seq_index = [self._tag_idx[tag] for tag in seq_list]
else:
seq_index = self.get_seq_order_for_epoch(epoch, self._num_seqs, lambda s: self._get_seq_length_by_real_idx(s)[0])
data_index = 0
if not self.num_inputs and "data" in self.target_keys:
data_index = self.target_keys.index("data")
seq_index = self.get_seq_order_for_epoch(epoch, self._num_seqs, lambda s: self._get_seq_length_by_real_idx(s)[data_index])

old_index_map = self._index_map[:]
self._index_map = range(len(seq_index)) # sorted seq idx -> seq_index idx
Expand Down
2 changes: 1 addition & 1 deletion returnn/tf/compat.py
Expand Up @@ -24,7 +24,7 @@
if v2 and tf.__version__.startswith("2."):
tf.compat.v1.disable_eager_execution()
tf.compat.v1.disable_v2_tensorshape()
tf.compat.v1.disable_control_flow_v2()
# tf.compat.v1.disable_control_flow_v2()
# tf.compat.v1.disable_v2_behavior() -- not sure on this

try:
Expand Down
11 changes: 9 additions & 2 deletions returnn/tf/layers/rec.py
Expand Up @@ -1984,6 +1984,7 @@ def get_output(self):
" has to be the same. Found length ", data_len, " for %s, which does not match length " % key,
common_data_len, " of the other data."])]):
data_len = tf.identity(data_len)
print("DEBUG: creating TA for", key)
data_ta = tf.TensorArray(
name=key + "_ta",
dtype=data.dtype,
Expand Down Expand Up @@ -2222,6 +2223,7 @@ def get_loop_loss():
if not layer.output.mark_same_time(rec_layer.output):
continue
assert fixed_seq_len is not None
print("DEBUG: creating TA for input layer", layer_name)
inp_ta = tf.TensorArray(
name="%s_ta" % layer_name,
dtype=self.layer_data_templates[layer_name].output.dtype,
Expand All @@ -2239,6 +2241,8 @@ def get_loop_loss():
name="%s_ta_unstack" % layer_name)
input_layers_moved_out_tas[layer_name] = inp_ta

print("DEBUG: outputs_to_accumulate", outputs_to_accumulate)

# Create a tensor array to store the intermediate values for each step i, e.g. of shape (batch, dim).
init_acc_tas = [
tf.TensorArray(
Expand Down Expand Up @@ -2686,7 +2690,8 @@ def get_choice_seq(choice_base):
# The max_seq_len might actually be one more, as it includes the EOS, but that does not matter;
# we just want to create a new acc_ta with the same length.
# (Cutting off the EOS is handled elsewhere.)
max_seq_len = acc_ta.size()
#max_seq_len = acc_ta.size()
max_seq_len = tf.shape(acc_ta.stack())[0]
initial_i = tf.identity(max_seq_len - 1, name="search_resolve_initial_i") # we go backwards
latest_beam_size = latest_layer_choice.output.beam.beam_size
batch_dim = rec_layer.network.get_data_batch_dim()
Expand Down Expand Up @@ -2739,8 +2744,10 @@ def get_choice_seq(choice_base):
# Recombine batch and beam dims
seq_len = tf.reshape(seq_len, [batch_dim * latest_beam_size], name="merge_batch_beam")

print("DEBUG: creating TA for search resolved layer", layer.name)
new_acc_output_ta = tf.TensorArray(
name="search_resolved_%s" % os.path.basename(acc_ta.handle.op.name),
#name="search_resolved_%s" % os.path.basename(acc_ta.handle.op.name),
name="search_resolved",
dtype=layer.output.dtype,
element_shape=tf.TensorShape(layer.output.batch_shape),
size=max_seq_len,
Expand Down
12 changes: 7 additions & 5 deletions returnn/tf/util/basic.py
Expand Up @@ -5309,7 +5309,7 @@ def tensor_array_element_shape(ta):
return tf.TensorShape(None)


def tensor_array_like(ta, **kwargs):
def tensor_array_like(ta, size, **kwargs):
"""
:param tf.TensorArray ta:
:param kwargs: passed to tf.TensorArray constructor
Expand All @@ -5318,7 +5318,7 @@ def tensor_array_like(ta, **kwargs):
"""
# noinspection PyProtectedMember
return tf.TensorArray(
dtype=ta.dtype, size=ta.size(), dynamic_size=tensor_array_is_dynamic_size(ta),
dtype=ta.dtype, size=size, dynamic_size=tensor_array_is_dynamic_size(ta),
clear_after_read=tensor_array_is_clear_after_read(ta),
infer_shape=ta._infer_shape, element_shape=tensor_array_element_shape(ta),
**kwargs)
Expand Down Expand Up @@ -5352,8 +5352,9 @@ def tensor_array_stack(ta, start=0, stop=None, name="TensorArrayStack"):
with tf.name_scope(name):
if stop is None:
stop = ta.size()
return ta.gather(tf.range(start, stop), name=name)

#return ta.gather(tf.range(start, stop), name=name)
stacked_tensor = ta.stack(name=name)
return stacked_tensor[:stop]

def _tensor_array_select_src_beams(ta, src_beams):
"""
Expand All @@ -5367,7 +5368,7 @@ def _tensor_array_select_src_beams(ta, src_beams):
x = swapaxes(x, 0, 1) # (batch,time,...)
x = select_src_beams(x, src_beams=src_beams)
x = swapaxes(x, 0, 1) # (time,batch,...)
ta_new = tensor_array_like(ta)
ta_new = tensor_array_like(ta, size=tf.shape(x)[0])
ta_new = ta_new.unstack(x)
return ta_new

Expand Down Expand Up @@ -5470,6 +5471,7 @@ def select_src_beams(x, src_beams, name="select_src_beams"):
:rtype: tf.Tensor|T
"""
if isinstance(x, tf.TensorArray):
print("DEBUG: creating TA select source beams", x)
return _tensor_array_select_src_beams(x, src_beams=src_beams)
assert isinstance(x, tf.Tensor)
assert isinstance(src_beams, tf.Tensor)
Expand Down

0 comments on commit a36f93c

Please sign in to comment.