Skip to content

Commit

Permalink
TF fix tests for new PadLayer dyn dim handling
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Apr 25, 2024
1 parent cd4eced commit 3f055e6
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 25 deletions.
52 changes: 29 additions & 23 deletions tests/test_TFEngine.py
Expand Up @@ -4409,29 +4409,35 @@ def get_output_dict(train, search, t_target, beam_size=beam_size):
"loss": "as_is",
"loss_scale": 0,
},
"use_t_search": {
"class": "compare",
"kind": "less",
"from": ["existing_align_score", "extra.search:search_score"],
}
if have_existing_align
else {"class": "constant", "value": True},
"t_search_or_fallback": {
"class": "switch",
"condition": "use_t_search",
"true_from": "extra.search:t_search",
"false_from": "data:t_base",
}
if have_existing_align
else {"class": "copy", "from": "data:t_base"},
"t_search_or_fallback_score": {
"class": "switch",
"condition": "use_t_search",
"true_from": "extra.search:search_score",
"false_from": "existing_align_score",
}
if have_existing_align
else {"class": "copy", "from": "extra.search:search_score"},
"use_t_search": (
{
"class": "compare",
"kind": "less",
"from": ["existing_align_score", "extra.search:search_score"],
}
if have_existing_align
else {"class": "constant", "value": True}
),
"t_search_or_fallback": (
{
"class": "switch",
"condition": "use_t_search",
"true_from": "extra.search:t_search",
"false_from": "data:t_base",
}
if have_existing_align
else {"class": "copy", "from": "data:t_base"}
),
"t_search_or_fallback_score": (
{
"class": "switch",
"condition": "use_t_search",
"true_from": "extra.search:search_score",
"false_from": "existing_align_score",
}
if have_existing_align
else {"class": "copy", "from": "extra.search:search_score"}
),
}
)
if epoch0 is not None and epoch0 < StoreAlignmentUpToEpoch:
Expand Down
23 changes: 21 additions & 2 deletions tests/test_TFNetworkLayer.py
Expand Up @@ -362,7 +362,16 @@ def test_PadLayer_time():
padding = (2, 3)
net = TFNetwork(config=config)
net.construct_from_dict(
{"output": {"class": "pad", "axes": "T", "padding": padding, "mode": "replication", "from": "data:data"}}
{
"output": {
"class": "pad",
"axes": "T",
"padding": padding,
"handle_dynamic_dims": False, # our test below does not handle dyn seq lens
"mode": "replication",
"from": "data:data",
}
}
)
out_t = net.get_default_output_layer().output.placeholder
assert out_t.shape.as_list() == [None, None, n_in]
Expand Down Expand Up @@ -395,7 +404,16 @@ def test_PadLayer_feature():
padding = (2, 3)
net = TFNetwork(config=config)
net.construct_from_dict(
{"output": {"class": "pad", "axes": "F", "padding": padding, "mode": "replication", "from": "data:data"}}
{
"output": {
"class": "pad",
"axes": "F",
"padding": padding,
"handle_dynamic_dims": False, # our test below does not handle dyn seq lens
"mode": "replication",
"from": "data:data",
}
}
)
out_t = net.get_default_output_layer().output.placeholder
assert out_t.shape.as_list() == [None, None, None]
Expand Down Expand Up @@ -12397,6 +12415,7 @@ def test_automatic_seq_lengths():
"mode": "reflect",
"axes": "spatial",
"padding": (3, 3),
"handle_dynamic_dims": False, # not supported yet otherwise
"from": "data",
}, # len+6
"layer1": {
Expand Down

0 comments on commit 3f055e6

Please sign in to comment.