Skip to content

Commit

Permalink
🔨 core memory refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
kartik4949 committed Jan 20, 2021
1 parent 937a94c commit d42352b
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions pytorch_lightning/core/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,21 +222,13 @@ def param_nums(self) -> List[int]:

@property
def total_params(self) -> int:
_total_params = sum(p.numel() for p in self._model.parameters())
return _total_params
return sum(p.numel() for p in self._model.parameters())

def total_out_params(self, batch_size_dim: int) -> int:
""" finds total output parameters to calculate forward/backward pass size. """
_total_out_params = 0

_out_sizes = [
l_size
for l_type, l_size in zip(self.layer_types, self.out_sizes)
if not isinstance(l_type, torch.nn.Sequential)
]

# recursive traversal to calculate output size.
# recursive is used to handle nested output sizes i.e [[[1,2,3],[[12,2,3], [1,3,4]]], [2,3,4]].
# recursive is used to handle nested output sizes i.e [[[1,2,3], [[12,2,3], [1,3,4]]], [2,3,4]].
def _get_out_size_params(out_sizes, batch_size_dim=batch_size_dim):
nonlocal _total_out_params
if not any(isinstance(i, list) for i in out_sizes):
Expand All @@ -249,7 +241,17 @@ def _get_out_size_params(out_sizes, batch_size_dim=batch_size_dim):
# do nothing if tried to find prod on unknown type.
pass
else:
_ = [_get_out_size_params(out_size) for out_size in out_sizes if isinstance(out_size, list)]
for out_size in out_sizes:
if isinstance(out_size, list):
_get_out_size_params(out_size)

_total_out_params = 0

_out_sizes = [
l_size
for l_type, l_size in zip(self.layer_types, self.out_sizes)
if not isinstance(l_type, torch.nn.Sequential)
]

_get_out_size_params(_out_sizes)

Expand All @@ -264,15 +266,13 @@ def model_size(self, batch_size_dim: Optional[int] = 0) -> float:
NOTE: Currently only Supported in Full Mode.
::
Example:
Example::
>> model = LitModel()
>> summary = ModelSummary(model, mode='full') # doctest: +NORMALIZE_WHITESPACE
>> summary.model_size()
Returns:
float: Total estimated model size(MB) if example input array is passed else Total Model Params Size(MB).
Total estimated model size(MB) if example input array is passed else Total Model Params Size(MB).
"""

if isinstance(self._model.example_input_array, (list, tuple)):
Expand All @@ -295,15 +295,15 @@ def model_size(self, batch_size_dim: Optional[int] = 0) -> float:

return self._get_total_size(in_features, batch_size_dim)

def _get_total_size(self, input_size: tuple, batch_size_dim: int) -> float:
def _get_total_size(self, input_size: Tuple[int], batch_size_dim: int) -> float:
"""_get_total_size.
helper function to find total model size MB
Function to find total model size (MB)
Args:
input_size (tuple): input_size to calculate model input size (MB)
input_size : input_size to calculate model input size (MB)
Returns:
float: Total estimated model size if example input array is passed else Total Model Params Size.
Total estimated model size if example input array is passed else Total Model Params Size.
"""
self.total_params_dsize = abs(self.total_params * self._precision_megabytes)
if not input_size:
Expand Down Expand Up @@ -490,7 +490,7 @@ def get_gpu_memory_map() -> Dict[str, int]:


def get_formatted_model_size(total_model_size: float) -> float:
return "{:.3f}".format(total_model_size)
return f"{total_model_size:.3f}"


def get_human_readable_count(number: int) -> str:
Expand Down

0 comments on commit d42352b

Please sign in to comment.