Skip to content

Commit

Permalink
refactor: simplify Tensor import (#15959)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Dec 8, 2022
1 parent b5fa896 commit 23b12ee
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 57 deletions.
24 changes: 11 additions & 13 deletions src/lightning_lite/plugins/collectives/collective.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional

import torch
from torch import Tensor
from typing_extensions import Self

from lightning_lite.utilities.types import CollectibleGroup
Expand Down Expand Up @@ -38,45 +38,43 @@ def group(self) -> CollectibleGroup:
return self._group

@abstractmethod
def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
def broadcast(self, tensor: Tensor, src: int) -> Tensor:
...

@abstractmethod
def all_reduce(self, tensor: torch.Tensor, op: str) -> torch.Tensor:
def all_reduce(self, tensor: Tensor, op: str) -> Tensor:
...

@abstractmethod
def reduce(self, tensor: torch.Tensor, dst: int, op: str) -> torch.Tensor:
def reduce(self, tensor: Tensor, dst: int, op: str) -> Tensor:
...

@abstractmethod
def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor) -> List[torch.Tensor]:
def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]:
...

@abstractmethod
def gather(self, tensor: torch.Tensor, gather_list: List[torch.Tensor], dst: int = 0) -> List[torch.Tensor]:
def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]:
...

@abstractmethod
def scatter(self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0) -> torch.Tensor:
def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor:
...

@abstractmethod
def reduce_scatter(self, output: torch.Tensor, input_list: List[torch.Tensor], op: str) -> torch.Tensor:
def reduce_scatter(self, output: Tensor, input_list: List[Tensor], op: str) -> Tensor:
...

@abstractmethod
def all_to_all(
self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor]
) -> List[torch.Tensor]:
def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]:
...

@abstractmethod
def send(self, tensor: torch.Tensor, dst: int, tag: Optional[int] = 0) -> None:
def send(self, tensor: Tensor, dst: int, tag: Optional[int] = 0) -> None:
...

