Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trace functionality to the function to_torchscript #4142

Merged
merged 3 commits into from
Oct 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 25 additions & 8 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,17 +1527,23 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg
torch.onnx.export(self, input_data, file_path, **kwargs)

def to_torchscript(
self, file_path: Optional[str] = None, **kwargs
self, file_path: Optional[str] = None, method: Optional[str] = 'script',
example_inputs: Optional[torch.Tensor] = None, **kwargs
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
"""
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
If you would like to customize the modules that are scripted or you want to use tracing
you should override this method. In case you want to return multiple modules, we
recommend using a dictionary.
If you want to use tracing, please provided the argument `method='trace'` and make sure that either the
example_inputs argument is provided, or the model has self.example_input_array set.
If you would like to customize the modules that are scripted you should override this method.
In case you want to return multiple modules, we recommend using a dictionary.

Args:
file_path: Path where to save the torchscript. Default: None (no file saved).
**kwargs: Additional arguments that will be passed to the :func:`torch.jit.save` function.
method: Whether to use TorchScript's script or trace method. Default: 'script'
example_inputs: Tensor to be used to do tracing when method is set to 'trace'.
Default: None (Use self.example_input_array)
**kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or
:func:`torch.jit.trace` function.

Note:
- Requires the implementation of the
Expand Down Expand Up @@ -1568,13 +1574,24 @@ def to_torchscript(

mode = self.training
with torch.no_grad():
scripted_module = torch.jit.script(self.eval(), **kwargs)
if method == 'script':
torchscript_module = torch.jit.script(self.eval(), **kwargs)
elif method == 'trace':
# if no example inputs are provided, try to see if model has example_input_array set
if example_inputs is None:
example_inputs = self.example_input_array
# automatically send example inputs to the right device and use trace
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs.to(self.device),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you assume that example_input_array is a tensor, but this is not true.
If forward takes *args, example_input_array is a tuple, and if forward takes **kwargs, example_input_array must be a dict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given your comment on the issue that .trace() accepts either a tuple or a torch.Tensor (that is automatically converted to a tuple), it means that the input should be: example_input_array: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]?

However, when the forward function accepts **kwargs, self.example_input_array could be a dict, in which case .trace(example_inputs=example_inputs) will fail?

What would be the best way to approach this? Does this mean that .trace() cannot be used if forward expects a dict?

Copy link
Member

@awaelchli awaelchli Oct 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)


class Net2(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x, y):
        return self.conv(x)


class Net3(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x, y):
        return self.conv(x)

# SINGLE INPUT
net = Net()
ex_inp = torch.rand(1, 1, 3, 3)
torch.jit.trace(net, ex_inp)

# TWO INPUTS
net = Net2()
torch.jit.trace(net, (ex_inp, ex_inp))

# DICT (**kwargs)
# fails
# net = Net3()
# torch.jit.trace(net, dict(x=ex_inp, y=ex_inp))

Here is an example. tracing supports single input and tuple, which gets unrolled to multiple positional args. In these two cases, you can use the Lightning self.example_input_array. However, dicts will not be passed as kwargs, and instead as a single input. In Lightning however, a dict would mean **kwargs.

I see several ways to handle it:

  1. leave as is, user needs to know how self.example_input_array works
  2. error when self.example_input_array is a dict
  3. do not even use self.example_input_array, and require the user to give inputs to the method directly

Then there is a second issue. You should use the pytorch_lightning.utilities.apply_func.move_data_to_device to move the example input to the device, since it could be a tuple.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 & 2 could be combined by raising a warning instead of an error. From PL's side throw a warning similar to:

self.example_input_array cannot be a dict. Please provide a sample Tensor/Tuple to example_inputs as argument, or set self.example_input_array to a Tensor/Tuple.

Then output the actual error produced by .trace().
If in the future .trace() would be updated to support a dict, there is no need for a change (except removing the warning) on PL's side.

Personally, PL is for me about removing boilerplate code. Since self.example_input_array is already a thing in PL, it's better to use it. Therefore, I would advise against option 3.
I haven't used self.example_input_array personally yet, but in how many projects would this be a dict?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, makes sense.
Would you like to follow up on this with a PR? Would greatly appreciate this. For me the main concern is to properly move the input to the device with the function I referenced. For the way inputs are passed in, I don't have a strong opinon.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDK how much future support there will be for tracing vs scripting (scripting is strongly recommended). Rather than adding more trace support at the top-level of the PL module, why not override to_torchscript in your lightning module to determine how you want to export? then you have way more flexibility with tracing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli Ok, I'll follow up with another pull request using the move_data_to_device function.

@ananthsub edit moved my comment to the feature request, as it is a more relevant place for this discussion: #4140

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli I addressed your issues in a follow-up pull request (could not be added to this one due to it already being merged):
#4360

**kwargs)
else:
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:"
f"{method}")
self.train(mode)

if file_path is not None:
torch.jit.save(scripted_module, file_path)
torch.jit.save(torchscript_module, file_path)

return scripted_module
return torchscript_module

@property
def hparams(self) -> Union[AttributeDict, str]:
Expand Down
16 changes: 16 additions & 0 deletions tests/models/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@ def test_torchscript_input_output(modelclass):
assert torch.allclose(script_output, model_output)


@pytest.mark.parametrize("modelclass", [
EvalModelTemplate,
ParityModuleRNN,
BasicGAN,
])
def test_torchscript_input_output_trace(modelclass):
""" Test that traced LightningModule forward works. """
model = modelclass()
script = model.to_torchscript(method='trace')
assert isinstance(script, torch.jit.ScriptModule)
model.eval()
model_output = model(model.example_input_array)
script_output = script(model.example_input_array)
assert torch.allclose(script_output, model_output)


@pytest.mark.parametrize("device", [
torch.device("cpu"),
torch.device("cuda", 0)
Expand Down