-
Notifications
You must be signed in to change notification settings - Fork 130
/
basic.py
11961 lines (11051 loc) · 499 KB
/
basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Many canonical basic layers.
"""
from __future__ import print_function
import tensorflow as tf
import contextlib
import typing
import returnn.tf.compat as tf_compat
import returnn.tf.util.basic as tf_util
from returnn.util.basic import unicode, NotSpecified
from returnn.tf.util.data import Data, SearchBeam, Dim, FeatureDim, SpatialDim
from returnn.tf.util.basic import OutputWithActivation, dimshuffle, swapaxes
from returnn.log import log
from .base import LayerBase, Loss, InternalLayer, SearchChoices
class SourceLayer(LayerBase):
"""
This gives access to some entry from network.extern_data (:class:`ExternData`).
"""
layer_class = "source"
def __init__(self, network, data_key=None, sources=(), **kwargs):
"""
:param returnn.tf.network.TFNetwork network:
:param str|None data_key:
:param tuple sources:
"""
if data_key is None:
data_key = network.extern_data.default_input
assert not sources, "source layer does not expect sources"
data = network.get_extern_data(data_key, mark_data_key_as_used=True).copy()
super(SourceLayer, self).__init__(network=network, **kwargs)
# Note: No check on data.placeholder. We allow to behave similar as DataNotAvailableLayer.
self.output = data
if self.output.beam:
# This can happen if register_as_extern_data was used on a layer with a beam.
search_choices = network.get_search_choices_from_beam(self.output.beam)
self.sources.append(search_choices.owner)
@classmethod
def transform_config_dict(cls, d, network, get_layer):
"""
:param dict[str] d: will modify inplace
:param returnn.tf.network.TFNetwork network:
:param ((str) -> LayerBase) get_layer: function to get or construct another layer
"""
d.setdefault("from", []) # source does not make sense
super(SourceLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer)
@classmethod
def get_out_data_from_opts(cls, network, data_key=None, **kwargs):
"""
:param returnn.tf.network.TFNetwork network:
:param str|None data_key:
:rtype: Data
"""
if data_key is None:
data_key = network.extern_data.default_input
return network.get_extern_data(data_key, mark_data_key_as_used=False).copy()
@contextlib.contextmanager
def _name_scope_for_concat_src_layers(src_layers, postfix):
"""
:param list[LayerBase] src_layers:
:param str postfix:
:return: yields scope via reuse_name_scope()
"""
assert src_layers
if len(src_layers) == 1:
name_scope = src_layers[0].get_absolute_name_scope_prefix() + postfix
else:
base = src_layers[0].network.get_absolute_name_scope_prefix()
name = "concat_" + "_".join([layer.tf_scope_name for layer in src_layers])
name_scope = base + name + "/" + postfix
from returnn.tf.util.basic import reuse_name_scope
with reuse_name_scope(name_scope, absolute=True) as scope:
yield scope
def concat_sources(src_layers, out_dim=None, allow_broadcast_all_sources=NotSpecified):
"""
:param list[LayerBase] src_layers:
:param Dim|None out_dim:
:param bool|NotSpecified allow_broadcast_all_sources:
:return: data with placeholders set
:rtype: Data
"""
assert src_layers, "need source layers"
if len(src_layers) == 1:
data = src_layers[0].output.copy()
if out_dim:
if out_dim == data.feature_dim_or_sparse_dim:
pass # good
elif out_dim in data.dim_tags:
# We found out_dim in the input but it is not marked as the feature dim.
# This is explicitly allowed. Follow-up code will expect this to be the feature-dim though,
# So we mark it accordingly.
assert not data.sparse
axis = data.get_axis_from_description(out_dim)
data.feature_dim_axis = axis
else:
raise Exception("%s not found in %s" % (out_dim, data))
return data
network = src_layers[0].network
cache_key = (tuple(src_layers), out_dim, 0.0, None)
if cache_key in network.concat_sources_dropout_cache:
return network.concat_sources_dropout_cache[cache_key].copy()
data = get_concat_sources_data_template(
src_layers, out_dim=out_dim, allow_broadcast_all_sources=allow_broadcast_all_sources)
layers_data = []
with _name_scope_for_concat_src_layers(src_layers, "concat_sources"):
for layer in src_layers:
assert not layer.output.sparse, "sparse concat not supported"
assert layer.output.dtype == data.dtype, "incompatible dtype with layer %r" % layer
# unbroadcast is needed for tf.concat.
layers_data.append(layer.output.copy_compatible_to(data, unbroadcast=True, except_feature=True))
data.placeholder = tf.concat(
axis=data.feature_dim_axis,
values=[layer_data.placeholder for layer_data in layers_data])
axes_split_info = [None] * data.batch_ndim # type: typing.List[typing.Optional[typing.List[int]]]
axes_split_info[data.feature_dim_axis] = [layer_data.dim for layer_data in layers_data]
tf_util.set_param_axes_split_info(data.placeholder, axes_split_info)
# Note: We will loose this info for any further op (e.g. dropout, activation, etc). Should be better...
# Maybe instead in Data class?
# Also note, even for tf.Variable, e.g. with weight noise, we might loose this?
network.concat_sources_dropout_cache[cache_key] = data.copy()
return data
def get_concat_sources_data_template(src_layers, out_dim=None, allow_broadcast_all_sources=NotSpecified, name=None):
"""
This just creates a template :class:`Data` instance,
without creating any real TF tensors.
:func:`concat_sources` (and related) are the equivalent functions
which would create a :class:`Data` together with the tensor.
:param list[LayerBase]|tuple[LayerBase] src_layers:
:param Dim|None out_dim:
:param bool|NotSpecified allow_broadcast_all_sources:
:param str|None name: name of the Data
:return: data with no placeholders set. it is always a copy or new instance, so safe to manipulate
:rtype: Data
"""
assert src_layers, "need source layers"
if len(src_layers) == 1:
data = src_layers[0].output.copy_template(name=name)
if out_dim:
assert out_dim == data.feature_dim_or_sparse_dim
return data
if not name:
name = "concat_" + "_".join([layer.name for layer in src_layers])
dim = None
common_source = Data.get_common_data(
[s.output for s in src_layers], ignore_feature_dim=True, allow_broadcast_all_sources=allow_broadcast_all_sources)
if not common_source.have_feature_axis(): # e.g. during template construction
return common_source
for layer in src_layers:
# Note: We do not perform much compatibility checks at this point,
# as this is for a template only anyway.
# The real checks are in concat_sources.
if layer.output.have_feature_axis(): # just ignore at this point if None (e.g. during template construction)
layer_dim = layer.output.feature_dim_or_sparse_dim
if layer_dim.dimension is not None: # maybe during template construction
if dim is None:
dim = layer_dim
else:
dim = dim + layer_dim
if out_dim:
assert out_dim.dimension == dim.dimension
else:
out_dim = dim
return common_source.copy_template_replace_dim_tag(
name=name,
axis=common_source.feature_dim_axis,
new_dim_tag=out_dim)
def concat_sources_with_opt_dropout(src_layers, out_dim=None,
dropout=0, dropout_axis=None, dropout_noise_shape=None, dropout_on_forward=False,
allow_broadcast_all_sources=NotSpecified):
"""
Concatenates in the feature dim (see :func:`concat_sources`),
and then optionally applies dropout.
:param list[LayerBase] src_layers:
:param Dim|None out_dim:
:param float dropout: dropout rate that will be applied if train_flag is set or dropout_on_forward is enabled
:param Dim|str|list[Dim|str]|None dropout_axis:
:param tuple|list|dict[Dim|str|list[Dim|str]|tuple[Dim|str],int|str|None]|None dropout_noise_shape:
provide 1 for broadcasting or None otherwise for each axis.
The default "None" will broadcast across all dynamic axes including the batch axis.
Use {"*": None} to disable broadcasting for all axes.
:param bool dropout_on_forward: apply dropout also during inference
:param bool|NotSpecified allow_broadcast_all_sources:
:return: data with placeholders set
:rtype: Data
"""
assert src_layers, "need source layers"
data = concat_sources(src_layers, out_dim=out_dim, allow_broadcast_all_sources=allow_broadcast_all_sources)
network = src_layers[0].network
if network.train_flag is False and not dropout_on_forward:
# If we know that we are not training, we always disable dropout.
dropout = 0
if not dropout:
return data.copy()
assert not data.sparse, "need dense data when dropout is used; sources: %r" % (src_layers,)
if dropout_axis is not None:
dropout_axis = data.get_axes_from_description(dropout_axis, allow_int=False)
assert not dropout_noise_shape, (
"do not provide both dropout_axis %r and dropout_noise_shape %r" % (dropout_axis, dropout_noise_shape))
dropout_noise_shape = [dim if i in dropout_axis else 1 for i, dim in enumerate(data.batch_shape)]
if isinstance(dropout_noise_shape, dict) or not dropout_noise_shape:
# Default noise_shape behavior is like old for now:
# All dynamic dimensions (batch,time) will use the same dropout-mask broadcasted.
dropout_noise_shape = data.get_bc_shape(dropout_noise_shape)
cache_key = (tuple(src_layers), out_dim, float(dropout), tuple(dropout_noise_shape))
if cache_key in network.concat_sources_dropout_cache:
return network.concat_sources_dropout_cache[cache_key].copy()
data = data.copy()
assert 0.0 < dropout < 1.0
with _name_scope_for_concat_src_layers(src_layers, "dropout_in_train"):
if dropout_on_forward:
data.placeholder = tf_util.dropout(
data.placeholder,
keep_prob=1 - dropout,
noise_shape=dropout_noise_shape,
seed=network.random.randint(2 ** 31))
else:
data.placeholder = network.cond_on_train(
fn_train=lambda: tf_util.dropout(
data.placeholder,
keep_prob=1 - dropout,
noise_shape=dropout_noise_shape,
seed=network.random.randint(2 ** 31)),
fn_eval=lambda: data.placeholder)
network.concat_sources_dropout_cache[cache_key] = data.copy()
return data
class _ConcatInputLayer(LayerBase):
"""
Base layer which concatenates all incoming source layers in the feature dimension,
and provides that as `self.input_data`, which is of type :class:`Data`.
This is the most common thing what many layers do with the input sources.
If there is only a single source, will not do anything.
This layer also optionally can do dropout on the input.
"""
def __init__(self, in_dim=None, out_shape=None,
dropout=0, dropout_axis=None, dropout_noise_shape=None, dropout_on_forward=False,
mask=None,
**kwargs):
"""
:param Dim|None in_dim:
:param set[Dim|returnn.tf.util.data._MarkedDim]|tuple|list|None out_shape:
:param float dropout: 0.0 means to apply no dropout. dropout will only be applied during training
:param Dim|str|list[Dim|str]|None dropout_axis:
:param dict[Dim|str|list[Dim|str]|tuple[Dim|str],int|str|None]|None dropout_noise_shape:
see :func:`Data.get_bc_shape`
:param bool dropout_on_forward: apply dropout during inference
:param str|None mask: "dropout" or "unity" or None. this is obsolete and only here for historical reasons
"""
super(_ConcatInputLayer, self).__init__(in_dim=in_dim, out_shape=out_shape, **kwargs)
assert mask in ['dropout', 'unity', None], "invalid mask: %r" % mask
if mask == "unity":
assert not dropout
elif mask == "dropout":
assert dropout > 0
self.dropout = dropout
self.input_data = None # type: typing.Optional[Data]
if self.sources:
self.input_data = concat_sources_with_opt_dropout(
self.sources, out_dim=in_dim,
dropout=dropout, dropout_axis=dropout_axis, dropout_noise_shape=dropout_noise_shape,
dropout_on_forward=dropout_on_forward,
allow_broadcast_all_sources=True if out_shape else NotSpecified)
class CopyLayer(_ConcatInputLayer):
"""
This layer does nothing, it copies its input.
If multiple sources are provided, they are concatenated in the feature-dim.
"""
layer_class = "copy"
def __init__(self, in_dim=None, out_dim=None, extra_deps=(), **kwargs):
"""
:param Dim|None in_dim:
:param Dim|None out_dim:
:param list[LayerBase] extra_deps: Just add as an additional dependency, without really using it.
This can have an effect though on the search beam, via :class:`SelectSearchSourcesLayer`.
We only have this here for the :class:`CopyLayer` because the :func:`get_out_data_from_opts`
must know about it and define the right beam.
Also see the option ``collocate_with``, which is different in that it does *not* add a dependency.
"""
if in_dim and out_dim:
assert in_dim == out_dim
in_dim = in_dim or out_dim
out_dim = in_dim
super(CopyLayer, self).__init__(in_dim=in_dim, out_dim=out_dim, **kwargs)
self.extra_deps = extra_deps
self.output = self.input_data.copy(name="%s_output" % self.name)
if len(self.sources) == 1:
self.output_loss = self.sources[0].output_loss
if not self.dropout:
self.output_before_activation = self.sources[0].output_before_activation
for src in self.sources:
if src.allow_inf_in_output:
self.allow_inf_in_output = True
def get_dep_layers(self):
"""
:rtype: list[LayerBase]
"""
return super(CopyLayer, self).get_dep_layers() + list(self.extra_deps)
@classmethod
def get_out_data_from_opts(cls, name, sources=(), extra_deps=(),
out_type=None, out_dim=None, n_out=NotSpecified, out_shape=None,
**kwargs):
"""
:param str name:
:param list[LayerBase] sources:
:param list[LayerBase] extra_deps:
:param dict[str]|None out_type:
:param Dim|None out_dim:
:param int|None|NotSpecified n_out:
:param set[Dim|returnn.tf.util.data._MarkedDim]|tuple|list|None out_shape:
:rtype: Data
"""
# If all sources are defined, use them to get the exact out_type.
out = get_concat_sources_data_template(
sources, out_dim=out_dim, name="%s_output" % name,
allow_broadcast_all_sources=True if out_shape else NotSpecified)
# Instead of checking or raising an exception, just overwrite, as this could be the template construction.
if out_type or n_out is not NotSpecified:
if not out_type:
out_type = {}
else:
out_type = out_type.copy()
if out.sparse:
out_type["sparse"] = True # otherwise the default get_out_data_from_opts would assume dense
if n_out is not NotSpecified:
out_type["dim"] = n_out
elif out.dim is not None:
out_type.setdefault("dim", out.dim)
out = super(CopyLayer, cls).get_out_data_from_opts(
name=name, out_type=out_type, n_out=n_out, out_dim=out_dim, out_shape=out_shape, sources=sources, **kwargs)
out.beam = SearchBeam.get_combined_beam(out.beam, *[dep.output.beam for dep in extra_deps if dep])
return out
@classmethod
def transform_config_dict(cls, d, network, get_layer):
"""
:param dict[str] d: will modify inplace
:param returnn.tf.network.TFNetwork network:
:param ((str) -> LayerBase) get_layer: function to get or construct another layer
"""
super(CopyLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer)
if "extra_deps" in d:
extra_deps = d["extra_deps"]
if not isinstance(extra_deps, (list, tuple)):
extra_deps = [extra_deps]
d["extra_deps"] = [get_layer(src_name) for src_name in extra_deps]
class ConcatLayer(LayerBase):
"""
Concatenates the inputs in specified axes.
This generalizes :class:`CopyLayer` which concatenates in the feature dim.
"""
layer_class = "concat"
def __init__(self, sources, allow_broadcast=False, out_dim=None, **kwargs):
"""
:param list[(LayerBase,str|Dim)] sources:
:param bool allow_broadcast:
:param Dim|None out_dim:
"""
sources, axes = zip(*sources) # unzip
super(ConcatLayer, self).__init__(sources=sources, **kwargs)
sources_data = [layer.output for layer in sources] # type: typing.List[Data]
axes_int = [src.get_axis_from_description(axis) for (src, axis) in zip(sources_data, axes)]
concat_dim_tags = [src.dim_tags[axis] for (src, axis) in zip(sources_data, axes_int)] # type: typing.List[Dim]
if not out_dim:
out_dim = sum(concat_dim_tags)
out_concat_axis = self.output.get_axis_from_description(out_dim)
def _copy_compatible(x, axis):
"""
:param Data x: input
:param int axis:
:rtype: Data
"""
dummy_ref = self.output.copy_template()
dummy_ref = dummy_ref.copy_template_replace_dim_tag(
axis=out_concat_axis, new_dim_tag=x.dim_tags[axis])
return x.copy_compatible_to(dummy_ref, add_dims=allow_broadcast, unbroadcast=False)
sources_data = [_copy_compatible(src, axis) for (src, axis) in zip(sources_data, axes_int)]
self.output.placeholder = tf_util.concat_with_opt_broadcast(
[src.placeholder for src in sources_data], axis=out_concat_axis,
allow_broadcast=[allow_broadcast] * len(sources_data))
@classmethod
def get_out_data_from_opts(cls, name, sources, out_dim=None, **kwargs):
"""
:param str name:
:param list[(LayerBase,str|Dim)] sources:
:param Dim|None out_dim:
:rtype: Data
"""
assert sources
sources, axes = zip(*sources) # unzip
axes_int = [layer.output.get_axis_from_description(axis) for (layer, axis) in zip(sources, axes)]
concat_dim_tags = [
layer.output.dim_tags[axis] for (layer, axis) in zip(sources, axes_int)] # type: typing.List[Dim]
if any(tag.dimension is None for tag in concat_dim_tags):
dimension = None
else:
dimension = 0
for tag in concat_dim_tags:
dimension += tag.dimension
if not out_dim:
out_dim = sum(concat_dim_tags)
assert isinstance(out_dim, Dim)
else:
sum(concat_dim_tags).declare_same_as(out_dim)
assert out_dim.dimension == dimension
def _as_common(x, axis):
"""
:param Data x: input
:param int axis:
:rtype: Data
"""
return x.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim)
sources_data = [_as_common(layer.output, axis) for (layer, axis) in zip(sources, axes_int)]
# Always allow broadcast here, for template construction. We will check it in __init__.
return Data.get_common_data(sources_data, allow_broadcast_all_sources=True, name="%s_output" % name)
@classmethod
def transform_config_dict(cls, d, network, get_layer):
"""
:param dict[str] d: will modify inplace
:param returnn.tf.network.TFNetwork network:
:param ((str) -> LayerBase) get_layer: function to get or construct another layer
"""
sources_and_axes = d.pop("from")
d["from"], axes = zip(*sources_and_axes)
super(ConcatLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer)
d["sources"] = list(zip(d["sources"], axes))
class DropoutLayer(CopyLayer):
"""
Just the same as :class:`CopyLayer`, because that one already supports dropout.
"""
layer_class = "dropout"
class ScaledGradientLayer(CopyLayer):
"""
Just :func:`tf.identity` in the forward pass.
Scales the gradient by some factor in backprop.
Can be used as gradient reversal layer (with negative factor).
Uses :func:`returnn.tf.util.basic.scaled_gradient`, or :func:`tf.stop_gradient`
"""
layer_class = "scaled_grad"
def __init__(self, scale, shift=None, scale_shift_by_sum_over_axis=None, clip_max_axis=None, **kwargs):
"""
:param float|LayerBase scale: if 0. and no shift, will use tf.stop_gradient
:param float|LayerBase|None shift:
:param Dim|str|None scale_shift_by_sum_over_axis: if given, calculates the sum over this axis (absolute values)
and multiplies the shift value by this sum.
:param Dim|str|None clip_max_axis: if given, clips the gradient to the max value in this axis
before the transformation, for all values in the axis
"""
super(ScaledGradientLayer, self).__init__(**kwargs)
self.scale = scale
self.shift = shift
from returnn.tf.util.basic import scaled_gradient
if isinstance(scale, (int, float)) and scale == 0. and shift is None:
self.output.placeholder = tf.stop_gradient(self.output.placeholder)
else:
scale_t = scale.output.copy_compatible_to(self.output).placeholder if isinstance(scale, LayerBase) else scale
shift_t = shift.output.copy_compatible_to(self.output).placeholder if isinstance(shift, LayerBase) else shift
if scale_shift_by_sum_over_axis is not None:
scale_shift_by_sum_over_axis = self.output.get_axis_from_description(
scale_shift_by_sum_over_axis, allow_int=False)
if clip_max_axis is not None:
clip_max_axis = self.output.get_axis_from_description(clip_max_axis, allow_int=False)
self.output.placeholder = scaled_gradient(
self.output.placeholder, scale=scale_t, shift=shift_t,
scale_shift_by_sum_over_axis=scale_shift_by_sum_over_axis, clip_max_axis=clip_max_axis)
def get_dep_layers(self):
"""
:rtype: list[LayerBase]
"""
deps = super(ScaledGradientLayer, self).get_dep_layers()
if isinstance(self.scale, LayerBase):
deps.append(self.scale)
if isinstance(self.shift, LayerBase):
deps.append(self.shift)
return deps
@classmethod
def transform_config_dict(cls, d, network, get_layer):
"""
:param dict[str] d: will modify inplace
:param returnn.tf.network.TFNetwork network:
:param ((str) -> LayerBase) get_layer: function to get or construct another layer
"""
super(ScaledGradientLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer)
if isinstance(d.get("scale"), str):
d["scale"] = get_layer(d["scale"])
if isinstance(d.get("shift"), str):
d["shift"] = get_layer(d["shift"])
class SelectSearchSourcesLayer(InternalLayer):
"""
Selects the corresponding search beams from the source, given current search choices
(determined by a layer).
Like :class:`InternalLayer`, only for internal purpose at the moment.
"""
@classmethod
def select_if_needed(cls, layer, search_choices):
"""
:param LayerBase layer:
:param SearchChoices|None search_choices:
:rtype: LayerBase
"""
assert isinstance(layer, LayerBase)
if not search_choices:
return layer
if layer.network.is_extra_internal_template_construction():
assert layer.output.placeholder is None # we expect a template
return layer
layer_search_choices = layer.get_search_choices()
if layer_search_choices and layer_search_choices.keep_raw:
return layer
if layer_search_choices == search_choices:
assert layer.output.beam == search_choices.get_beam_info(), "%r != %r. %s" % (
layer.output.beam, search_choices.get_beam_info(),
layer.network.debug_search_choices(layer) or "debug search dumped")
return layer
if layer.output.batch_dim_axis is None: # e.g. VariableLayer, ConstantLayer, or so
return layer
layer = SelectSearchSourcesLayer(sources=[layer], search_choices_layer=search_choices.owner)
assert layer.output.beam == search_choices.get_beam_info(), "%r != %r. %s" % (
layer.output.beam, search_choices.get_beam_info(),
layer.network.debug_search_choices(layer) or "debug search dumped")
return layer
def __init__(self, search_choices_layer, sources, **kwargs):
"""
:param LayerBase search_choices_layer:
:param list[LayerBase] sources:
"""
from returnn.tf.util.basic import select_src_beams, get_valid_scope_name_from_str, Dim
from pprint import pformat
assert len(sources) == 1
search_choices = search_choices_layer.get_search_choices()
src = sources[0]
kwargs = kwargs.copy()
kwargs["sources"] = sources
if "output" not in kwargs:
kwargs["output"] = src.output # will be reset later
if "network" not in kwargs:
kwargs["network"] = src.network
if "name" not in kwargs:
kwargs["name"] = src.name
if "_src_common_search_choices" not in kwargs:
kwargs["_src_common_search_choices"] = search_choices
super(SelectSearchSourcesLayer, self).__init__(**kwargs)
self.search_choices_layer = search_choices_layer
self.used_search_choices_beams = False
self.search_choices_from_layer = search_choices
self.output = src.output.copy_as_batch_major()
self.rec_vars_outputs = src.rec_vars_outputs.copy()
src_search_choices = src.get_search_choices()
self.transform_func = None # type: typing.Optional[typing.Callable[[tf.Tensor],tf.Tensor]]
self.search_choices_seq = None # type: typing.Optional[typing.List[SearchChoices]]
if not search_choices:
assert not src_search_choices
assert not self.output.beam
elif search_choices == src_search_choices:
pass
elif not src_search_choices:
assert not self.output.beam, ("no src %r search choices but beam?" % src, src.network.debug_search_choices(src))
self.output = self.output.copy_extend_with_beam(search_choices.get_beam_info())
else:
assert search_choices and search_choices != src_search_choices
search_choices_seq = search_choices.get_src_choices_seq()
assert src_search_choices in search_choices_seq, self.network.debug_search_choices(self.search_choices_layer) or (
("%s: No common search base:\n"
"from layer %s\n"
"search choices %s,\n"
"to layer %s\n"
"search choices\n%s.") % (
self, src, src_search_choices, self.search_choices_layer, pformat(search_choices_seq)))
search_choices_seq = search_choices_seq[:search_choices_seq.index(src_search_choices)]
assert src_search_choices not in search_choices_seq
assert search_choices_seq
self.output.beam = search_choices.get_beam_info()
if self.output.batch:
self.output.batch = self.output.batch.copy_set_beam(self.output.beam)
def transform(v):
"""
:param tuple|list|tf.Tensor|tf.TensorArray|T v:
:rtype: T
"""
if isinstance(v, (tuple, list)):
from returnn.util.basic import make_seq_of_type
return make_seq_of_type(type(v), [transform(v_) for v_ in v])
assert isinstance(v, (tf.Tensor, tf.TensorArray))
if isinstance(v, tf.Tensor) and v.get_shape().ndims == 0:
return v # leave scalars as-is
if isinstance(v, tf.Tensor) and getattr(v, "_RETURNN_beam_expanded_base_data", None):
# This tensor was just expanded by a beam. Selecting beams are not needed.
return v
for i, base_src_choices in enumerate(reversed(search_choices_seq)):
assert isinstance(base_src_choices, SearchChoices)
assert base_src_choices.src_beams is not None, (
self.network.debug_search_choices(self.search_choices_layer) or (
("Cannot transform %r,\n"
"search choices %r,\n"
"to search choices %r.\n"
"Missing beam idxs.") % (src, src_search_choices, search_choices_seq)))
tag = Dim.get_tag_from_size_tensor(v)
if tag:
assert tag.dyn_size_ext.is_batch_major
v = select_src_beams(
v, src_beams=base_src_choices.src_beams,
name="%s_select_src_beams_%i_%s_%i_%s" % (
get_valid_scope_name_from_str(self.name),
i, get_valid_scope_name_from_str(base_src_choices.owner.name),
len(search_choices_seq), get_valid_scope_name_from_str(search_choices.owner.name)))
if tag:
tag.set_tag_on_size_tensor(v, batch=self.output.batch.copy_set_beam(base_src_choices.get_beam_info()))
self.used_search_choices_beams = True
return v
self.search_choices_seq = search_choices_seq
self.transform_func = transform
# It's possible that src.output.placeholder is not set, e.g. in a prev-layer where the
# prev output is not needed, only the prev state. See _TemplateLayer.copy_as_prev_time_frame.
src_output = src.output.copy_as_batch_major()
if src_output.placeholder is not None:
self.output.placeholder = transform(src_output.placeholder)
for src_tag, out_tag in zip(src_output.dim_tags, self.output.dim_tags):
assert src_tag.dimension == out_tag.dimension
if src_tag.is_batch_dim():
assert out_tag.batch == self.output.batch
continue
if src_tag.dimension is not None:
continue
if out_tag.dyn_size_ext is None:
if src_tag.dyn_size_ext.have_batch_axis():
out_tag.dyn_size_ext = src_tag.dyn_size_ext.copy_template()
out_tag.dyn_size_ext.beam = None
out_tag.dyn_size_ext = out_tag.dyn_size_ext.copy_extend_with_beam(self.output.beam)
else:
out_tag.dyn_size_ext = src_tag.dyn_size_ext.copy()
if out_tag.dyn_size_ext.placeholder is None:
assert out_tag.dyn_size_ext.have_batch_axis() and out_tag.dyn_size_ext.is_batch_major
out_tag.dyn_size_ext.placeholder = transform(src_tag.dyn_size_ext.placeholder)
if out_tag.dyn_size_ext.have_batch_axis():
assert out_tag.dyn_size_ext.batch == out_tag.batch == self.output.batch
self.rec_vars_outputs = {k: transform(v) for (k, v) in src.rec_vars_outputs.items()} # assumes batch-major
for src in self.sources:
if src.allow_inf_in_output:
self.allow_inf_in_output = True
def __repr__(self):
return "<%s %r %r out_type=%s>" % (
self.__class__.__name__, self.name, self.search_choices_from_layer,
self.output.get_description(with_name=False) if self.output else None)
def get_dep_layers(self):
"""
:rtype: list[LayerBase]
"""
dep_layers = super(SelectSearchSourcesLayer, self).get_dep_layers()
if self.used_search_choices_beams: # only in that case, it is really a dependency
dep_layers.append(self.search_choices_layer)
return dep_layers
@classmethod
def transform_config_dict(cls, d, network, get_layer):
"""
:param dict[str] d:
:param returnn.tf.network.TFNetwork network:
:param get_layer:
"""
super(SelectSearchSourcesLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer)
d["search_choices"] = get_layer(d["search_choices"])
@classmethod
def get_out_data_from_opts(cls, name, sources, search_choices, **kwargs):
"""
:param str name:
:param list[LayerBase] sources:
:param LayerBase search_choices:
:rtype: Data
"""
assert len(sources) == 1
search_choices_ = search_choices.get_search_choices()
data = sources[0].output.copy_template().copy_as_batch_major()
if data.beam or search_choices_:
assert search_choices_
data = data.copy_extend_with_beam(search_choices_.get_beam_info())
return data
class ActivationLayer(_ConcatInputLayer):
"""
This layer just applies an activation function.
See :func:`returnn.tf.util.basic.get_activation_function` about supported functions.
Also see :class:`EvalLayer` and :class:`CombineLayer` for similar layers.
"""
layer_class = "activation"
def __init__(self, activation, opts=None, **kwargs):
"""
:param str activation: e.g. "relu", "tanh", etc
:param dict[str]|None opts: for activation function, e.g. eps for safe_log
"""
super(ActivationLayer, self).__init__(**kwargs)
x = self.input_data.copy_compatible_to(self.output, check_dtype=False).placeholder
if activation:
if "softmax" in activation:
assert not opts # do not set axis or anything. this handled automatically. we moved feature to last axis.
if self.output.dim_tags[-1].is_dynamic():
self.recurrent = True
from returnn.tf.util.basic import get_activation_function
act_func = get_activation_function(activation)
self.output_before_activation = OutputWithActivation(x, act_func=act_func, act_func_opts=opts)
else:
self.output_before_activation = OutputWithActivation(x)
if self.output_before_activation:
self.output.placeholder = self.output_before_activation.y
@classmethod
def get_out_data_from_opts(cls, activation, **kwargs):
"""
:param str activation:
:rtype: Data
"""
# Just the same as the input.
# Use CopyLayer.get_out_data_from_opts for potential extra logic for out_type.
out = CopyLayer.get_out_data_from_opts(**kwargs)
# Modify dtype if needed based on activation function
if activation in ["abs", "angle"] and out.dtype == "complex64":
out.dtype = "float32"
if "softmax" in activation:
# Make sure we use the right axis.
out = out.copy_with_feature_last()
return out
class BatchNormLayer(CopyLayer):
"""
Implements batch-normalization (https://arxiv.org/abs/1502.03167) as a separate layer.
Also see :class:`NormLayer`.
"""
layer_class = "batch_norm"
def __init__(self,
in_dim=None,
use_shift=NotSpecified, use_std=NotSpecified,
use_sample=NotSpecified, force_sample=NotSpecified,
momentum=NotSpecified, epsilon=NotSpecified,
update_sample_only_in_training=NotSpecified,
delay_sample_update=NotSpecified,
param_version=NotSpecified,
gamma_init=NotSpecified, beta_init=NotSpecified,
masked_time=NotSpecified, **kwargs):
"""
:param returnn.tf.util.data.Dim|None in_dim:
:param bool use_shift:
:param bool use_std:
:param float use_sample: defaults to 0.0 which is used in training
:param bool force_sample: even in eval, use the use_sample factor
:param float momentum: for the running average of sample_mean and sample_std
:param bool update_sample_only_in_training:
:param bool delay_sample_update:
:param int param_version: 0 or 1 or 2
:param float epsilon:
:param str|float gamma_init: see :func:`returnn.tf.util.basic.get_initializer`, for the scale
:param str|float beta_init: see :func:`returnn.tf.util.basic.get_initializer`, for the mean
:param bool masked_time: flatten and mask input tensor
The default settings for these variables are set in the function :func:`batch_norm` of :class:`LayerBase`.
If you do not want to change them you can leave them undefined here.
With our default settings:
- In training: use_sample=0, i.e. not using running average, using current batch mean/var.
- Not in training (e.g. eval): use_sample=1, i.e. using running average, not using current batch mean/var.
- The running average includes the statistics of the current batch.
- The running average is also updated when not training.
"""
local = locals()
from returnn.util.basic import getargspec
batch_norm_kwargs = getargspec(self.batch_norm).args[1:] # first is self, ignore
batch_norm_opts = {key: local[key] for key in batch_norm_kwargs if key in local and local[key] != NotSpecified}
super(BatchNormLayer, self).__init__(batch_norm=batch_norm_opts or True, in_dim=in_dim, **kwargs)
if in_dim:
# should be the case via get_out_data_from_opts
assert self.output.dim_tags[self.output.feature_dim_axis] == in_dim
# batch norm is now applied via post_init
class LayerNormLayer(_ConcatInputLayer):
"""
Applies `layer-normalization <https://arxiv.org/abs/1607.06450>`__.
Note that we *just* normalize over the feature-dim axis here.
This is consistent to the default behavior of :class:`tf.keras.layers.LayerNormalization`
and also how it is commonly used in many models, including Transformer.
However, there are cases where it would be common to normalize over all axes except batch-dim,
or all axes except batch and time.
For a more generic variant, see :class:`NormLayer`.
"""
layer_class = "layer_norm"
def __init__(self, in_dim=None, out_dim=None, epsilon=1e-6, **kwargs):
"""
:param Dim|None in_dim: axis to normalize over. feature-dim by default
:param Dim|None out_dim: just the same as in_dim
:param float epsilon:
"""
super(LayerNormLayer, self).__init__(**kwargs)
assert not self.input_data.sparse
x = self.input_data.placeholder
if not in_dim and out_dim:
in_dim = out_dim
if in_dim:
if out_dim:
assert in_dim == out_dim
assert isinstance(in_dim, Dim)
axis = self.input_data.get_axis_from_description(in_dim)
else:
axis = self.input_data.feature_dim_axis
dim = self.input_data.batch_shape[axis]
assert dim is not None, "%s: in_dim %i must be static in input %s" % (self, in_dim or axis, self.input_data)
with self.var_creation_scope():
scale = self.add_param(tf_compat.v1.get_variable("scale", [dim], initializer=tf.ones_initializer()))
bias = self.add_param(tf_compat.v1.get_variable("bias", [dim], initializer=tf.zeros_initializer()))
mean = tf.reduce_mean(x, axis=[axis], keepdims=True, name="mean")
variance = tf.reduce_mean(tf.square(x - mean), axis=[axis], keepdims=True, name="variance")
with tf.name_scope("normalized"):
norm_x = (x - mean) * tf_compat.v1.rsqrt(variance + epsilon)
if axis != self.input_data.batch_ndim - 1:
ndim = self.input_data.batch_ndim
scale_bc = tf.reshape(scale, [dim if i == axis else 1 for i in range(ndim)])
bias_bc = tf.reshape(bias, [dim if i == axis else 1 for i in range(ndim)])
self.output.placeholder = norm_x * scale_bc + bias_bc
else:
self.output.placeholder = norm_x * scale + bias
self.output.size_placeholder = self.input_data.size_placeholder.copy()
@classmethod
def get_out_data_from_opts(cls, sources, name, **kwargs):
"""
:param list[LayerBase] sources:
:param str name:
:rtype: Data
"""
return get_concat_sources_data_template(sources, name="%s_output" % name)
class NormLayer(_ConcatInputLayer):
"""
Normalize over specified axes, e.g. time and/or feature axis.
Note: For calculating a norm, see :class:`MathNormLayer` instead.
In case of just feature (``axes="F"``),
this corresponds to `layer normalization <https://arxiv.org/abs/1607.06450>`__ (see :class:`LayerNormLayer`).
In case of time and feature (``axes="TF"``) for a 3D input,
or more general all except batch (``axes="except_batch"``),
this corresponds to `group normalization <https://arxiv.org/abs/1803.08494>`__ with G=1,
or non-standard layer normalization.
(The definition of layer-normalization is not clear on what axes should be normalized over.
In many other frameworks, the default axis is just the last axis,
which is usually the feature axis.
However, in certain implementations and models,
it is also common to normalize over all axes except batch.)
The statistics are calculated just on the input.
There are no running statistics (in contrast to batch normalization, see :class:`BatchNormLayer`).
For some discussion on the definition of layer-norm vs group-norm,
also see
`here <https://stats.stackexchange.com/questions/485550/is-group-norm-with-g-1-equiv-to-layer-norm>`__
and `here <https://github.com/tensorflow/addons/issues/2143>`__.
"""
layer_class = "norm"
def __init__(self, axis=NotSpecified, axes=NotSpecified,
param_shape=NotSpecified, scale=True, bias=True, epsilon=1e-6, **kwargs):
"""
:param Dim|str|list[Dim|str] axis: axis or axes over which the mean and variance are computed, e.g. "F" or "TF"
:param Dim|str|list[Dim|str] axes: axis or axes over which the mean and variance are computed, e.g. "F" or "TF"
:param Dim|str|list[Dim|str]|tuple[Dim|str] param_shape: shape of the scale and bias parameters.
You can also refer to (static) axes of the input, such as the feature-dim.
This is also the default, i.e. a param-shape of [F], independent of the axes to normalize over.
:param bool scale: add trainable scale parameters
:param bool bias: add trainable bias parameters
:param float epsilon: epsilon for numerical stability
"""
if axis is not NotSpecified:
assert axes is NotSpecified
axes = axis
super(NormLayer, self).__init__(**kwargs)
assert not self.input_data.sparse
x = self.input_data.placeholder
if scale or bias:
if param_shape is NotSpecified:
param_shape = "F"
if isinstance(param_shape, (list, tuple)):
param_axes = [self.input_data.get_axis_from_description(a, allow_int=False) for a in param_shape]
else:
param_axes = [self.input_data.get_axis_from_description(param_shape, allow_int=False)]
assert sorted(set(param_axes)) == sorted(param_axes), "%s: param_shape %r should be unique" % (self, param_shape)
param_shape = [self.input_data.batch_shape[axis] for axis in param_axes]
assert all(isinstance(dim, int) for dim in param_shape), "%s: only static param shape allowed" % self
param_dim_tags = [self.input_data.dim_tags[axis] for axis in param_axes]
else:
assert param_shape is NotSpecified or not param_shape
param_dim_tags = None
axes = self.input_data.get_axes_from_description(axes)
mean = tf.reduce_mean(x, axis=axes, keepdims=True, name="mean")
variance = tf.reduce_mean(tf.square(x - mean), axis=axes, keepdims=True, name="variance")
with tf.name_scope("normalized"):
norm_x = (x - mean) * tf_compat.v1.rsqrt(variance + epsilon)
if scale:
with self.var_creation_scope():
scale_param = self.add_param(tf_compat.v1.get_variable("scale", param_shape, initializer=tf.ones_initializer()))
norm_x *= (
Data(name="scale_param", dim_tags=param_dim_tags, placeholder=scale_param)
.copy_compatible_to(self.output).placeholder)
if bias:
with self.var_creation_scope():
bias_param = self.add_param(tf_compat.v1.get_variable("bias", param_shape, initializer=tf.zeros_initializer()))
norm_x += (
Data(name="bias_param", dim_tags=param_dim_tags, placeholder=bias_param)
.copy_compatible_to(self.output).placeholder)
self.output.placeholder = norm_x
self.output.size_placeholder = self.input_data.size_placeholder.copy()
@classmethod
def get_out_data_from_opts(cls, sources, name, **kwargs):
"""
:param list[LayerBase] sources:
:param str name:
:rtype: Data
"""
return get_concat_sources_data_template(sources, name="%s_output" % name)
class MathNormLayer(_ConcatInputLayer):
"""
Calculates sum(abs(x) ** p) ** (1./p).
"""
layer_class = "math_norm"
def __init__(self, p, axis=NotSpecified, axes=NotSpecified, keep_dims=False, **kwargs):
"""
:param int|float p:
:param Dim|str|list[Dim|str] axis:
:param Dim|str|list[Dim|str] axes:
:param bool keep_dims:
"""
if axis is not NotSpecified:
assert axes is NotSpecified
axes = axis
super(MathNormLayer, self).__init__(**kwargs)
x = self.input_data.copy()
x.placeholder = tf.abs(x.placeholder) ** p
self.output.placeholder = ReduceLayer.reduce(x, mode="sum", axes=axes, keep_dims=keep_dims) ** (1. / p)