Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move example inputs to correct device when tracing module #4360

Merged
merged 14 commits into from
Oct 29, 2020
Merged

move example inputs to correct device when tracing module #4360

merged 14 commits into from
Oct 29, 2020

Conversation

NumesSanguis
Copy link
Contributor

Continuation of pull request (which has been merged): #4142
This pull request addresses issues raised in the previous pull request.

These pull request address the original feature request: #4140

…Tensor; not supported log error when example_inputs is a dict; commented docstring trace example
@pep8speaks
Copy link

pep8speaks commented Oct 26, 2020

Hello @NumesSanguis! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2020-10-29 05:20:59 UTC

@NumesSanguis
Copy link
Contributor Author

This pull request has a PEP8 issue. I don't know how to do a multi-line version of (line 1580):

>>> torch.jit.save(model.to_torchscript(method='trace', example_inputs=torch.randn(1, 64)), "model_trace.pt")  # doctest: +SKIP

Please let me know how to do this.

@codecov
Copy link

codecov bot commented Oct 26, 2020

Codecov Report

Merging #4360 into master will increase coverage by 0%.
The diff coverage is 100%.

@@          Coverage Diff           @@
##           master   #4360   +/-   ##
======================================
  Coverage      93%     93%           
======================================
  Files         111     111           
  Lines        8092    8127   +35     
======================================
+ Hits         7500    7547   +47     
+ Misses        592     580   -12     

Copy link
Member

@justusschock justusschock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Some minor comments though

