-
Notifications
You must be signed in to change notification settings - Fork 388
/
checks.py
723 lines (592 loc) · 31 KB
/
checks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from time import perf_counter
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
import torch
from torch import Tensor
from torchmetrics.utilities.data import select_topk, to_onehot
from torchmetrics.utilities.enums import DataType
def _check_for_empty_tensors(preds: Tensor, target: Tensor) -> bool:
if preds.numel() == target.numel() == 0:
return True
return False
def _check_same_shape(preds: Tensor, target: Tensor) -> None:
"""Check that predictions and target have the same shape, else raise error."""
if preds.shape != target.shape:
raise RuntimeError("Predictions and targets are expected to have the same shape")
def _basic_input_validation(
preds: Tensor, target: Tensor, threshold: float, multiclass: Optional[bool], ignore_index: Optional[int]
) -> None:
"""Perform basic validation of inputs that does not require deducing any information of the type of inputs."""
# Skip all other checks if both preds and target are empty tensors
if _check_for_empty_tensors(preds, target):
return
if target.is_floating_point():
raise ValueError("The `target` has to be an integer tensor.")
if ignore_index is None and target.min() < 0:
raise ValueError("The `target` has to be a non-negative tensor.")
elif ignore_index is not None and ignore_index >= 0 and target.min() < 0:
raise ValueError("The `target` has to be a non-negative tensor.")
preds_float = preds.is_floating_point()
if not preds_float and preds.min() < 0:
raise ValueError("If `preds` are integers, they have to be non-negative.")
if not preds.shape[0] == target.shape[0]:
raise ValueError("The `preds` and `target` should have the same first dimension.")
if multiclass is False and target.max() > 1:
raise ValueError("If you set `multiclass=False`, then `target` should not exceed 1.")
if multiclass is False and not preds_float and preds.max() > 1:
raise ValueError("If you set `multiclass=False` and `preds` are integers, then `preds` should not exceed 1.")
def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> Tuple[DataType, int]:
"""This checks that the shape and type of inputs are consistent with each other and fall into one of the
allowed input types (see the documentation of docstring of ``_input_format_classification``). It does not check
for consistency of number of classes, other functions take care of that.
It returns the name of the case in which the inputs fall, and the implied number of classes (from the ``C`` dim for
multi-class data, or extra dim(s) for multi-label data).
"""
preds_float = preds.is_floating_point()
if preds.ndim == target.ndim:
if preds.shape != target.shape:
raise ValueError(
"The `preds` and `target` should have the same shape,",
f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.",
)
if preds_float and target.numel() > 0 and target.max() > 1:
raise ValueError(
"If `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary."
)
# Get the case
if preds.ndim == 1 and preds_float:
case = DataType.BINARY
elif preds.ndim == 1 and not preds_float:
case = DataType.MULTICLASS
elif preds.ndim > 1 and preds_float:
case = DataType.MULTILABEL
else:
case = DataType.MULTIDIM_MULTICLASS
implied_classes = preds[0].numel() if preds.numel() > 0 else 0
elif preds.ndim == target.ndim + 1:
if not preds_float:
raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.")
if preds.shape[2:] != target.shape[1:]:
raise ValueError(
"If `preds` have one dimension more than `target`, the shape of `preds` should be"
" (N, C, ...), and the shape of `target` should be (N, ...)."
)
implied_classes = preds.shape[1] if preds.numel() > 0 else 0
if preds.ndim == 2:
case = DataType.MULTICLASS
else:
case = DataType.MULTIDIM_MULTICLASS
else:
raise ValueError(
"Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)"
" and `preds` should be (N, C, ...)."
)
return case, implied_classes
def _check_num_classes_binary(num_classes: int, multiclass: Optional[bool]) -> None:
"""This checks that the consistency of `num_classes` with the data and `multiclass` param for binary data."""
if num_classes > 2:
raise ValueError("Your data is binary, but `num_classes` is larger than 2.")
if num_classes == 2 and not multiclass:
raise ValueError(
"Your data is binary and `num_classes=2`, but `multiclass` is not True."
" Set it to True if you want to transform binary data to multi-class format."
)
if num_classes == 1 and multiclass:
raise ValueError(
"You have binary data and have set `multiclass=True`, but `num_classes` is 1."
" Either set `multiclass=None`(default) or set `num_classes=2`"
" to transform binary data to multi-class format."
)
def _check_num_classes_mc(
preds: Tensor,
target: Tensor,
num_classes: int,
multiclass: Optional[bool],
implied_classes: int,
) -> None:
"""This checks that the consistency of `num_classes` with the data and `multiclass` param for (multi-
dimensional) multi-class data."""
if num_classes == 1 and multiclass is not False:
raise ValueError(
"You have set `num_classes=1`, but predictions are integers."
" If you want to convert (multi-dimensional) multi-class data with 2 classes"
" to binary/multi-label, set `multiclass=False`."
)
if num_classes > 1:
if multiclass is False and implied_classes != num_classes:
raise ValueError(
"You have set `multiclass=False`, but the implied number of classes "
" (from shape of inputs) does not match `num_classes`. If you are trying to"
" transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`"
" should be either None or the product of the size of extra dimensions (...)."
" See Input Types in Metrics documentation."
)
if target.numel() > 0 and num_classes <= target.max():
raise ValueError("The highest label in `target` should be smaller than `num_classes`.")
if preds.shape != target.shape and num_classes != implied_classes:
raise ValueError("The size of C dimension of `preds` does not match `num_classes`.")
def _check_num_classes_ml(num_classes: int, multiclass: Optional[bool], implied_classes: int) -> None:
"""This checks that the consistency of ``num_classes`` with the data and ``multiclass`` param for multi-label
data."""
if multiclass and num_classes != 2:
raise ValueError(
"Your have set `multiclass=True`, but `num_classes` is not equal to 2."
" If you are trying to transform multi-label data to 2 class multi-dimensional"
" multi-class, you should set `num_classes` to either 2 or None."
)
if not multiclass and num_classes != implied_classes:
raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.")
def _check_top_k(top_k: int, case: str, implied_classes: int, multiclass: Optional[bool], preds_float: bool) -> None:
if case == DataType.BINARY:
raise ValueError("You can not use `top_k` parameter with binary data.")
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("The `top_k` has to be an integer larger than 0.")
if not preds_float:
raise ValueError("You have set `top_k`, but you do not have probability predictions.")
if multiclass is False:
raise ValueError("If you set `multiclass=False`, you can not set `top_k`.")
if case == DataType.MULTILABEL and multiclass:
raise ValueError(
"If you want to transform multi-label data to 2 class multi-dimensional"
"multi-class data using `multiclass=True`, you can not use `top_k`."
)
if top_k >= implied_classes:
raise ValueError("The `top_k` has to be strictly smaller than the `C` dimension of `preds`.")
def _check_classification_inputs(
preds: Tensor,
target: Tensor,
threshold: float,
num_classes: Optional[int],
multiclass: Optional[bool],
top_k: Optional[int],
ignore_index: Optional[int] = None,
) -> DataType:
"""Performs error checking on inputs for classification.
This ensures that preds and target take one of the shape/type combinations that are
specified in ``_input_format_classification`` docstring. It also checks the cases of
over-rides with ``multiclass`` by checking (for multi-class and multi-dim multi-class
cases) that there are only up to 2 distinct labels.
In case where preds are floats (probabilities), it is checked whether they are in ``[0,1]`` interval.
When ``num_classes`` is given, it is checked that it is consistent with input cases (binary,
multi-label, ...), and that, if available, the implied number of classes in the ``C``
dimension is consistent with it (as well as that max label in target is smaller than it).
When ``num_classes`` is not specified in these cases, consistency of the highest target
value against ``C`` dimension is checked for (multi-dimensional) multi-class cases.
If ``top_k`` is set (not None) for inputs that do not have probability predictions (and
are not binary), an error is raised. Similarly, if ``top_k`` is set to a number that
is higher than or equal to the ``C`` dimension of ``preds``, an error is raised.
Preds and target tensors are expected to be squeezed already - all dimensions should be
greater than 1, except perhaps the first one (``N``).
Args:
preds: Tensor with predictions (labels or probabilities)
target: Tensor with ground truth labels, always integers (labels)
threshold:
Threshold value for transforming probability/logit predictions to binary
(0,1) predictions, in the case of binary or multi-label inputs.
num_classes:
Number of classes. If not explicitly set, the number of classes will be inferred
either from the shape of inputs, or the maximum label in the ``target`` and ``preds``
tensor, where applicable.
top_k:
Number of the highest probability entries for each sample to convert to 1s - relevant
only for inputs with probability predictions. The default value (``None``) will be
interpreted as 1 for these inputs. If this parameter is set for multi-label inputs,
it will take precedence over threshold.
Should be left unset (``None``) for inputs with label predictions.
multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <pages/overview:using the multiclass parameter>`
for a more detailed explanation and examples.
Return:
case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or
'multi-dim multi-class'
"""
# Basic validation (that does not need case/type information)
_basic_input_validation(preds, target, threshold, multiclass, ignore_index)
# Check that shape/types fall into one of the cases
case, implied_classes = _check_shape_and_type_consistency(preds, target)
# Check consistency with the `C` dimension in case of multi-class data
if preds.shape != target.shape:
if multiclass is False and implied_classes != 2:
raise ValueError(
"You have set `multiclass=False`, but have more than 2 classes in your data,"
" based on the C dimension of `preds`."
)
if target.max() >= implied_classes:
raise ValueError(
"The highest label in `target` should be smaller than the size of the `C` dimension of `preds`."
)
# Check that num_classes is consistent
if num_classes:
if case == DataType.BINARY:
_check_num_classes_binary(num_classes, multiclass)
elif case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS):
_check_num_classes_mc(preds, target, num_classes, multiclass, implied_classes)
elif case.MULTILABEL:
_check_num_classes_ml(num_classes, multiclass, implied_classes)
# Check that top_k is consistent
if top_k is not None:
_check_top_k(top_k, case, implied_classes, multiclass, preds.is_floating_point())
return case
def _input_squeeze(
preds: Tensor,
target: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Remove excess dimensions."""
if preds.shape[0] == 1:
preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0)
else:
preds, target = preds.squeeze(), target.squeeze()
return preds, target
def _input_format_classification(
preds: Tensor,
target: Tensor,
threshold: float = 0.5,
top_k: Optional[int] = None,
num_classes: Optional[int] = None,
multiclass: Optional[bool] = None,
ignore_index: Optional[int] = None,
) -> Tuple[Tensor, Tensor, DataType]:
"""Convert preds and target tensors into common format.
Preds and targets are supposed to fall into one of these categories (and are
validated to make sure this is the case):
* Both preds and target are of shape ``(N,)``, and both are integers (multi-class)
* Both preds and target are of shape ``(N,)``, and target is binary, while preds
are a float (binary)
* preds are of shape ``(N, C)`` and are floats, and target is of shape ``(N,)`` and
is integer (multi-class)
* preds and target are of shape ``(N, ...)``, target is binary and preds is a float
(multi-label)
* preds are of shape ``(N, C, ...)`` and are floats, target is of shape ``(N, ...)``
and is integer (multi-dimensional multi-class)
* preds and target are of shape ``(N, ...)`` both are integers (multi-dimensional
multi-class)
To avoid ambiguities, all dimensions of size 1, except the first one, are squeezed out.
The returned output tensors will be binary tensors of the same shape, either ``(N, C)``
of ``(N, C, X)``, the details for each case are described below. The function also returns
a ``case`` string, which describes which of the above cases the inputs belonged to - regardless
of whether this was "overridden" by other settings (like ``multiclass``).
In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed
into a binary tensor (elements become 1 if the probability is greater than or equal to
``threshold`` or 0 otherwise). If ``multiclass=True``, then both targets are preds
become ``(N, 2)`` tensors by a one-hot transformation; with the thresholding being applied to
preds first.
In multi-class case, normally both preds and targets become ``(N, C)`` binary tensors; targets
by a one-hot transformation and preds by selecting ``top_k`` largest entries (if their original
shape was ``(N,C)``). However, if ``multiclass=False``, then targets and preds will be
returned as ``(N,1)`` tensor.
In multi-label case, normally targets and preds are returned as ``(N, C)`` binary tensors, with
preds being binarized as in the binary case. Here the ``C`` dimension is obtained by flattening
all dimensions after the first one. However, if ``multiclass=True``, then both are returned as
``(N, 2, C)``, by an equivalent transformation as in the binary case.
In multi-dimensional multi-class case, normally both target and preds are returned as
``(N, C, X)`` tensors, with ``X`` resulting from flattening of all dimensions except ``N`` and
``C``. The transformations performed here are equivalent to the multi-class case. However, if
``multiclass=False`` (and there are up to two classes), then the data is returned as
``(N, X)`` binary tensors (multi-label).
Note:
Where a one-hot transformation needs to be performed and the number of classes
is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be
equal to ``num_classes``, if it is given, or the maximum label value in preds and
target.
Args:
preds: Tensor with predictions (labels or probabilities)
target: Tensor with ground truth labels, always integers (labels)
threshold:
Threshold value for transforming probability/logit predictions to binary
(0 or 1) predictions, in the case of binary or multi-label inputs.
num_classes:
Number of classes. If not explicitly set, the number of classes will be inferred
either from the shape of inputs, or the maximum label in the ``target`` and ``preds``
tensor, where applicable.
top_k:
Number of the highest probability entries for each sample to convert to 1s - relevant
only for (multi-dimensional) multi-class inputs with probability predictions. The
default value (``None``) will be interpreted as 1 for these inputs.
Should be left unset (``None``) for all other types of inputs.
multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <pages/overview:using the multiclass parameter>`
for a more detailed explanation and examples.
Returns:
preds: binary tensor of shape ``(N, C)`` or ``(N, C, X)``
target: binary tensor of shape ``(N, C)`` or ``(N, C, X)``
case: The case the inputs fall in, one of ``'binary'``, ``'multi-class'``, ``'multi-label'`` or
``'multi-dim multi-class'``
"""
# Remove excess dimensions
preds, target = _input_squeeze(preds, target)
# Convert half precision tensors to full precision, as not all ops are supported
# for example, min() is not supported
if preds.dtype == torch.float16:
preds = preds.float()
case = _check_classification_inputs(
preds,
target,
threshold=threshold,
num_classes=num_classes,
multiclass=multiclass,
top_k=top_k,
ignore_index=ignore_index,
)
if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k:
preds = (preds >= threshold).int()
num_classes = num_classes if not multiclass else 2
if case == DataType.MULTILABEL and top_k:
preds = select_topk(preds, top_k)
if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or multiclass:
if preds.is_floating_point():
num_classes = preds.shape[1]
preds = select_topk(preds, top_k or 1)
else:
num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1
preds = to_onehot(preds, max(2, num_classes))
target = to_onehot(target, max(2, num_classes)) # type: ignore
if multiclass is False:
preds, target = preds[:, 1, ...], target[:, 1, ...]
if not _check_for_empty_tensors(preds, target):
if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and multiclass is not False) or multiclass:
target = target.reshape(target.shape[0], target.shape[1], -1)
preds = preds.reshape(preds.shape[0], preds.shape[1], -1)
else:
target = target.reshape(target.shape[0], -1)
preds = preds.reshape(preds.shape[0], -1)
# Some operations above create an extra dimension for MC/binary case - this removes it
if preds.ndim > 2:
preds, target = preds.squeeze(-1), target.squeeze(-1)
return preds.int(), target.int(), case
def _input_format_classification_one_hot(
num_classes: int,
preds: Tensor,
target: Tensor,
threshold: float = 0.5,
multilabel: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Convert preds and target tensors into one hot spare label tensors.
Args:
num_classes: number of classes
preds: either tensor with labels, tensor with probabilities/logits or multilabel tensor
target: tensor with ground-true labels
threshold: float used for thresholding multilabel input
multilabel: boolean flag indicating if input is multilabel
Raises:
ValueError:
If ``preds`` and ``target`` don't have the same number of dimensions
or one additional dimension for ``preds``.
Returns:
preds: one hot tensor of shape [num_classes, -1] with predicted labels
target: one hot tensors of shape [num_classes, -1] with true labels
"""
if preds.ndim not in (target.ndim, target.ndim + 1):
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")
if preds.ndim == target.ndim + 1:
# multi class probabilities
preds = torch.argmax(preds, dim=1)
if preds.ndim == target.ndim and preds.dtype in (torch.long, torch.int) and num_classes > 1 and not multilabel:
# multi-class
preds = to_onehot(preds, num_classes=num_classes)
target = to_onehot(target, num_classes=num_classes)
elif preds.ndim == target.ndim and preds.is_floating_point():
# binary or multilabel probabilities
preds = (preds >= threshold).long()
# transpose class as first dim and reshape
if preds.ndim > 1:
preds = preds.transpose(1, 0)
target = target.transpose(1, 0)
return preds.reshape(num_classes, -1), target.reshape(num_classes, -1)
def _check_retrieval_functional_inputs(
preds: Tensor,
target: Tensor,
allow_non_binary_target: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Check ``preds`` and ``target`` tensors are of the same shape and of the correct data type.
Args:
preds: either tensor with scores/logits
target: tensor with ground true labels
allow_non_binary_target: whether to allow target to contain non-binary values
Raises:
ValueError:
If ``preds`` and ``target`` don't have the same shape, if they are empty
or not of the correct ``dtypes``.
Returns:
preds: as torch.float32
target: as torch.long if not floating point else torch.float32
"""
if preds.shape != target.shape:
raise ValueError("`preds` and `target` must be of the same shape")
if not preds.numel() or not preds.size():
raise ValueError("`preds` and `target` must be non-empty and non-scalar tensors")
return _check_retrieval_target_and_prediction_types(preds, target, allow_non_binary_target=allow_non_binary_target)
def _check_retrieval_inputs(
indexes: Tensor,
preds: Tensor,
target: Tensor,
allow_non_binary_target: bool = False,
ignore_index: Optional[int] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct data type.
Args:
indexes: tensor with queries indexes
preds: tensor with scores/logits
target: tensor with ground true labels
ignore_index: ignore predictions where targets are equal to this number
Raises:
ValueError:
If ``preds`` and ``target`` don't have the same shape, if they are empty or not of the correct ``dtypes``.
Returns:
indexes: as ``torch.long``
preds: as ``torch.float32``
target: as ``torch.long``
"""
if indexes.shape != preds.shape or preds.shape != target.shape:
raise ValueError("`indexes`, `preds` and `target` must be of the same shape")
if indexes.dtype is not torch.long:
raise ValueError("`indexes` must be a tensor of long integers")
# remove predictions where target is equal to `ignore_index`
if ignore_index is not None:
valid_positions = target != ignore_index
indexes, preds, target = indexes[valid_positions], preds[valid_positions], target[valid_positions]
if not indexes.numel() or not indexes.size():
raise ValueError(
"`indexes`, `preds` and `target` must be non-empty and non-scalar tensors",
)
preds, target = _check_retrieval_target_and_prediction_types(
preds, target, allow_non_binary_target=allow_non_binary_target
)
return indexes.long().flatten(), preds, target
def _check_retrieval_target_and_prediction_types(
preds: Tensor,
target: Tensor,
allow_non_binary_target: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Check ``preds`` and ``target`` tensors are of the same shape and of the correct data type.
Args:
preds: either tensor with scores/logits
target: tensor with ground true labels
allow_non_binary_target: whether to allow target to contain non-binary values
Raises:
ValueError:
If ``preds`` and ``target`` don't have the same shape, if they are empty or not of the correct ``dtypes``.
"""
if target.dtype not in (torch.bool, torch.long, torch.int) and not torch.is_floating_point(target):
raise ValueError("`target` must be a tensor of booleans, integers or floats")
if not preds.is_floating_point():
raise ValueError("`preds` must be a tensor of floats")
if not allow_non_binary_target and (target.max() > 1 or target.min() < 0):
raise ValueError("`target` must contain `binary` values")
target = target.float() if target.is_floating_point() else target.long()
preds = preds.float()
return preds.flatten(), target.flatten()
def _allclose_recursive(res1: Any, res2: Any, atol: float = 1e-8) -> bool:
"""Utility function for recursively asserting that two results are within a certain tolerance."""
# single output compare
if isinstance(res1, Tensor):
return torch.allclose(res1, res2, atol=atol)
elif isinstance(res1, str):
return res1 == res2
elif isinstance(res1, Sequence):
return all(_allclose_recursive(r1, r2) for r1, r2 in zip(res1, res2))
elif isinstance(res1, Mapping):
return all(_allclose_recursive(res1[k], res2[k]) for k in res1.keys())
else:
return res1 == res2
def check_forward_no_full_state(
metric_class, # type: ignore
init_args: Dict[str, Any] = {},
input_args: Dict[str, Any] = {},
num_update_to_compare: Sequence[int] = [10, 100, 1000],
reps: int = 5,
) -> bool:
"""Utility function for checking if the new ``full_state_update`` property can safely be set to ``False`` which
will for most metrics results in a speedup when using ``forward``.
Args:
metric_class: metric class object that should be checked
init_args: dict containing arguments for initializing the metric class
input_args: dict containing arguments to pass to ``forward``
num_update_to_compare: if we successfully detech that the flag is safe to set to ``False``
we will run some speedup test. This arg should be a list of integers for how many
steps to compare over.
reps: number of repetitions of speedup test
Example (states in ``update`` are independent, save to set ``full_state_update=False``)
>>> from torchmetrics import ConfusionMatrix
>>> check_forward_no_full_state(
... ConfusionMatrix,
... init_args = {'num_classes': 3},
... input_args = {'preds': torch.randint(3, (10,)), 'target': torch.randint(3, (10,))},
... ) # doctest: +ELLIPSIS
Full state for 10 steps took: ...
Partial state for 10 steps took: ...
Full state for 100 steps took: ...
Partial state for 100 steps took: ...
Full state for 1000 steps took: ...
Partial state for 1000 steps took: ...
True
Example (states in ``update`` are dependend meaning that ``full_state_update=True``):
>>> from torchmetrics import ConfusionMatrix
>>> class MyMetric(ConfusionMatrix):
... def update(self, preds, target):
... super().update(preds, target)
... # by construction make future states dependent on prior states
... if self.confmat.sum() > 20:
... self.reset()
>>> check_forward_no_full_state(
... MyMetric,
... init_args = {'num_classes': 3},
... input_args = {'preds': torch.randint(3, (10,)), 'target': torch.randint(3, (10,))},
... )
False
"""
class FullState(metric_class):
full_state_update = True
class PartState(metric_class):
full_state_update = False
fullstate = FullState(**init_args)
partstate = PartState(**init_args)
equal = True
for _ in range(num_update_to_compare[0]):
out1 = fullstate(**input_args)
try: # if it fails, the code most likely need access to the full state
out2 = partstate(**input_args)
except RuntimeError:
equal = False
break
equal = equal & _allclose_recursive(out1, out2)
res1 = fullstate.compute()
try: # if it fails, the code most likely need access to the full state
res2 = partstate.compute()
except RuntimeError:
equal = False
equal = equal & _allclose_recursive(res1, res2)
if not equal: # we can stop early because the results did not match
return False
# Do timings
res = torch.zeros(2, len(num_update_to_compare), reps)
for i, metric in enumerate([fullstate, partstate]):
for j, t in enumerate(num_update_to_compare):
for r in range(reps):
start = perf_counter()
for _ in range(t):
_ = metric(**input_args)
end = perf_counter()
res[i, j, r] = end - start
metric.reset()
mean = torch.mean(res, -1)
std = torch.std(res, -1)
for t in range(len(num_update_to_compare)):
print(f"Full state for {num_update_to_compare[t]} steps took: {mean[0, t]}+-{std[0, t]:0.3f}")
print(f"Partial state for {num_update_to_compare[t]} steps took: {mean[1, t]:0.3f}+-{std[1, t]:0.3f}")
return (mean[1, -1] < mean[0, -1]).item() # if faster on average, we recommend upgrading