Skip to content

Commit

Permalink
🔨 refactore and minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kartik4949 committed Dec 1, 2020
1 parent dfece30 commit 60ee020
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 21 deletions.
90 changes: 72 additions & 18 deletions pytorch_lightning/core/memory.py
Expand Up @@ -16,7 +16,7 @@
import shutil
import subprocess
from collections import OrderedDict
from typing import Tuple, Dict, Union, List, Any
from typing import Tuple, Optional, Dict, Union, List, Any

import numpy as np
import torch
Expand Down Expand Up @@ -181,6 +181,8 @@ def __init__(self, model, mode: str = MODE_DEFAULT):
self._mode = mode
self._layer_summary = self.summarize()
self._precision_bytes = self._model.precision / 8.0 # 1 byte -> 8 bits
self.total_output_dsize = 0.0
self.total_params_dsize = 0.0

@property
def named_modules(self) -> List[Tuple[str, nn.Module]]:
Expand Down Expand Up @@ -214,14 +216,18 @@ def out_sizes(self) -> List:
def param_nums(self) -> List[int]:
return [layer.num_parameters for layer in self._layer_summary.values()]

@property
def total_out_params(self) -> int:
def total_out_params(self, batch_size_dim: int) -> int:
""" finds total output parameters to calculate forward/backward pass size. """
_total_out_params = 0

def _get_out_size_params(out_sizes):
# recursive traversal to calculate output size.
# recursive is used to handle nested output sizes i.e [[[1,2,3],[[12,2,3], [1,3,4]]], [2,3,4]].
def _get_out_size_params(out_sizes, batch_size_dim=batch_size_dim):
nonlocal _total_out_params
if not any(isinstance(i, list) for i in out_sizes):
try:
if out_sizes:
out_sizes = out_sizes[:batch_size_dim] + [-1] + out_sizes[batch_size_dim + 1 :]
# try to find prod, i.e check for unknown sizes.
_total_out_params += np.prod(out_sizes)
except TypeError:
Expand All @@ -230,15 +236,34 @@ def _get_out_size_params(out_sizes):
else:
_ = [_get_out_size_params(out_size) for out_size in out_sizes if isinstance(out_size, list)]

_get_out_size_params(self.out_sizes)
import copy

_out_sizes = copy.deepcopy(self.out_sizes)
_get_out_size_params(_out_sizes)

return _total_out_params

@property
def total_params(self) -> int:
return sum(self.param_nums)

def model_size(self, input_size=None) -> float:
def model_size(self, batch_size_dim: Optional[int] = 0) -> float:
"""
Estimates total model size i.e input_size + forward/backward pass size + total params size in MBs
total params size gives model size in accounting total model params.
forward/backward model size accounts model size acounting output shape of individual layers.
input size gives the total input size in MBs including multiple inputs, batch size, etc.
::
Example:
>> model = LitModel()
>> summary = ModelSummary(model, mode='top') # doctest: +NORMALIZE_WHITESPACE
>> summary.model_size()
Returns:
float: Total estimated model size(MB) if example input array is passed else Total Model Params Size(MB).
"""

if isinstance(self._model.example_input_array, (list, tuple)):
in_features = (
sum(
Expand All @@ -255,16 +280,30 @@ def model_size(self, input_size=None) -> float:
in_features = (self._model.example_input_array.numel(),)
else:
# if example_input_array is NoneType
in_features = (1,)
in_features = None

return self._get_total_size(in_features, batch_size_dim)

return self._get_total_size(in_features if not input_size else input_size)
def _get_total_size(self, input_size: tuple, batch_size_dim: int) -> float:
"""_get_total_size.
helper function to find total model size MB
def _get_total_size(self, input_size: tuple) -> float:
total_input_dsize = abs(np.prod(np.array(input_size))) * self._precision_bytes / (1024 ** 2.0)
Args:
input_size (tuple): input_size to calculate model input size (MB)
Returns:
float: Total estimated model size if example input array is passed else Total Model Params Size.
"""
self.total_params_dsize = abs(self.total_params * self._precision_bytes / (1024 ** 2.0))
if not input_size:
self.total_output_dsize = 0.0
return self.total_params_dsize
self.total_input_dsize = abs(np.prod(np.array(input_size))) * self._precision_bytes / (1024 ** 2.0)
# 2x for gradients.
total_output_dsize = abs(2.0 * self.total_out_params * self._precision_bytes / (1024 ** 2.0))
total_params_dsize = abs(self.total_params * self._precision_bytes / (1024 ** 2.0))
return total_params_dsize + total_output_dsize + total_input_dsize
self.total_output_dsize = abs(
2.0 * self.total_out_params(batch_size_dim) * self._precision_bytes / (1024 ** 2.0)
)
return self.total_params_dsize + self.total_output_dsize + self.total_input_dsize

def summarize(self) -> Dict[str, LayerSummary]:
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
Expand All @@ -275,7 +314,7 @@ def summarize(self) -> Dict[str, LayerSummary]:
return summary

def _forward_example_input(self) -> None:
""" Run the example input through each layer to get input- and output sizes. """
""" Run the example input through each layer to get input and output sizes. """
model = self._model
trainer = self._model.trainer

Expand Down Expand Up @@ -315,8 +354,12 @@ def __str__(self):

trainable_parameters = sum(p.numel() for p in self._model.parameters() if p.requires_grad)
total_parameters = self.total_params
total_model_size = self.model_size()
return _format_summary_table(total_parameters, trainable_parameters, total_model_size, *arrays)
total_model_dsize = self.model_size()
total_params_dsize = self.total_params_dsize
total_output_dsize = self.total_output_dsize
return _format_summary_table(
total_parameters, trainable_parameters, total_model_dsize, total_output_dsize, total_params_dsize, *arrays
)

def __repr__(self):
return str(self)
Expand All @@ -333,7 +376,14 @@ def parse_batch_shape(batch: Any) -> Union[str, List]:
return UNKNOWN_SIZE


def _format_summary_table(total_parameters: int, trainable_parameters: int, total_model_size: float, *cols) -> str:
def _format_summary_table(
total_parameters: int,
trainable_parameters: int,
total_model_dsize: float,
total_output_dsize: float,
total_params_dsize: float,
*cols,
) -> str:
"""
Takes in a number of arrays, each specifying a column in
the summary table, and combines them all into one big
Expand Down Expand Up @@ -369,7 +419,11 @@ def _format_summary_table(total_parameters: int, trainable_parameters: int, tota
summary += "Non-trainable params"
summary += "\n" + s.format(get_human_readable_count(total_parameters), 10)
summary += "Total params"
summary += "\n" + s.format(get_formatted_model_size(total_model_size), 10)
summary += "\n" + s.format(get_formatted_model_size(total_params_dsize), 10)
summary += "Total Estimated Params Size (MB)"
summary += "\n" + s.format(get_formatted_model_size(total_output_dsize), 10)
summary += "Total Estimated Forward/Backward Size (MB)"
summary += "\n" + s.format(get_formatted_model_size(total_model_dsize), 10)
summary += "Total Estimated Model Size (MB)"

return summary
Expand Down
7 changes: 4 additions & 3 deletions tests/core/test_memory.py
Expand Up @@ -259,9 +259,10 @@ def test_summary_layer_types(mode):
@pytest.mark.parametrize(
["example_input", "expected_model_size"],
[
pytest.param(torch.zeros(1, 1, 28, 28), 0.668),
pytest.param(torch.zeros(1, 1, 224, 224), 93.57),
pytest.param(torch.zeros(10, 1, 512, 512), 5176.78),
pytest.param(torch.zeros(1, 1, 28, 28), 0.318),
pytest.param(torch.zeros(1, 1, 224, 224), 31.84),
pytest.param(torch.zeros(10, 1, 512, 512), 183.425),
pytest.param(None, 0.075),
],
)
def test_known_model_sizes(example_input, expected_model_size, mode):
Expand Down

0 comments on commit 60ee020

Please sign in to comment.