Skip to content

Commit

Permalink
Simplified Model size
Browse files Browse the repository at this point in the history
  • Loading branch information
kartik4949 committed Jan 20, 2021
1 parent d42352b commit 63010a8
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 239 deletions.
160 changes: 24 additions & 136 deletions pytorch_lightning/core/memory.py
Expand Up @@ -16,7 +16,7 @@
import shutil
import subprocess
from collections import OrderedDict
from typing import Tuple, Optional, Dict, Union, List, Any
from typing import Tuple, Dict, Union, List, Any

import numpy as np
import torch
Expand All @@ -33,17 +33,13 @@ class LayerSummary(object):
"""
Summary class for a single layer in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
It collects the following information:
- Type of the layer (e.g. Linear, BatchNorm1d, ...)
- Input shape
- Output shape
- Number of parameters
The input and output shapes are only known after the example input array was
passed through the model.
Example::
>>> model = torch.nn.Conv2d(3, 8, 3)
>>> summary = LayerSummary(model)
>>> summary.num_parameters
Expand All @@ -55,10 +51,8 @@ class LayerSummary(object):
[1, 3, 5, 5]
>>> summary.out_size
[1, 8, 3, 3]
Args:
module: A module to summarize
"""

def __init__(self, module: nn.Module):
Expand All @@ -76,7 +70,6 @@ def _register_hook(self) -> RemovableHandle:
Registers a hook on the module that computes the input- and output size(s) on the first forward pass.
If the hook is called, it will remove itself from the from the module, meaning that
recursive models will only record their input- and output shapes once.
Return:
A handle for the installed hook.
"""
Expand Down Expand Up @@ -120,25 +113,19 @@ def num_parameters(self) -> int:
class ModelSummary(object):
"""
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
Args:
model: The model to summarize (also referred to as the root module)
mode: Can be one of
- `top` (default): only the top-level modules will be recorded (the children of the root module)
- `full`: summarizes all layers and their submodules in the root module
The string representation of this summary prints a table with columns containing
the name, type and number of parameters for each layer.
the name type and number of parameters for each layer.
The root module may also have an attribute ``example_input_array`` as shown in the example below.
If present, the root module will be called with it as input to determine the
intermediate input- and output shapes of all layers. Supported are tensors and
nested lists and tuples of tensors. All other types of inputs will be skipped and show as `?`
in the summary table. The summary will also display `?` for layers not used in the forward pass.
Example::
>>> import pytorch_lightning as pl
>>> class LitModel(pl.LightningModule):
...
Expand Down Expand Up @@ -169,9 +156,7 @@ class ModelSummary(object):
132 K Trainable params
0 Non-trainable params
132 K Total params
0.506 Total Estimated Params Size (MB)
0.012 Total Estimated Forward/Backward Size (MB)
0.527 Total Estimated Model Size (MB)
0.506 Total estimated model params size (MB)
"""

MODE_TOP = "top"
Expand All @@ -183,10 +168,7 @@ def __init__(self, model, mode: str = MODE_DEFAULT):
self._model = model
self._mode = mode
self._layer_summary = self.summarize()
self._precision_bytes = self._model.precision / 8.0 # 1 byte -> 8 bits
self._precision_megabytes = self._precision_bytes / (1024 ** 2.0)
self.total_output_dsize = 0.0
self.total_params_dsize = 0.0
self._precision_megabytes = (self._model.precision / 8.0) / (1024 ** 2.0) # 1 byte -> 8 bits)

@property
def named_modules(self) -> List[Tuple[str, nn.Module]]:
Expand Down Expand Up @@ -221,98 +203,29 @@ def param_nums(self) -> List[int]:
return [layer.num_parameters for layer in self._layer_summary.values()]

@property
def total_params(self) -> int:
def total_parameters(self) -> int:
return sum(p.numel() for p in self._model.parameters())

def total_out_params(self, batch_size_dim: int) -> int:
""" finds total output parameters to calculate forward/backward pass size. """

# recursive traversal to calculate output size.
# recursive is used to handle nested output sizes i.e [[[1,2,3], [[12,2,3], [1,3,4]]], [2,3,4]].
def _get_out_size_params(out_sizes, batch_size_dim=batch_size_dim):
nonlocal _total_out_params
if not any(isinstance(i, list) for i in out_sizes):
try:
if out_sizes:
out_sizes = out_sizes[:batch_size_dim] + [-1] + out_sizes[batch_size_dim + 1 :]
# try to find prod, i.e check for unknown sizes.
_total_out_params += np.prod(out_sizes)
except TypeError:
# do nothing if tried to find prod on unknown type.
pass
else:
for out_size in out_sizes:
if isinstance(out_size, list):
_get_out_size_params(out_size)

_total_out_params = 0

_out_sizes = [
l_size
for l_type, l_size in zip(self.layer_types, self.out_sizes)
if not isinstance(l_type, torch.nn.Sequential)
]

_get_out_size_params(_out_sizes)

return _total_out_params
@property
def trainable_parameters(self) -> int:
return sum(p.numel() for p in self._model.parameters() if p.requires_grad)

def model_size(self, batch_size_dim: Optional[int] = 0) -> float:
def model_size(self) -> float:
"""
Estimates total model size i.e input_size + forward/backward pass size + total params size in MBs
Estimates total model size i.e total params size in MBs
total params size gives model size in accounting total model params.
forward/backward model size accounts model size acounting output shape of individual layers.
input size gives the total input size in MBs including multiple inputs, batch size, etc.
NOTE: Currently only Supported in Full Mode.
NOTE: Currently only Supported total params size.
Example::
>> model = LitModel()
>> summary = ModelSummary(model, mode='full') # doctest: +NORMALIZE_WHITESPACE
>> summary.model_size()
Returns:
Total estimated model size(MB) if example input array is passed else Total Model Params Size(MB).
Total estimated model size(MB).
"""

if isinstance(self._model.example_input_array, (list, tuple)):
in_features = (
sum(
[
input_array.numel() if isinstance(input_array, torch.Tensor) else torch.tensor(input_array)
for input_array in self._model.example_input_array
]
),
)

elif isinstance(self._model.example_input_array, dict):
in_features = self._model.example_input_array["tensor"].numel()
elif isinstance(self._model.example_input_array, torch.Tensor):
in_features = (self._model.example_input_array.numel(),)
else:
# if example_input_array is NoneType
in_features = None

return self._get_total_size(in_features, batch_size_dim)

def _get_total_size(self, input_size: Tuple[int], batch_size_dim: int) -> float:
"""_get_total_size.
Function to find total model size (MB)
Args:
input_size : input_size to calculate model input size (MB)
Returns:
Total estimated model size if example input array is passed else Total Model Params Size.
"""
self.total_params_dsize = abs(self.total_params * self._precision_megabytes)
if not input_size:
self.total_output_dsize = 0.0
return self.total_params_dsize
self.total_input_dsize = abs(np.prod(np.array(input_size)) * self._precision_megabytes)
# 2x for gradients.
self.total_output_dsize = abs(2.0 * self.total_out_params(batch_size_dim) * self._precision_megabytes)
return self.total_params_dsize + self.total_output_dsize + self.total_input_dsize
return self.total_parameters * self._precision_megabytes

def summarize(self) -> Dict[str, LayerSummary]:
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
Expand All @@ -323,7 +236,7 @@ def summarize(self) -> Dict[str, LayerSummary]:
return summary

def _forward_example_input(self) -> None:
""" Run the example input through each layer to get input and output sizes. """
""" Run the example input through each layer to get input- and output sizes. """
model = self._model
trainer = self._model.trainer

Expand All @@ -348,7 +261,6 @@ def _forward_example_input(self) -> None:
def __str__(self):
"""
Makes a summary listing with:
Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes
"""
arrays = [
Expand All @@ -360,15 +272,10 @@ def __str__(self):
if self._model.example_input_array is not None:
arrays.append(["In sizes", self.in_sizes])
arrays.append(["Out sizes", self.out_sizes])
total_parameters = self.total_parameters
trainable_parameters = self.trainable_parameters
model_size = self.model_size()

trainable_parameters = sum(p.numel() for p in self._model.parameters() if p.requires_grad)
total_parameters = self.total_params
model_size = None
if self._mode == self.MODE_FULL:
total_model_dsize = self.model_size()
total_params_dsize = self.total_params_dsize
total_output_dsize = self.total_output_dsize
model_size = (total_params_dsize, total_output_dsize, total_model_dsize)
return _format_summary_table(total_parameters, trainable_parameters, model_size, *arrays)

def __repr__(self):
Expand All @@ -386,12 +293,7 @@ def parse_batch_shape(batch: Any) -> Union[str, List]:
return UNKNOWN_SIZE


def _format_summary_table(
total_parameters: int,
trainable_parameters: int,
model_size: tuple,
*cols,
) -> str:
def _format_summary_table(total_parameters: int, trainable_parameters: int, model_size: float, *cols) -> str:
"""
Takes in a number of arrays, each specifying a column in
the summary table, and combines them all into one big
Expand Down Expand Up @@ -427,31 +329,22 @@ def _format_summary_table(
summary += "Non-trainable params"
summary += "\n" + s.format(get_human_readable_count(total_parameters), 10)
summary += "Total params"
if model_size:
summary += "\n" + s.format(get_formatted_model_size(model_size[0]), 10)
summary += "Total Estimated Params Size (MB)"
summary += "\n" + s.format(get_formatted_model_size(model_size[1]), 10)
summary += "Total Estimated Forward/Backward Size (MB)"
summary += "\n" + s.format(get_formatted_model_size(model_size[2]), 10)
summary += "Total Estimated Model Size (MB)"
summary += "\n" + s.format(get_formatted_model_size(model_size), 10)
summary += "Total Estimated Params Size (MB)"

return summary


def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]:
"""Get a profile of the current memory usage.
""" Get a profile of the current memory usage.
Args:
mode: There are two modes:
- 'all' means return memory for all gpus
- 'min_max' means return memory for max and min
Return:
A dictionary in which the keys are device ids as integers and
values are memory usage as integers in MB.
If mode is 'min_max', the dictionary will also contain two additional keys:
- 'min_gpu_mem': the minimum memory usage in MB
- 'max_gpu_mem': the maximum memory usage in MB
"""
Expand All @@ -469,7 +362,6 @@ def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]:
def get_gpu_memory_map() -> Dict[str, int]:
"""
Get the current gpu usage.
Return:
A dictionary in which the keys are device ids as integers and
values are memory usage as integers in MB.
Expand All @@ -485,19 +377,18 @@ def get_gpu_memory_map() -> Dict[str, int]:

# Convert lines into a dictionary
gpu_memory = [float(x) for x in result.stdout.strip().split(os.linesep)]
gpu_memory_map = {f"gpu_id: {gpu_id}/memory.used (MB)": memory for gpu_id, memory in enumerate(gpu_memory)}
gpu_memory_map = {
f"gpu_id: {gpu_id}/memory.used (MB)": memory for gpu_id, memory in enumerate(gpu_memory)
}
return gpu_memory_map


def get_formatted_model_size(total_model_size: float) -> float:
return f"{total_model_size:.3f}"


def get_human_readable_count(number: int) -> str:
"""
Abbreviates an integer number with K, M, B, T for thousands, millions,
billions and trillions, respectively.
Examples:
>>> get_human_readable_count(123)
'123 '
Expand All @@ -511,13 +402,10 @@ def get_human_readable_count(number: int) -> str:
'400 T'
>>> get_human_readable_count(5e15) # (more than trillion)
'5,000 T'
Args:
number: a positive integer number
Return:
A string formatted according to the pattern described above.
"""
assert number >= 0
labels = PARAMETER_NUM_UNITS
Expand Down

0 comments on commit 63010a8

Please sign in to comment.