@abstractmethod
def recv(self, tensor: torch.Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> torch.Tensor:
def recv(self, tensor: Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> Tensor:
...

@abstractmethod
Expand Down
26 changes: 13 additions & 13 deletions src/lightning_lite/plugins/collectives/single_device.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, List

import torch
from torch import Tensor

from lightning_lite.plugins.collectives.collective import Collective
from lightning_lite.utilities.types import CollectibleGroup
Expand All @@ -15,42 +15,42 @@ def rank(self) -> int:
def world_size(self) -> int:
return 1

def broadcast(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
def broadcast(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor:
return tensor

def all_reduce(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
def all_reduce(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor:
return tensor

def reduce(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
def reduce(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor:
return tensor

def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor, **__: Any) -> List[torch.Tensor]:
def all_gather(self, tensor_list: List[Tensor], tensor: Tensor, **__: Any) -> List[Tensor]:
return [tensor]

def gather(self, tensor: torch.Tensor, *_: Any, **__: Any) -> List[torch.Tensor]:
def gather(self, tensor: Tensor, *_: Any, **__: Any) -> List[Tensor]:
return [tensor]

def scatter(
self,
tensor: torch.Tensor,
scatter_list: List[torch.Tensor],
tensor: Tensor,
scatter_list: List[Tensor],
*_: Any,
**__: Any,
) -> torch.Tensor:
) -> Tensor:
return scatter_list[0]

def reduce_scatter(self, output: torch.Tensor, input_list: List[torch.Tensor], *_: Any, **__: Any) -> torch.Tensor:
def reduce_scatter(self, output: Tensor, input_list: List[Tensor], *_: Any, **__: Any) -> Tensor:
return input_list[0]

def all_to_all(
self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor], *_: Any, **__: Any
) -> List[torch.Tensor]:
self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor], *_: Any, **__: Any
) -> List[Tensor]:
return input_tensor_list

def send(self, *_: Any, **__: Any) -> None:
pass

def recv(self, tensor: torch.Tensor, *_: Any, **__: Any) -> torch.Tensor:
def recv(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor:
return tensor

def barrier(self, *_: Any, **__: Any) -> None:
Expand Down
25 changes: 12 additions & 13 deletions src/lightning_lite/plugins/collectives/torch_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.distributed as dist
from torch import Tensor
from typing_extensions import Self

from lightning_lite.plugins.collectives.collective import Collective
Expand Down Expand Up @@ -33,49 +34,47 @@ def rank(self) -> int:
def world_size(self) -> int:
return dist.get_world_size(self.group)

def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
def broadcast(self, tensor: Tensor, src: int) -> Tensor:
dist.broadcast(tensor, src, group=self.group)
return tensor

def all_reduce(self, tensor: torch.Tensor, op: Union[str, ReduceOp, RedOpType] = "sum") -> torch.Tensor:
def all_reduce(self, tensor: Tensor, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor:
op = self._convert_to_native_op(op)
dist.all_reduce(tensor, op=op, group=self.group)
return tensor

def reduce(self, tensor: torch.Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> torch.Tensor:
def reduce(self, tensor: Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor:
op = self._convert_to_native_op(op)
dist.reduce(tensor, dst, op=op, group=self.group)
return tensor

def all_gather(self, tensor_list: List[torch.Tensor], tensor: torch.Tensor) -> List[torch.Tensor]:
def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]:
dist.all_gather(tensor_list, tensor, group=self.group)
return tensor_list

def gather(self, tensor: torch.Tensor, gather_list: List[torch.Tensor], dst: int = 0) -> List[torch.Tensor]:
def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]:
dist.gather(tensor, gather_list, dst, group=self.group)
return gather_list

def scatter(self, tensor: torch.Tensor, scatter_list: List[torch.Tensor], src: int = 0) -> torch.Tensor:
def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor:
dist.scatter(tensor, scatter_list, src, group=self.group)
return tensor

def reduce_scatter(
self, output: torch.Tensor, input_list: List[torch.Tensor], op: Union[str, ReduceOp, RedOpType] = "sum"
) -> torch.Tensor:
self, output: Tensor, input_list: List[Tensor], op: Union[str, ReduceOp, RedOpType] = "sum"
) -> Tensor:
op = self._convert_to_native_op(op)
dist.reduce_scatter(output, input_list, op=op, group=self.group)
return output

def all_to_all(
self, output_tensor_list: List[torch.Tensor], input_tensor_list: List[torch.Tensor]
) -> List[torch.Tensor]:
def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]:
dist.all_to_all(output_tensor_list, input_tensor_list, group=self.group)
return output_tensor_list

def send(self, tensor: torch.Tensor, dst: int, tag: Optional[int] = 0) -> None:
def send(self, tensor: Tensor, dst: int, tag: Optional[int] = 0) -> None:
dist.send(tensor, dst, tag=tag, group=self.group)

def recv(self, tensor: torch.Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> torch.Tensor:
def recv(self, tensor: Tensor, src: Optional[int] = None, tag: Optional[int] = 0) -> Tensor:
dist.recv(tensor, src, tag=tag, group=self.group)
return tensor

Expand Down
3 changes: 2 additions & 1 deletion src/lightning_lite/plugins/precision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from typing import Union

import torch
from torch import Tensor


def _convert_fp_tensor(tensor: torch.Tensor, dst_type: Union[str, torch.dtype]) -> torch.Tensor:
def _convert_fp_tensor(tensor: Tensor, dst_type: Union[str, torch.dtype]) -> Tensor:
return tensor.to(dst_type) if torch.is_floating_point(tensor) else tensor
2 changes: 1 addition & 1 deletion src/pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
if device is not None and not isinstance(device, (torch.device, str)):
raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}")

self.n_averaged: Optional[torch.Tensor] = None
self.n_averaged: Optional[Tensor] = None
self._swa_epoch_start = swa_epoch_start
self._swa_lrs = swa_lrs
self._annealing_epochs = annealing_epochs
Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def log(
" but it should not contain information about `dataloader_idx`"
)

value = apply_to_collection(value, (torch.Tensor, numbers.Number), self.__to_tensor, name)
value = apply_to_collection(value, (Tensor, numbers.Number), self.__to_tensor, name)

if self.trainer._logger_connector.should_reset_tensors(self._current_fx_name):
# if we started a new epoch (running its first batch) the hook name has changed
Expand Down Expand Up @@ -544,10 +544,10 @@ def __check_not_nested(value: dict, name: str) -> None:
def __check_allowed(v: Any, name: str, value: Any) -> None:
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")

def __to_tensor(self, value: Union[torch.Tensor, numbers.Number], name: str) -> Tensor:
def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor:
value = (
value.clone().detach().to(self.device)
if isinstance(value, torch.Tensor)
if isinstance(value, Tensor)
else torch.tensor(value, device=self.device)
)
if not torch.numel(value) == 1:
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from time import time
from typing import Any, Dict, Mapping, Optional, Union

import torch
import yaml
from lightning_utilities.core.imports import module_available
from torch import Tensor
from typing_extensions import Literal

from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Expand Down Expand Up @@ -332,7 +332,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
for t, p, s, tag in checkpoints:
metadata = {
# Ensure .item() is called to store Tensor contents
"score": s.item() if isinstance(s, torch.Tensor) else s,
"score": s.item() if isinstance(s, Tensor) else s,
"original_filename": Path(p).name,
"Checkpoint": {
k: getattr(checkpoint_callback, k)
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/serve/servable_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Callable, Dict, Tuple

import torch
from torch import Tensor


class ServableModule(torch.nn.Module):
Expand Down Expand Up @@ -70,7 +71,7 @@ def configure_serialization(self) -> Tuple[Dict[str, Callable], Dict[str, Callab
"""
...

def serve_step(self, *args: torch.Tensor, **kwargs: torch.Tensor) -> Dict[str, torch.Tensor]:
def serve_step(self, *args: Tensor, **kwargs: Tensor) -> Dict[str, Tensor]:
r"""
Returns the predictions of your model as a dictionary.
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from inspect import getmembers, isclass

import torch
from torch import Tensor

from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.strategies import _StrategyRegistry
Expand All @@ -40,7 +41,7 @@ def _call_register_strategies(registry: _StrategyRegistry, base_module: str) ->
mod.register_strategies(registry)


def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor:
def _fp_to_half(tensor: Tensor, precision: PrecisionType) -> Tensor:
if precision == PrecisionType.HALF:
return _convert_fp_tensor(tensor, torch.half)
if precision == PrecisionType.BFLOAT:
Expand Down
19 changes: 10 additions & 9 deletions src/pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections
from torch import Tensor
from torch.utils.data import Dataset
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
from torch.utils.data.dataset import IterableDataset
Expand Down Expand Up @@ -59,18 +60,18 @@ def reset(self, window_length: Optional[int] = None) -> None:
"""Empty the accumulator."""
if window_length is not None:
self.window_length = window_length
self.memory: Optional[torch.Tensor] = None
self.memory: Optional[Tensor] = None
self.current_idx: int = 0
self.last_idx: Optional[int] = None
self.rotated: bool = False

def last(self) -> Optional[torch.Tensor]:
def last(self) -> Optional[Tensor]:
"""Get the last added element."""
if self.last_idx is not None:
assert isinstance(self.memory, torch.Tensor)
assert isinstance(self.memory, Tensor)
return self.memory[self.last_idx].float()

def append(self, x: torch.Tensor) -> None:
def append(self, x: Tensor) -> None:
"""Add an element to the accumulator."""
if self.memory is None:
# tradeoff memory for speed by keeping the memory on device
Expand All @@ -89,21 +90,21 @@ def append(self, x: torch.Tensor) -> None:
if self.current_idx == 0:
self.rotated = True

def mean(self) -> Optional[torch.Tensor]:
def mean(self) -> Optional[Tensor]:
"""Get mean value from stored elements."""
return self._agg_memory("mean")

def max(self) -> Optional[torch.Tensor]:
def max(self) -> Optional[Tensor]:
"""Get maximal value from stored elements."""
return self._agg_memory("max")

def min(self) -> Optional[torch.Tensor]:
def min(self) -> Optional[Tensor]:
"""Get minimal value from stored elements."""
return self._agg_memory("min")

def _agg_memory(self, how: str) -> Optional[torch.Tensor]:
def _agg_memory(self, how: str) -> Optional[Tensor]:
if self.last_idx is not None:
assert isinstance(self.memory, torch.Tensor)
assert isinstance(self.memory, Tensor)
if self.rotated:
return getattr(self.memory.float(), how)()
return getattr(self.memory[: self.current_idx].float(), how)()
Expand Down

0 comments on commit 23b12ee

Please sign in to comment.