Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistencies between nn.functional.interpolate and tf.image.resize #18811

Closed
4 tasks
sayakpaul opened this issue Aug 30, 2022 · 15 comments
Closed
4 tasks

Inconsistencies between nn.functional.interpolate and tf.image.resize #18811

sayakpaul opened this issue Aug 30, 2022 · 15 comments
Labels

Comments

@sayakpaul
Copy link
Member

System Info

Upsampling of intermediate feature maps is common for computer vision models. In PyTorch, it's usually implemented using nn.functional.interpolate. For TensorFlow, it's usually tf.image.resize.

But there is an inconsistency between what these two methods yield (Colab Notebook). Sometimes the differences in their outputs are small enough to ignore (ViT, for example). But when interpolation is repeated several times in a model (MobileViT, for example), these small differences can add up and shake the final outputs of a model quite a bit. More details here.

@hollance wrote an amazing blog post discussing this issue.

Who can help?

@amyeroberts @gante @Rocketknight1

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The Colab Notebook mentioned in the description.

Expected behavior

We should work on a TF utility that yields the same output as nn.functional.interpolate.

@sayakpaul sayakpaul added the bug label Aug 30, 2022
@hollance
Copy link
Contributor

Note that align_corners=None does give the same result as tf.image.resize, to an absolute tolerance of 1e-6 or so:

upsampled_logits_pt = torch.nn.functional.interpolate(
    dummy_logits_pt, size=interp_shape, mode="bilinear", align_corners=None
)

@ariG23498
Copy link
Contributor

I had to face a similar probelem.

With the torch.nn.functional.interpolate:

  • On align_corners=True, the best option is to use the tf.compat.v1.image.resize
  • On align_corners=False, the tf.image.resize does the trick.

Here is a colab notebook that details the solution that I proposed.

@amyeroberts and @gante were against using the tf.compat.v1.image.resize for obvious reasons. @amyeroberts did come up with a solution which is documented in this comment. I hope this provides some value to this thread.

@sayakpaul
Copy link
Member Author

Thanks @hollance and @ariG23498 for chiming in.

When there's one / two interpolation ops, the small differences are fine but when they are done several times (as mentioned earlier), these differences compound up which creates mismatches in the final outputs.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@Rocketknight1
Copy link
Member

Pinging this to keep it open - we've implemented TF versions of Torch ops like AdaptivePool, so this might be something to explore as well. I hope a performant solution can be implemented with native ops (with or without XLA compilation) without needing to write our own CUDA, though!

@sayakpaul
Copy link
Member Author

Pinging this to keep it open - we've implemented TF versions of Torch ops like AdaptivePool, so this might be something to explore as well.

Could you point me to something relevant? Would love to see the new implementations.

@Rocketknight1
Copy link
Member

Hi @sayakpaul, I'm extremely sorry! I thought we'd implemented it in data2vec2 already, but checking the code it seems like we still have the old version there. Here's a much, much more performant version.

@Rocketknight1
Copy link
Member

The new version will also allow XLA compilation, unlike the original sparse implementation

@sayakpaul
Copy link
Member Author

Oh yeah, I remember this one. I need to update the data2vec code with your latest implementation. Reminded me of that.

Thanks, Matt!

@Rocketknight1
Copy link
Member

Although, checking out that notebook, it seems like Pytorch's interpolate and TF's resize are using the same algorithm, and the differences are mostly numerical; when I switch the dtype to float64, the max absolute differences is 1e-7.

I think this will make it extremely hard to bring the accuracies any closer - we would have to rewrite the internals of the algorithm so that they're not just mathematically equivalent, but they lose precision in the same places as well. I think we're stuck with the issue, unfortunately!

@sayakpaul
Copy link
Member Author

Fair enough. The tiny differences just add up when there's multiple such interpolations. Otherwise, it's not a big deal.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ariG23498
Copy link
Contributor

Ping.

I don't think it is resolved yet.

@sayakpaul
Copy link
Member Author

I doubt it will be apples-to-apples resolved ever, considering #18811 (comment).

We need to be aware when comparing predictions of models (PT vs. TF) with stacks of interpolations, especially focusing on what tolerances we're using.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants