/
activation.py
1495 lines (1131 loc) · 51.7 KB
/
activation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import warnings
from typing import Optional, Tuple
import torch
from torch import Tensor
from .linear import NonDynamicallyQuantizableLinear
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.parameter import Parameter
from .module import Module
from .. import functional as F
__all__ = ['Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid', 'Tanh',
'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU', 'Hardshrink', 'LeakyReLU',
'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Tanhshrink',
'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax']
class Threshold(Module):
r"""Thresholds each element of the input Tensor.
Threshold is defined as:
.. math::
y =
\begin{cases}
x, &\text{ if } x > \text{threshold} \\
\text{value}, &\text{ otherwise }
\end{cases}
Args:
threshold: The value to threshold at
value: The value to replace with
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
Examples::
>>> m = nn.Threshold(0.1, 20)
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['threshold', 'value', 'inplace']
threshold: float
value: float
inplace: bool
def __init__(self, threshold: float, value: float, inplace: bool = False) -> None:
super(Threshold, self).__init__()
self.threshold = threshold
self.value = value
self.inplace = inplace
# TODO: check in THNN (if inplace == True, then assert value <= threshold)
def forward(self, input: Tensor) -> Tensor:
return F.threshold(input, self.threshold, self.value, self.inplace)
def extra_repr(self):
inplace_str = ', inplace=True' if self.inplace else ''
return 'threshold={}, value={}{}'.format(
self.threshold, self.value, inplace_str
)
class ReLU(Module):
r"""Applies the rectified linear unit function element-wise:
:math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
Args:
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/ReLU.png
Examples::
>>> m = nn.ReLU()
>>> input = torch.randn(2)
>>> output = m(input)
An implementation of CReLU - https://arxiv.org/abs/1603.05201
>>> m = nn.ReLU()
>>> input = torch.randn(2).unsqueeze(0)
>>> output = torch.cat((m(input),m(-input)))
"""
__constants__ = ['inplace']
inplace: bool
def __init__(self, inplace: bool = False):
super(ReLU, self).__init__()
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.relu(input, inplace=self.inplace)
def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
class RReLU(Module):
r"""Applies the randomized leaky rectified liner unit function, element-wise,
as described in the paper:
`Empirical Evaluation of Rectified Activations in Convolutional Network`_.
The function is defined as:
.. math::
\text{RReLU}(x) =
\begin{cases}
x & \text{if } x \geq 0 \\
ax & \text{ otherwise }
\end{cases}
where :math:`a` is randomly sampled from uniform distribution
:math:`\mathcal{U}(\text{lower}, \text{upper})`.
See: https://arxiv.org/pdf/1505.00853.pdf
Args:
lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/RReLU.png
Examples::
>>> m = nn.RReLU(0.1, 0.3)
>>> input = torch.randn(2)
>>> output = m(input)
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
https://arxiv.org/abs/1505.00853
"""
__constants__ = ['lower', 'upper', 'inplace']
lower: float
upper: float
inplace: bool
def __init__(
self,
lower: float = 1. / 8,
upper: float = 1. / 3,
inplace: bool = False
):
super(RReLU, self).__init__()
self.lower = lower
self.upper = upper
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
def extra_repr(self):
inplace_str = ', inplace=True' if self.inplace else ''
return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
class Hardtanh(Module):
r"""Applies the HardTanh function element-wise.
HardTanh is defined as:
.. math::
\text{HardTanh}(x) = \begin{cases}
\text{max\_val} & \text{ if } x > \text{ max\_val } \\
\text{min\_val} & \text{ if } x < \text{ min\_val } \\
x & \text{ otherwise } \\
\end{cases}
Args:
min_val: minimum value of the linear region range. Default: -1
max_val: maximum value of the linear region range. Default: 1
inplace: can optionally do the operation in-place. Default: ``False``
Keyword arguments :attr:`min_value` and :attr:`max_value`
have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/Hardtanh.png
Examples::
>>> m = nn.Hardtanh(-2, 2)
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['min_val', 'max_val', 'inplace']
min_val: float
max_val: float
inplace: bool
def __init__(
self,
min_val: float = -1.,
max_val: float = 1.,
inplace: bool = False,
min_value: Optional[float] = None,
max_value: Optional[float] = None
) -> None:
super(Hardtanh, self).__init__()
if min_value is not None:
warnings.warn("keyword argument min_value is deprecated and rename to min_val")
min_val = min_value
if max_value is not None:
warnings.warn("keyword argument max_value is deprecated and rename to max_val")
max_val = max_value
self.min_val = min_val
self.max_val = max_val
self.inplace = inplace
assert self.max_val > self.min_val
def forward(self, input: Tensor) -> Tensor:
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
def extra_repr(self) -> str:
inplace_str = ', inplace=True' if self.inplace else ''
return 'min_val={}, max_val={}{}'.format(
self.min_val, self.max_val, inplace_str
)
class ReLU6(Hardtanh):
r"""Applies the element-wise function:
.. math::
\text{ReLU6}(x) = \min(\max(0,x), 6)
Args:
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/ReLU6.png
Examples::
>>> m = nn.ReLU6()
>>> input = torch.randn(2)
>>> output = m(input)
"""
def __init__(self, inplace: bool = False):
super(ReLU6, self).__init__(0., 6., inplace)
def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
class Sigmoid(Module):
r"""Applies the element-wise function:
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/Sigmoid.png
Examples::
>>> m = nn.Sigmoid()
>>> input = torch.randn(2)
>>> output = m(input)
"""
def forward(self, input: Tensor) -> Tensor:
return torch.sigmoid(input)
class Hardsigmoid(Module):
r"""Applies the Hardsigmoid function element-wise.
Hardsigmoid is defined as:
.. math::
\text{Hardsigmoid}(x) = \begin{cases}
0 & \text{if~} x \le -3, \\
1 & \text{if~} x \ge +3, \\
x / 6 + 1 / 2 & \text{otherwise}
\end{cases}
Args:
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/Hardsigmoid.png
Examples::
>>> m = nn.Hardsigmoid()
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['inplace']
inplace: bool
def __init__(self, inplace : bool = False) -> None:
super(Hardsigmoid, self).__init__()
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.hardsigmoid(input, self.inplace)
class Tanh(Module):
r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
Tanh is defined as:
.. math::
\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/Tanh.png
Examples::
>>> m = nn.Tanh()
>>> input = torch.randn(2)
>>> output = m(input)
"""
def forward(self, input: Tensor) -> Tensor:
return torch.tanh(input)
class SiLU(Module):
r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
The SiLU function is also known as the swish function.
.. math::
\text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
.. note::
See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
where the SiLU (Sigmoid Linear Unit) was originally coined, and see
`Sigmoid-Weighted Linear Units for Neural Network Function Approximation
in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
where the SiLU was experimented with later.
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/SiLU.png
Examples::
>>> m = nn.SiLU()
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['inplace']
inplace: bool
def __init__(self, inplace: bool = False):
super(SiLU, self).__init__()
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.silu(input, inplace=self.inplace)
def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
class Mish(Module):
r"""Applies the Mish function, element-wise.
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
.. math::
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
.. note::
See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/Mish.png
Examples::
>>> m = nn.Mish()
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['inplace']
inplace: bool
def __init__(self, inplace: bool = False):
super(Mish, self).__init__()
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.mish(input, inplace=self.inplace)
def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
class Hardswish(Module):
r"""Applies the Hardswish function, element-wise, as described in the paper:
`Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
Hardswish is defined as:
.. math::
\text{Hardswish}(x) = \begin{cases}
0 & \text{if~} x \le -3, \\
x & \text{if~} x \ge +3, \\
x \cdot (x + 3) /6 & \text{otherwise}
\end{cases}
Args:
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/Hardswish.png
Examples::
>>> m = nn.Hardswish()
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['inplace']
inplace: bool
def __init__(self, inplace : bool = False) -> None:
super(Hardswish, self).__init__()
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.hardswish(input, self.inplace)
class ELU(Module):
r"""Applies the Exponential Linear Unit (ELU) function, element-wise, as described
in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
ELU is defined as:
.. math::
\text{ELU}(x) = \begin{cases}
x, & \text{ if } x > 0\\
\alpha * (\exp(x) - 1), & \text{ if } x \leq 0
\end{cases}
Args:
alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/ELU.png
Examples::
>>> m = nn.ELU()
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['alpha', 'inplace']
alpha: float
inplace: bool
def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
super(ELU, self).__init__()
self.alpha = alpha
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.elu(input, self.alpha, self.inplace)
def extra_repr(self) -> str:
inplace_str = ', inplace=True' if self.inplace else ''
return 'alpha={}{}'.format(self.alpha, inplace_str)
class CELU(Module):
r"""Applies the element-wise function:
.. math::
\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
Args:
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/CELU.png
Examples::
>>> m = nn.CELU()
>>> input = torch.randn(2)
>>> output = m(input)
.. _`Continuously Differentiable Exponential Linear Units`:
https://arxiv.org/abs/1704.07483
"""
__constants__ = ['alpha', 'inplace']
alpha: float
inplace: bool
def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
super(CELU, self).__init__()
self.alpha = alpha
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.celu(input, self.alpha, self.inplace)
def extra_repr(self) -> str:
inplace_str = ', inplace=True' if self.inplace else ''
return 'alpha={}{}'.format(self.alpha, inplace_str)
class SELU(Module):
r"""Applied element-wise, as:
.. math::
\text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
with :math:`\alpha = 1.6732632423543772848170429916717` and
:math:`\text{scale} = 1.0507009873554804934193349852946`.
.. warning::
When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation,
``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'``
in order to get `Self-Normalizing Neural Networks`_.
See :func:`torch.nn.init.calculate_gain` for more information.
More details can be found in the paper `Self-Normalizing Neural Networks`_ .
Args:
inplace (bool, optional): can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/SELU.png
Examples::
>>> m = nn.SELU()
>>> input = torch.randn(2)
>>> output = m(input)
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
"""
__constants__ = ['inplace']
inplace: bool
def __init__(self, inplace: bool = False) -> None:
super(SELU, self).__init__()
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.selu(input, self.inplace)
def extra_repr(self) -> str:
inplace_str = 'inplace=True' if self.inplace else ''
return inplace_str
class GLU(Module):
r"""Applies the gated linear unit function
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
of the input matrices and :math:`b` is the second half.
Args:
dim (int): the dimension on which to split the input. Default: -1
Shape:
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
dimensions
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
Examples::
>>> m = nn.GLU()
>>> input = torch.randn(4, 2)
>>> output = m(input)
"""
__constants__ = ['dim']
dim: int
def __init__(self, dim: int = -1) -> None:
super(GLU, self).__init__()
self.dim = dim
def forward(self, input: Tensor) -> Tensor:
return F.glu(input, self.dim)
def extra_repr(self) -> str:
return 'dim={}'.format(self.dim)
class GELU(Module):
r"""Applies the Gaussian Error Linear Units function:
.. math:: \text{GELU}(x) = x * \Phi(x)
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
When the approximate argument is 'tanh', Gelu is estimated with:
.. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
Args:
approximate (str, optional): the gelu approximation algorithm to use:
``'none'`` | ``'tanh'``. Default: ``'none'``
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/GELU.png
Examples::
>>> m = nn.GELU()
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['approximate']
approximate: str
def __init__(self, approximate: str = 'none') -> None:
super(GELU, self).__init__()
self.approximate = approximate
def forward(self, input: Tensor) -> Tensor:
return F.gelu(input, approximate=self.approximate)
def extra_repr(self) -> str:
return 'approximate={}'.format(repr(self.approximate))
class Hardshrink(Module):
r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
Hardshrink is defined as:
.. math::
\text{HardShrink}(x) =
\begin{cases}
x, & \text{ if } x > \lambda \\
x, & \text{ if } x < -\lambda \\
0, & \text{ otherwise }
\end{cases}
Args:
lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/Hardshrink.png
Examples::
>>> m = nn.Hardshrink()
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['lambd']
lambd: float
def __init__(self, lambd: float = 0.5) -> None:
super(Hardshrink, self).__init__()
self.lambd = lambd
def forward(self, input: Tensor) -> Tensor:
return F.hardshrink(input, self.lambd)
def extra_repr(self) -> str:
return '{}'.format(self.lambd)
class LeakyReLU(Module):
r"""Applies the element-wise function:
.. math::
\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
or
.. math::
\text{LeakyReLU}(x) =
\begin{cases}
x, & \text{ if } x \geq 0 \\
\text{negative\_slope} \times x, & \text{ otherwise }
\end{cases}
Args:
negative_slope: Controls the angle of the negative slope. Default: 1e-2
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(*)` where `*` means, any number of additional
dimensions
- Output: :math:`(*)`, same shape as the input
.. image:: ../scripts/activation_images/LeakyReLU.png
Examples::
>>> m = nn.LeakyReLU(0.1)
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['inplace', 'negative_slope']
inplace: bool
negative_slope: float
def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
super(LeakyReLU, self).__init__()
self.negative_slope = negative_slope
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return F.leaky_relu(input, self.negative_slope, self.inplace)
def extra_repr(self) -> str:
inplace_str = ', inplace=True' if self.inplace else ''
return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
class LogSigmoid(Module):
r"""Applies the element-wise function:
.. math::
\text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/LogSigmoid.png
Examples::
>>> m = nn.LogSigmoid()
>>> input = torch.randn(2)
>>> output = m(input)
"""
def forward(self, input: Tensor) -> Tensor:
return F.logsigmoid(input)
class Softplus(Module):
r"""Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} *
\log(1 + \exp(\beta * x))` element-wise.
SoftPlus is a smooth approximation to the ReLU function and can be used
to constrain the output of a machine to always be positive.
For numerical stability the implementation reverts to the linear function
when :math:`input \times \beta > threshold`.
Args:
beta: the :math:`\beta` value for the Softplus formulation. Default: 1
threshold: values above this revert to a linear function. Default: 20
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/Softplus.png
Examples::
>>> m = nn.Softplus()
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['beta', 'threshold']
beta: int
threshold: int
def __init__(self, beta: int = 1, threshold: int = 20) -> None:
super(Softplus, self).__init__()
self.beta = beta
self.threshold = threshold
def forward(self, input: Tensor) -> Tensor:
return F.softplus(input, self.beta, self.threshold)
def extra_repr(self) -> str:
return 'beta={}, threshold={}'.format(self.beta, self.threshold)
class Softshrink(Module):
r"""Applies the soft shrinkage function elementwise:
.. math::
\text{SoftShrinkage}(x) =
\begin{cases}
x - \lambda, & \text{ if } x > \lambda \\
x + \lambda, & \text{ if } x < -\lambda \\
0, & \text{ otherwise }
\end{cases}
Args:
lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
.. image:: ../scripts/activation_images/Softshrink.png
Examples::
>>> m = nn.Softshrink()
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['lambd']
lambd: float
def __init__(self, lambd: float = 0.5) -> None:
super(Softshrink, self).__init__()
self.lambd = lambd
def forward(self, input: Tensor) -> Tensor:
return F.softshrink(input, self.lambd)
def extra_repr(self) -> str:
return str(self.lambd)
class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information
from different representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
Multi-Head Attention is defined as:
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
``forward()`` will use a special optimized implementation if all of the following
conditions are met:
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
restriction will be loosened in the future.)
- inputs are batched (3D) with ``batch_first==True``
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
- training is disabled (using ``.eval()``)
- ``add_bias_kv`` is ``False``
- ``add_zero_attn`` is ``False``
- ``batch_first`` is ``True`` and the input is batched
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
nor ``attn_mask`` is passed
- autocast is disabled
If the optimized implementation is in use, a
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
``query``/``key``/``value`` to represent padding more efficiently than using a
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
will be returned, and an additional speedup proportional to the fraction of the input
that is padding can be expected.
Args:
embed_dim: Total dimension of the model.
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
Default: ``False``.
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
Examples::
>>> # xdoctest: +SKIP
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
__constants__ = ['batch_first']
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
self.register_parameter('in_proj_weight', None)
else:
self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
self.register_parameter('q_proj_weight', None)
self.register_parameter('k_proj_weight', None)
self.register_parameter('v_proj_weight', None)
if bias:
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
if add_bias_kv:
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self._reset_parameters()
def _reset_parameters(self):
if self._qkv_same_embed_dim:
xavier_uniform_(self.in_proj_weight)
else:
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.)
constant_(self.out_proj.bias, 0.)
if self.bias_k is not None: