Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand to_torchscript to support also TorchScript's trace method #4140

Closed
NumesSanguis opened this issue Oct 14, 2020 · 7 comments · Fixed by #4142
Closed

Expand to_torchscript to support also TorchScript's trace method #4140

NumesSanguis opened this issue Oct 14, 2020 · 7 comments · Fixed by #4142
Labels
feature Is an improvement or enhancement help wanted Open to be worked on

Comments

@NumesSanguis
Copy link
Contributor

🚀 Feature

Allow for the user to easily choose between TorchScript's script or trace method to create a module.

Motivation

While TorchScript's script method will work for simple models, it will not always work out of the box when models rely on Python variables to be set. This requires the user to manually annotate the model to not run into issues with script().

TorchScript's trace method on the other hand creates a traced module that is determined by running a Tensor through the network and tracks what happens during this process.
This always works, but loses design choices if present in the model.

Both script and trace have their use cases, and with a minimal extension of this function, both methods can be used.

Pitch

  • Add a method argument that can be set to either script or trace (default to script, which results in the current behaviour).
  • Add a example_inputs argument that defaults to None and can be set to any Tensor. If None is provided, this function will automatically try to use self.example_input_array. The example input is automatically send to the correct device.
    Note: example_inputs's name cannot be changed as this is the name of the argument trace() expects. If named otherwise, there can be a conflict with kwargs.

This change should not break any older scripts, as it by defaults uses script.

Alternatives

Make no change and require the user to overwrite this function to use trace.

Additional context

Please assign me for this request.

@NumesSanguis NumesSanguis added feature Is an improvement or enhancement help wanted Open to be worked on labels Oct 14, 2020
@awaelchli
Copy link
Member

awaelchli commented Oct 14, 2020

There is a problem with your PR regarding the inputs. The example inputs passed to the trace function is not the same as the LightningModule.example_input_array.
From the documentation of trace:

example_inputs (tuple) – A tuple of example inputs that will be passed to the function while tracing. The resulting trace can be run with inputs of different types and shapes assuming the traced operations support those types and shapes. example_inputs may also be a single Tensor in which case it is automatically wrapped in a tuple.

@NumesSanguis
Copy link
Contributor Author

@awaelchli I addressed this in your comment on the pull request revies: #4142 (comment)

@NumesSanguis
Copy link
Contributor Author

To move the discussion from #4142 (comment) to here @ananthsub 's

IDK how much future support there will be for tracing vs scripting (scripting is strongly recommended).

Could you elaborate on why scripting is strongly recommended? From the TorchScript documentation, scripting and tracing just focuses on different use cases, without a preference for either one. Scripting is positioned as useful when you need control-flow, while trace is simpler to use (no code changes).

Rather than adding more trace support at the top-level of the PL module, why not override to_torchscript in your lightning module to determine how you want to export? then you have way more flexibility with tracing

What flexibility would tracing need? Unlike scripting where you build in some logic, tracing is just throwing in an example batch, and get the resulting TorchScript module.

About "determine how you want to export?", could you elaborate? The to_torchscript() function supports file_path: Optional[str] argument, so you can choose where you save the module. Also the function returns Union[ScriptModule, Dict[str, ScriptModule]], which can be used however you like.


I would like to ask why to_torchscript should NOT implement tracing, and instead ask every user to implement this piece of engineering?

Lightning is also designed for the fast inference AI researchers and production teams need to scale up things like BERT and self-supervised learning. Lightning can automatically export to ONNX or TorchScript for those cases.

Imagine a team with a data scientist good at training a model (Python), but knowing little of engineering (in this example C++), and an engineer good at C++, but has no idea what the model does internally. The engineer just wants to TorchScript model and the Data Scientist has never brought a model to production, but heard that TorchScript is useful to transfer this model to production.

In case of scripting, likely the data scientist is the one who needs to figure out how TorchScript works and how to decorate his/her model to do proper scripting. Tracing on the other hand requires no extra knowledge, just call to_torchscript(method='trace', file_path="some_where", example_inputs=batch). The data scientist now has a model that can be given to the engineer by just calling 1 function.

@ananthsub
Copy link
Contributor

@suo is there guidance that PyTorch/JIT can share for tracing vs scripting? My understanding is that more use cases should adopt scripting, but maybe that's too naive on my part

@NumesSanguis
Copy link
Contributor Author

With the PR: #4360 being merged, example_inputs can now handle torch.Tensor and Tuple[torch.Tensor]:

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/lightning.py#L1599

While according to the [TorchScript docs (under Automatic Trace Checking)], example_inputs does not support a dict, @rohitgr7 has found out that TorchScript actually does support a dict.
However, if you try a dict now, you would get the error:

AttributeError                            Traceback (most recent call last)

<ipython-input-11-1f9f6fbe4f6c> in <module>()
----> 1 test_x(tmpdir)

10 frames

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in linear(input, weight, bias)
   1670         if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
   1671             return handle_torch_function(linear, tens_ops, input, weight, bias=bias)
-> 1672     if input.dim() == 2 and bias is not None:
   1673         # fused op is marginally faster
   1674         ret = torch.addmm(bias, input, weight.t())

AttributeError: 'dict' object has no attribute 'dim'

Who (Lighting/PyTorch/TorchScript) is the cause for this error still needs to be determined. Therefore, it's better to keep this issue open until Dicts are supported (or determined that Dicts cannot be used).

Related issue: #4378

@rohitgr7
Copy link
Contributor

The forward function above doesn't accept dict that's why that error is coming up.
if the forward function is something like:

def forward(self, batch):
    x = batch['x']
    return self.model(x)

and you would run torch.jit.trace() on this with dict example_inputs it won't throw an error and will give you a torchscript module. The reason I think that PyTorch specifies that it should be a Tensor or tuple of Tensors is because it is acceptable in other environments like C++ with no python dependency where there is no such thing like dict. Not sure if this is the correct reason since I haven't worked with torchscripts before, so need to check the reason whether it's a bug/documentation error or something else.

@suo
Copy link

suo commented Oct 29, 2020

@ananthsub Scripting is generally recommended. Since the tracer can only record observed tensor operations, there are a number of corner cases where the resulting graph may not generalize in surprising ways. Common pitfalls include: control flow, device-specific code, different sizes, etc.

That said, if tracing works (and tracing generally works well for a large class of models that don't have any control flow), then it is perfectly fine to use.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants