From 63010a8d11765b42a1eb3befca24eb1fe9cd7dc4 Mon Sep 17 00:00:00 2001 From: Kartik Sharma Date: Wed, 20 Jan 2021 19:12:50 +0530 Subject: [PATCH] Simplified Model size --- pytorch_lightning/core/memory.py | 160 +++++-------------------------- tests/core/test_memory.py | 138 +++++++------------------- 2 files changed, 59 insertions(+), 239 deletions(-) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index a74eb3c8089b4..de551cabd30df 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -16,7 +16,7 @@ import shutil import subprocess from collections import OrderedDict -from typing import Tuple, Optional, Dict, Union, List, Any +from typing import Tuple, Dict, Union, List, Any import numpy as np import torch @@ -33,17 +33,13 @@ class LayerSummary(object): """ Summary class for a single layer in a :class:`~pytorch_lightning.core.lightning.LightningModule`. It collects the following information: - - Type of the layer (e.g. Linear, BatchNorm1d, ...) - Input shape - Output shape - Number of parameters - The input and output shapes are only known after the example input array was passed through the model. - Example:: - >>> model = torch.nn.Conv2d(3, 8, 3) >>> summary = LayerSummary(model) >>> summary.num_parameters @@ -55,10 +51,8 @@ class LayerSummary(object): [1, 3, 5, 5] >>> summary.out_size [1, 8, 3, 3] - Args: module: A module to summarize - """ def __init__(self, module: nn.Module): @@ -76,7 +70,6 @@ def _register_hook(self) -> RemovableHandle: Registers a hook on the module that computes the input- and output size(s) on the first forward pass. If the hook is called, it will remove itself from the from the module, meaning that recursive models will only record their input- and output shapes once. - Return: A handle for the installed hook. """ @@ -120,25 +113,19 @@ def num_parameters(self) -> int: class ModelSummary(object): """ Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`. - Args: model: The model to summarize (also referred to as the root module) mode: Can be one of - - `top` (default): only the top-level modules will be recorded (the children of the root module) - `full`: summarizes all layers and their submodules in the root module - The string representation of this summary prints a table with columns containing - the name, type and number of parameters for each layer. - + the name type and number of parameters for each layer. The root module may also have an attribute ``example_input_array`` as shown in the example below. If present, the root module will be called with it as input to determine the intermediate input- and output shapes of all layers. Supported are tensors and nested lists and tuples of tensors. All other types of inputs will be skipped and show as `?` in the summary table. The summary will also display `?` for layers not used in the forward pass. - Example:: - >>> import pytorch_lightning as pl >>> class LitModel(pl.LightningModule): ... @@ -169,9 +156,7 @@ class ModelSummary(object): 132 K Trainable params 0 Non-trainable params 132 K Total params - 0.506 Total Estimated Params Size (MB) - 0.012 Total Estimated Forward/Backward Size (MB) - 0.527 Total Estimated Model Size (MB) + 0.506 Total estimated model params size (MB) """ MODE_TOP = "top" @@ -183,10 +168,7 @@ def __init__(self, model, mode: str = MODE_DEFAULT): self._model = model self._mode = mode self._layer_summary = self.summarize() - self._precision_bytes = self._model.precision / 8.0 # 1 byte -> 8 bits - self._precision_megabytes = self._precision_bytes / (1024 ** 2.0) - self.total_output_dsize = 0.0 - self.total_params_dsize = 0.0 + self._precision_megabytes = (self._model.precision / 8.0) / (1024 ** 2.0) # 1 byte -> 8 bits) @property def named_modules(self) -> List[Tuple[str, nn.Module]]: @@ -221,50 +203,19 @@ def param_nums(self) -> List[int]: return [layer.num_parameters for layer in self._layer_summary.values()] @property - def total_params(self) -> int: + def total_parameters(self) -> int: return sum(p.numel() for p in self._model.parameters()) - def total_out_params(self, batch_size_dim: int) -> int: - """ finds total output parameters to calculate forward/backward pass size. """ - - # recursive traversal to calculate output size. - # recursive is used to handle nested output sizes i.e [[[1,2,3], [[12,2,3], [1,3,4]]], [2,3,4]]. - def _get_out_size_params(out_sizes, batch_size_dim=batch_size_dim): - nonlocal _total_out_params - if not any(isinstance(i, list) for i in out_sizes): - try: - if out_sizes: - out_sizes = out_sizes[:batch_size_dim] + [-1] + out_sizes[batch_size_dim + 1 :] - # try to find prod, i.e check for unknown sizes. - _total_out_params += np.prod(out_sizes) - except TypeError: - # do nothing if tried to find prod on unknown type. - pass - else: - for out_size in out_sizes: - if isinstance(out_size, list): - _get_out_size_params(out_size) - - _total_out_params = 0 - - _out_sizes = [ - l_size - for l_type, l_size in zip(self.layer_types, self.out_sizes) - if not isinstance(l_type, torch.nn.Sequential) - ] - - _get_out_size_params(_out_sizes) - - return _total_out_params + @property + def trainable_parameters(self) -> int: + return sum(p.numel() for p in self._model.parameters() if p.requires_grad) - def model_size(self, batch_size_dim: Optional[int] = 0) -> float: + def model_size(self) -> float: """ - Estimates total model size i.e input_size + forward/backward pass size + total params size in MBs + Estimates total model size i.e total params size in MBs total params size gives model size in accounting total model params. - forward/backward model size accounts model size acounting output shape of individual layers. - input size gives the total input size in MBs including multiple inputs, batch size, etc. - NOTE: Currently only Supported in Full Mode. + NOTE: Currently only Supported total params size. Example:: >> model = LitModel() @@ -272,47 +223,9 @@ def model_size(self, batch_size_dim: Optional[int] = 0) -> float: >> summary.model_size() Returns: - Total estimated model size(MB) if example input array is passed else Total Model Params Size(MB). + Total estimated model size(MB). """ - - if isinstance(self._model.example_input_array, (list, tuple)): - in_features = ( - sum( - [ - input_array.numel() if isinstance(input_array, torch.Tensor) else torch.tensor(input_array) - for input_array in self._model.example_input_array - ] - ), - ) - - elif isinstance(self._model.example_input_array, dict): - in_features = self._model.example_input_array["tensor"].numel() - elif isinstance(self._model.example_input_array, torch.Tensor): - in_features = (self._model.example_input_array.numel(),) - else: - # if example_input_array is NoneType - in_features = None - - return self._get_total_size(in_features, batch_size_dim) - - def _get_total_size(self, input_size: Tuple[int], batch_size_dim: int) -> float: - """_get_total_size. - Function to find total model size (MB) - - Args: - input_size : input_size to calculate model input size (MB) - - Returns: - Total estimated model size if example input array is passed else Total Model Params Size. - """ - self.total_params_dsize = abs(self.total_params * self._precision_megabytes) - if not input_size: - self.total_output_dsize = 0.0 - return self.total_params_dsize - self.total_input_dsize = abs(np.prod(np.array(input_size)) * self._precision_megabytes) - # 2x for gradients. - self.total_output_dsize = abs(2.0 * self.total_out_params(batch_size_dim) * self._precision_megabytes) - return self.total_params_dsize + self.total_output_dsize + self.total_input_dsize + return self.total_parameters * self._precision_megabytes def summarize(self) -> Dict[str, LayerSummary]: summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules) @@ -323,7 +236,7 @@ def summarize(self) -> Dict[str, LayerSummary]: return summary def _forward_example_input(self) -> None: - """ Run the example input through each layer to get input and output sizes. """ + """ Run the example input through each layer to get input- and output sizes. """ model = self._model trainer = self._model.trainer @@ -348,7 +261,6 @@ def _forward_example_input(self) -> None: def __str__(self): """ Makes a summary listing with: - Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes """ arrays = [ @@ -360,15 +272,10 @@ def __str__(self): if self._model.example_input_array is not None: arrays.append(["In sizes", self.in_sizes]) arrays.append(["Out sizes", self.out_sizes]) + total_parameters = self.total_parameters + trainable_parameters = self.trainable_parameters + model_size = self.model_size() - trainable_parameters = sum(p.numel() for p in self._model.parameters() if p.requires_grad) - total_parameters = self.total_params - model_size = None - if self._mode == self.MODE_FULL: - total_model_dsize = self.model_size() - total_params_dsize = self.total_params_dsize - total_output_dsize = self.total_output_dsize - model_size = (total_params_dsize, total_output_dsize, total_model_dsize) return _format_summary_table(total_parameters, trainable_parameters, model_size, *arrays) def __repr__(self): @@ -386,12 +293,7 @@ def parse_batch_shape(batch: Any) -> Union[str, List]: return UNKNOWN_SIZE -def _format_summary_table( - total_parameters: int, - trainable_parameters: int, - model_size: tuple, - *cols, -) -> str: +def _format_summary_table(total_parameters: int, trainable_parameters: int, model_size: float, *cols) -> str: """ Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big @@ -427,31 +329,22 @@ def _format_summary_table( summary += "Non-trainable params" summary += "\n" + s.format(get_human_readable_count(total_parameters), 10) summary += "Total params" - if model_size: - summary += "\n" + s.format(get_formatted_model_size(model_size[0]), 10) - summary += "Total Estimated Params Size (MB)" - summary += "\n" + s.format(get_formatted_model_size(model_size[1]), 10) - summary += "Total Estimated Forward/Backward Size (MB)" - summary += "\n" + s.format(get_formatted_model_size(model_size[2]), 10) - summary += "Total Estimated Model Size (MB)" + summary += "\n" + s.format(get_formatted_model_size(model_size), 10) + summary += "Total Estimated Params Size (MB)" return summary def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]: - """Get a profile of the current memory usage. - + """ Get a profile of the current memory usage. Args: mode: There are two modes: - - 'all' means return memory for all gpus - 'min_max' means return memory for max and min - Return: A dictionary in which the keys are device ids as integers and values are memory usage as integers in MB. If mode is 'min_max', the dictionary will also contain two additional keys: - - 'min_gpu_mem': the minimum memory usage in MB - 'max_gpu_mem': the maximum memory usage in MB """ @@ -469,7 +362,6 @@ def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]: def get_gpu_memory_map() -> Dict[str, int]: """ Get the current gpu usage. - Return: A dictionary in which the keys are device ids as integers and values are memory usage as integers in MB. @@ -485,19 +377,18 @@ def get_gpu_memory_map() -> Dict[str, int]: # Convert lines into a dictionary gpu_memory = [float(x) for x in result.stdout.strip().split(os.linesep)] - gpu_memory_map = {f"gpu_id: {gpu_id}/memory.used (MB)": memory for gpu_id, memory in enumerate(gpu_memory)} + gpu_memory_map = { + f"gpu_id: {gpu_id}/memory.used (MB)": memory for gpu_id, memory in enumerate(gpu_memory) + } return gpu_memory_map - def get_formatted_model_size(total_model_size: float) -> float: return f"{total_model_size:.3f}" - def get_human_readable_count(number: int) -> str: """ Abbreviates an integer number with K, M, B, T for thousands, millions, billions and trillions, respectively. - Examples: >>> get_human_readable_count(123) '123 ' @@ -511,13 +402,10 @@ def get_human_readable_count(number: int) -> str: '400 T' >>> get_human_readable_count(5e15) # (more than trillion) '5,000 T' - Args: number: a positive integer number - Return: A string formatted according to the pattern described above. - """ assert number >= 0 labels = PARAMETER_NUM_UNITS diff --git a/tests/core/test_memory.py b/tests/core/test_memory.py index af4274d98a64e..cb68ad04459e8 100644 --- a/tests/core/test_memory.py +++ b/tests/core/test_memory.py @@ -21,54 +21,33 @@ from tests.base.models import ParityModuleRNN -def almost_equals(a, b, rel_tol=0.0, abs_tol=0.0): - _almost_close = lambda a, b: abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) - return _almost_close(a, b) - +class EmptyModule(LightningModule): + """ A module that has no layers """ -class LitModel(LightningModule): def __init__(self): super().__init__() - self.net = nn.Sequential(nn.Linear(256, 512), nn.BatchNorm1d(512)) + self.parameter = torch.rand(3, 3, requires_grad=True) + self.example_input_array = torch.zeros(1, 2, 3, 4, 5) - def forward(self, x): - return self.net(x) + def forward(self, *args, **kwargs): + return {'loss': self.parameter.sum()} -class KnownNet(LightningModule): - """ Pre calculated known model """ +class PreCalculatedModel(LightningModule): + """ A module with precalculated total params size in MB. """ def __init__(self): super().__init__() - self.conv1 = nn.Conv2d(1, 10, kernel_size=5) - self.conv2 = nn.Conv2d(10, 20, kernel_size=5) - self.conv3 = nn.Conv2d(20, 30, kernel_size=3) - self.conv4 = nn.Conv2d(30, 30, kernel_size=3) - self.fc1 = nn.Linear(10, 50) - self.fc2 = nn.Linear(50, 10) + self.layer1 = nn.Linear(10, 100) + self.layer2 = nn.Linear(100, 2) + self.pre_calculated_model_size = 0.005 def forward(self, x): - x = self.conv1(x) - x = self.conv2(x) - x = self.conv3(x) - x = self.conv4(x) - x = x.view(-1, 10) - x = self.fc1(x) - x = self.fc2(x) + x = self.layer1(x) + x = self.layer2(x) return x -class EmptyModule(LightningModule): - """ A module that has no layers """ - - def __init__(self): - super().__init__() - self.parameter = torch.rand(3, 3, requires_grad=True) - self.example_input_array = torch.zeros(1, 2, 3, 4, 5) - - def forward(self, *args, **kwargs): - return {'loss': self.parameter.sum()} - class UnorderedModel(LightningModule): """ A model in which the layers not defined in order of execution """ @@ -250,7 +229,6 @@ def test_summary_layer_types(mode): ] -<<<<<<< HEAD @pytest.mark.parametrize(['mode'], [ pytest.param(ModelSummary.MODE_FULL), pytest.param(ModelSummary.MODE_TOP), @@ -264,74 +242,6 @@ def test_summary_layer_types(mode): pytest.param([torch.zeros(2, 3), torch.zeros(4, 5)], [[2, 3], [4, 5]]), pytest.param((torch.zeros(2, 3), torch.zeros(4, 5)), [[2, 3], [4, 5]]), ]) -======= -@pytest.mark.parametrize( - ["mode"], - [ - pytest.param(ModelSummary.MODE_FULL), - pytest.param(ModelSummary.MODE_TOP), - ], -) -@pytest.mark.parametrize( - ["example_input", "expected_model_size"], - [ - pytest.param(torch.zeros(1, 1, 28, 28), 0.318), - pytest.param(torch.zeros(1, 1, 224, 224), 31.84), - pytest.param(torch.zeros(10, 1, 512, 512), 183.425), - pytest.param(None, 0.075), - ], -) -def test_known_model_sizes(example_input, expected_model_size, mode): - """ Test the knownet model on example input arrays and corresponding known model size """ - - model = KnownNet() - model.example_input_array = example_input - summary = model.summarize(mode=mode) - assert almost_equals(summary.model_size(), expected_model_size, rel_tol=1e-3, abs_tol=1e-3) - - -@pytest.mark.parametrize( - ["mode"], - [ - pytest.param(ModelSummary.MODE_FULL), - ], -) -@pytest.mark.parametrize( - ["example_input", "expected_model_size"], - [ - pytest.param(torch.zeros(10, 256), 0.527), - pytest.param(None, 0.505), - ], -) -def test_nested_seq_model_sizes(example_input, expected_model_size, mode): - """ Test the knownet model on example input arrays and corresponding known model size """ - - model = LitModel() - model.example_input_array = example_input - summary = model.summarize(mode=mode) - assert almost_equals(summary.model_size(), expected_model_size, rel_tol=1e-3, abs_tol=1e-3) - - -@pytest.mark.parametrize( - ["mode"], - [ - pytest.param(ModelSummary.MODE_FULL), - pytest.param(ModelSummary.MODE_TOP), - ], -) -@pytest.mark.parametrize( - ["example_input", "expected_size"], - [ - pytest.param([], UNKNOWN_SIZE), - pytest.param((1, 2, 3), [UNKNOWN_SIZE] * 3), - pytest.param(torch.tensor(0), UNKNOWN_SIZE), - pytest.param(dict(tensor=torch.zeros(1, 2, 3)), UNKNOWN_SIZE), - pytest.param(torch.zeros(2, 3, 4), [2, 3, 4]), - pytest.param([torch.zeros(2, 3), torch.zeros(4, 5)], [[2, 3], [4, 5]]), - pytest.param((torch.zeros(2, 3), torch.zeros(4, 5)), [[2, 3], [4, 5]]), - ], -) ->>>>>>> :hammer: Simplified tests def test_example_input_array_types(example_input, expected_size, mode): """ Test the types of example inputs supported for display in the summary. """ @@ -353,3 +263,25 @@ def forward(self, *args, **kwargs): model.example_input_array = example_input summary = model.summarize(mode=mode) assert summary.in_sizes == [expected_size] + +@pytest.mark.parametrize(['mode'], [ + pytest.param(ModelSummary.MODE_FULL), + pytest.param(ModelSummary.MODE_TOP), +]) +def test_model_size(mode): + """ Test that model size is calculated correctly. """ + model = PreCalculatedModel() + summary = model.summarize(mode=mode) + pre_calculated_model_size = torch.tensor(model.pre_calculated_model_size) + model_size = torch.tensor(summary.model_size()) + assert torch.isclose(model_size, pre_calculated_model_size, atol=1e-4) + +@pytest.mark.parametrize(['mode'], [ + pytest.param(ModelSummary.MODE_FULL), + pytest.param(ModelSummary.MODE_TOP), +]) +def test_empty_model_size(mode): + """ Test that empty model size is zero. """ + model = EmptyModule() + summary = model.summarize(mode=mode) + assert 0.0 == summary.model_size()