pytorch_lightning/core/lightning.py Outdated Show resolved Hide resolved
@@ -1591,8 +1594,12 @@ def to_torchscript(
# if no example inputs are provided, try to see if model has example_input_array set
if example_inputs is None:
example_inputs = self.example_input_array
# dicts are not supported, so show the user an error; not raising an error to show the original error
if type(example_inputs) == dict:
log.error(f"`example_inputs` should be a Tensor or a tuple of Tensors, but got a dict.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
log.error(f"`example_inputs` should be a Tensor or a tuple of Tensors, but got a dict.")
raise TypeError(f"`example_inputs` should be a Tensor or a tuple of Tensors, but got a dict.")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justusschock I on purpose used log.error, because the blocker should not be to_torchscript, but the logic should stay at torch.jit.trace's side.

Say for example that TorchScript is updated at PyTorch's side (torch.jit.trace() does now accept a dict). Now the to_torchscript() function can actually support a dict, but Lightning is unnecessarily blocking it. Also, the original error might be more insightful than just a "dict is not accepted".

log.error still gives this little bit of extra information to the user, but only acts as a friendly informer, instead of taking over the guard position.
If torch.jit.trace() does accept a dict in an update, we only have a nagging error, but not a showstopper.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. However, currently this is a showstopper and we should treat it like that imo.

if this isn't any longer in the future, we can simply remove this.

Also I'd wonder why we need it at all, if we say, we let torchscript handle this...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justusschock This came forth from the discussion with @awaelchli here: #4142 (comment)

I see. However, currently this is a showstopper and we should treat it like that imo.

If the person reads the docs (generated from example_inputs: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]), it should already be clear that a dict should not be used. So I would argue it's not a show stopper, just something that can be encountered depending on how self.example_input_array is set.

Also I'd wonder why we need it at all, if we say, we let torchscript handle this...

It could be removed, but since in some cases self.example_input_array is set to a dict, it's just a nice heads-up to the user if this problem is encountered. Seeing this log.error just before the beginning of the Traceback is just a bit more user-friendly.

@mergify mergify bot requested a review from a team October 26, 2020 07:28
@NumesSanguis
Copy link
Contributor Author

@justusschock Do you know how to solve my PEP8 issue I commented about above? Then I'll add that together with the fix(es) in your comments.

@justusschock
Copy link
Member

@NumesSanguis what about

>>> torch.jit.save(
        model.to_torchscript(method="trace", example_inputs=torch.randn(1, 64)), "model_trace.pt"
    )  # doctest: +SKIP

?

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
@NumesSanguis
Copy link
Contributor Author

@NumesSanguis what about

>>> torch.jit.save(
        model.to_torchscript(method="trace", example_inputs=torch.randn(1, 64)), "model_trace.pt"
    )  # doctest: +SKIP

?

@justusschock My IDE (PyCharm) didn't show correct coloring for writing it like that, but I updated my repo with your code lines. Let's see if the checks pass.

@NumesSanguis
Copy link
Contributor Author

@justusschock Seems PyCharm was right, because the checks throw this error:

 1580             >>> torch.jit.save(
UNEXPECTED EXCEPTION: SyntaxError('unexpected EOF while parsing', ('<doctest pytorch_lightning.core.lightning.LightningModule.to_torchscript[4]>', 1, 16, 'torch.jit.save(\n'))
Traceback (most recent call last):

  File "/opt/hostedtoolcache/Python/3.6.12/x64/lib/python3.6/doctest.py", line 1330, in __run
    compileflags, 1), test.globs)

  File "<doctest pytorch_lightning.core.lightning.LightningModule.to_torchscript[4]>", line 1

    torch.jit.save(

                  ^

SyntaxError: unexpected EOF while parsing

Should I just remove this docstring test? Or allow for a PEP8 here?

@justusschock
Copy link
Member

@Borda are you familiar with multiline docchecks?

@NumesSanguis
Copy link
Contributor Author

I undid the commit with the following lines using a rebase.

>>> torch.jit.save(
        model.to_torchscript(method="trace", example_inputs=torch.randn(1, 64)), "model_trace.pt"
    )  # doctest: +SKIP

Instead I commented it as multi-line to stop the PEP8 error. Once I know how to do a multi-line docstring code example, I can undo this last PEP8 fix commit and fix it properly.

@NumesSanguis
Copy link
Contributor Author

NumesSanguis commented Oct 26, 2020

@justusschock It seems the secret was ...:

>>> torch.jit.save(model.to_torchscript(file_path="model_trace.pt", method='trace', # doctest: +SKIP
...                                     example_inputs=torch.randn(1, 64)))  # doctest: +SKIP

@NumesSanguis
Copy link
Contributor Author

If the current use of log.error is acceptable, this pull request can be merged.

Copy link
Member

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a changelog message?

@@ -1591,8 +1595,13 @@ def to_torchscript(
# if no example inputs are provided, try to see if model has example_input_array set
if example_inputs is None:
example_inputs = self.example_input_array
# dicts are not supported, so show the user an error; not raising an error to show the original error
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this comment, it is saying there is an error but there is not an error, what is it?
can we just remove the comment? the code should speak for itself

Copy link
Contributor

@rohitgr7 rohitgr7 Oct 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the inputs must be a tensor or tuple of tensors. IMO a better way to handle this is by wrapping the input tensor into a tuple and checking whether each element in the tuple is an instance of torch.Tensor or not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also with a test for the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli good catch, that doesn't make sense indeed.
@rohitgr7 I think trace() already does internally this wrapping of a torch.Tensor in a tuple, so I don't think we have to add that again on Lightning's side?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the comment to hopefully make more sense

Copy link
Contributor

@rohitgr7 rohitgr7 Oct 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

give me sometime. Need to check what's the actual issue here. Is there something wrong from pytorch side or we are doing something wrong here? In the meantime can you open an issue on pytorch forums if possible? Maybe we can get a quick response there :) Would be good to resolve all issues in this PR itself to avoid any issues in the future related to to_torchscript. Also will make similar changes to to_onnx #4378.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue is already in the master, since this pull request was already merged: #4140
This pull request is just to add some quality of life changes to the previous one. If we merge this one, it's much easier for other people to reproduce this issue, because they will have the same error output (this pull request does not add a new problem, just 1 step closer to solving it).

We can just keep the original issue (#4140) open, and discuss this issue there, as it would be more easy to find compared to this comment thread. Then we can point a PyTorch forum issue to there. A new pull request can then target that specific dict improvement (which might be very deep), instead of making this PR huge.

Honestly, I would like to make all parts work nicely, but I'm not affected by the dict issue, and I already spend too much time on this pull request. The previous pull request already added everything needed for my use case, but this pull request is just an extra to make the previous one a little bit less rough.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok cool. Then let's remove the check for Mapping and merge this one since it doesn't throw any error with dict :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohitgr7 Thanks. The logger error has been removed :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohitgr7 I put a summary of the Dict issue here: #4140 (comment)
which should make the discussion a bit more visible for others.

pytorch_lightning/core/lightning.py Outdated Show resolved Hide resolved
pytorch_lightning/core/lightning.py Outdated Show resolved Hide resolved
@mergify mergify bot requested a review from a team October 27, 2020 08:40
@awaelchli awaelchli changed the title Quality of life changes to previous TorchScript trace merged pull request move example inputs to correct device when tracing module Oct 27, 2020
@awaelchli awaelchli added this to the 1.0.x milestone Oct 27, 2020
@awaelchli awaelchli added feature Is an improvement or enhancement torchscript labels Oct 27, 2020
@mergify mergify bot requested a review from a team October 27, 2020 08:49
@mergify mergify bot requested a review from a team October 27, 2020 13:00
NumesSanguis and others added 2 commits October 28, 2020 10:08
Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
@rohitgr7 rohitgr7 merged commit 9cfd299 into Lightning-AI:master Oct 29, 2020
Borda pushed a commit that referenced this pull request Nov 4, 2020
* use move_data_to_device instead of to; docstring also allow tuple of Tensor; not supported log error when example_inputs is a dict; commented docstring trace example

* Use isinstance to check if example_inputs is a Mapping, instead of type

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* import Mapping for isinstance check

* multi-line docstring code to test TorchScript trace()

* Fix PEP8 f-string is missing placeholders

* minor code style improvements

* Use (possibly user overwritten) transfer_batch_to_device instead of move_data_to_device

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* fixed weird comment about trace() log error

* Remove unused import

Co-authored-by: Jeff Yang <ydcjeff@outlook.com>

* Remove logger warning about dict not example_inputs not supported by trace

Co-authored-by: stef-ubuntu <stef@webempath.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
(cherry picked from commit 9cfd299)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants