Skip to content

Commit

Permalink
rec layer: add initial_state placeholder (#1238)
Browse files Browse the repository at this point in the history
See #1236

Co-authored-by: Albert Zeyer <albzey@gmail.com>
  • Loading branch information
Gerstenberger and albertz committed Dec 6, 2022
1 parent 3da98b5 commit 90f0d7b
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion returnn/tf/layers/rec.py
Expand Up @@ -4873,6 +4873,12 @@ def update_var():
s = tf.cond(tf.equal(step, 0), lambda: tf.zeros(initial_shape), lambda: var.value())
s.set_shape(shape_invariant)
return s
elif initial_state == "placeholder":
assert rec_layer is not None
with rec_layer.var_creation_scope():
ph = tf_compat.v1.placeholder(
tf.float32, shape=shape_invariant, name="initial_state_placeholder_%s" % key_name)
return ph
else:
raise Exception("invalid initial state type %r for sub-layer %r, key %r" % (initial_state, name, key))

Expand Down Expand Up @@ -4908,7 +4914,7 @@ def resolve(v):
:return:
"""
if isinstance(v, str):
if v in ["zeros", "ones", "var", "keep_over_epoch", "keep_over_epoch_no_init"]:
if v in ["zeros", "ones", "var", "keep_over_epoch", "keep_over_epoch_no_init", "placeholder"]:
return v
return get_layer(v)
if isinstance(v, (tuple, list)):
Expand Down

0 comments on commit 90f0d7b

Please sign in to comment.