Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suspected Corner Case in XLA Compilation - vectorized_sum, conditional, scatter_nd_update Complains about Dynamic Shape when we should Know it #66535

Open
stellarpower opened this issue Apr 26, 2024 · 1 comment
Assignees
Labels
comp:xla XLA stat:awaiting response Status - Awaiting response from author TF 2.15 For issues related to 2.15.x type:bug Bug

Comments

@stellarpower
Copy link

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

Yes

Source

binary

TensorFlow version

2.15.0; nightly

Custom code

Yes

OS platform and distribution

No response

Mobile device

No response

Python version

No response

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

I seem to have hit what I think is a really weird corner-case in XLA compiling. My code may have a problem in it, this is the first time I'm using XLA, but, I seem to need a particular combination of factors for the bug to reproduce.

The compiler is complaining about not being able to operate using dynamic shapes, although these should be known (or inferred) by the graph compiler. I don't know if some errors that could have come up are elided away as the optimisations are applied, but reproducing it required a combination of a conditional, a vectorised map, and a matrix generated using reduce_sum, as well as passing these between functions. Change some or all of these elements to try and pin the bug down and the code compiled fine.

I'm working on an implementation of Soft Dynamic Time Warping as a loss function. For LSTM-format tensors (sequences in batch, timesteps in sequence, elements of feature vector), the basics of the algorithm are we compute a distance matrix using some metric across all timesteps of the predictions and the ground truth. Then we calculate a loss matrix iteratively, using this distance matrix, using two for loops.

In order to help parallelisation, we can use vectorized_map, as each sequence in the batch is independent. The issue appears to be coming out of this vectorisation, and on using a distance matrix constructed using a suitable expression (Euclidean distance is used as an example) - change one of those two things (just hardcode to calculate on one sequence in the batch and do not loop, or make the distance matrix using e.g. tf.ones), and the error did not reproduce. The actual line at fault is a conditional assignment within the loops - if we have infinity in the loss matrix for the current (i, j) pair, we need to change that to minus infinity. The compiler spits out a long trace (because of the vectorisation, it's repeating itself) that there is a reshape operation happening inside the translation of the scatter_nd_update for which the shape is dynamic. This is particularly confusing given that I know all I am doing is assigning a scalar element to a scalar value within a matrix at known indices - when operating eagerly a reshape would be totally unnecessary, it essentially amounts to one comparison and a memory access on CPU. I can paste the traceback if required, but maybe it's easier to see using the nightly package in the colab notebook.

I did try every possible way I could think of to inform the compiler that the shape of the tensors is known and fix the shape of intermediate values (we can infer all the required information from the shape of the ground truth/predictions; the rest is deterministic), as well as trying to force it to understand that the single element in the distance matrix and the scatter_nd_update are only making scalar changes. E.g. fixing the input_signature of the tf.function, hardcoding the shape as constants, etc.

I suspected the fact that the conditional branch is optionally assigning makes it not-very-functional, and this could be causing problems, but I have tried expanding it to a ternary-style if in a few ways and the same issue applies. Ultimately my only solution was to scatter_nd_update on every iteration with either the original value or the constant I need to replace it by - terrible for performance but I will just pray that the compiler can melt this away.

I hope this provides some level of explanation, but as I am struggling to work out what the real issue is here, and any attempts to isolate the problem failed, afraid this is the most minimal example I could come up with; I'm not really sure what is going on.

As an aside, I'm sure it's a work-in-progress, but the traceback I get form the XLA compiler made this really, really hard work. It makes the template errors from the rangev3 library look like a breeze, and dumping the intermediaries just was not feasible to look through and work out anything sensible. It took about two days and I've gone round and round all of my code, before I finally pinned down the actual failure site to be the if statement. Aside from general improvements, given that static shapes are required, if it would be possible to get a traceback of how these shapes have been worked out would probably go a long way (in the same way I can trace an error in a C++ template and modern compilers will walk back each time and tell me how they deduced the particular type they used).

I'd probably also add that at a high level, the behaviour between symbolic tensors and XLA has caused rather a lot of friction. When using symbolic tensors, often the shape is not available, so we have to write our code in such a way that lazy operations that understand symbolic templates can function fine. But then, when it comes to XLA, we absolutely do need the shape, and it's impossible for me to guide the graph (when I absolutely do know what shape I am expecting), because of all the None values. As a user-programmer it really feels like the two systems are fighting each other and whilst I'd like to see how XLA performs and compares, I could probably achieve this more easily and quickly trying to write something using the HLO directly or using SYCL and making a custom Op. In an ideal world, I'd quite like to see some sort of DSL that builds upon the HLO, targets XLA, and that could interoperate with TensorFlow, as doing this in python and then with the graph compiler before converting that to XLA is a lot of overhead.

Standalone code to reproduce the issue

https://colab.research.google.com/drive/1PrrUgSHfGjvTKAjoZtBQEJDk2ZwXuEoP?usp=drive_link

Please note that my own code is much more involved. I have tried to cut back an MRE and tested this locally, but there may be some changes between behaviour on this colab notebook and my local code - I've pasted it in and checked it quickly in colab but not picked it apart in the same way I have locally.

Relevant log output

No response

@google-ml-butler google-ml-butler bot added the type:bug Bug label Apr 26, 2024
@sushreebarsa sushreebarsa added comp:xla XLA TF 2.15 For issues related to 2.15.x labels May 6, 2024
@sushreebarsa
Copy link
Contributor

sushreebarsa commented May 7, 2024

@stellarpower I was able to replicate the issue reported here. Could you please confirm the results?
Also please try without vectorization? We will analyze if a non-vectorized version compiles correctly?
Thank you!

@sushreebarsa sushreebarsa added the stat:awaiting response Status - Awaiting response from author label May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:xla XLA stat:awaiting response Status - Awaiting response from author TF 2.15 For issues related to 2.15.x type:bug Bug
Projects
None yet
Development

No branches or pull requests

2 participants