-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
types.py
105 lines (89 loc) · 3.43 KB
/
types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convention:
- Do not include any `_TYPE` suffix
- Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`)
"""
from pathlib import Path
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union
import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torchmetrics import Metric
from typing_extensions import Protocol, runtime_checkable, TypedDict
_NUMBER = Union[int, float]
_METRIC = Union[Metric, torch.Tensor, _NUMBER]
_METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]]
STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]]
EPOCH_OUTPUT = List[STEP_OUTPUT]
_EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader
_PREDICT_OUTPUT = Union[List[Any], List[List[Any]]]
_PARAMETERS = Iterator[torch.nn.Parameter]
_PATH = Union[str, Path]
TRAIN_DATALOADERS = Union[
DataLoader,
Sequence[DataLoader],
Sequence[Sequence[DataLoader]],
Sequence[Dict[str, DataLoader]],
Dict[str, DataLoader],
Dict[str, Dict[str, DataLoader]],
Dict[str, Sequence[DataLoader]],
]
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]
@runtime_checkable
class _SupportsStateDict(Protocol):
"""This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`."""
def state_dict(self) -> Dict[str, Any]:
...
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...
# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
class _LRScheduler(_SupportsStateDict):
optimizer: Optimizer
def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None:
...
# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
class ReduceLROnPlateau(_SupportsStateDict):
in_cooldown: bool
optimizer: Optimizer
def __init__(
self,
optimizer: Optimizer,
mode: str = ...,
factor: float = ...,
patience: int = ...,
verbose: bool = ...,
threshold: float = ...,
threshold_mode: str = ...,
cooldown: int = ...,
min_lr: float = ...,
eps: float = ...,
) -> None:
...
# todo: improve LRSchedulerType naming/typing
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
class LRSchedulerConfig(TypedDict):
scheduler: Union[_LRScheduler, ReduceLROnPlateau]
name: Optional[str]
interval: str
frequency: int
reduce_on_plateau: bool
monitor: Optional[str]
strict: bool
opt_idx: Optional[int]