Suspected Corner Case in XLA Compilation - vectorized_sum, conditional, scatter_nd_update Complains about Dynamic Shape when we should Know it #66535
Labels
comp:xla
XLA
stat:awaiting response
Status - Awaiting response from author
TF 2.15
For issues related to 2.15.x
type:bug
Bug
Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
Yes
Source
binary
TensorFlow version
2.15.0; nightly
Custom code
Yes
OS platform and distribution
No response
Mobile device
No response
Python version
No response
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current behavior?
I seem to have hit what I think is a really weird corner-case in XLA compiling. My code may have a problem in it, this is the first time I'm using XLA, but, I seem to need a particular combination of factors for the bug to reproduce.
The compiler is complaining about not being able to operate using dynamic shapes, although these should be known (or inferred) by the graph compiler. I don't know if some errors that could have come up are elided away as the optimisations are applied, but reproducing it required a combination of a conditional, a vectorised map, and a matrix generated using reduce_sum, as well as passing these between functions. Change some or all of these elements to try and pin the bug down and the code compiled fine.
I'm working on an implementation of Soft Dynamic Time Warping as a loss function. For LSTM-format tensors (sequences in batch, timesteps in sequence, elements of feature vector), the basics of the algorithm are we compute a distance matrix using some metric across all timesteps of the predictions and the ground truth. Then we calculate a loss matrix iteratively, using this distance matrix, using two for loops.
In order to help parallelisation, we can use
vectorized_map
, as each sequence in the batch is independent. The issue appears to be coming out of this vectorisation, and on using a distance matrix constructed using a suitable expression (Euclidean distance is used as an example) - change one of those two things (just hardcode to calculate on one sequence in the batch and do not loop, or make the distance matrix using e.g.tf.ones
), and the error did not reproduce. The actual line at fault is a conditional assignment within the loops - if we have infinity in the loss matrix for the current(i, j)
pair, we need to change that to minus infinity. The compiler spits out a long trace (because of the vectorisation, it's repeating itself) that there is a reshape operation happening inside the translation of thescatter_nd_update
for which the shape is dynamic. This is particularly confusing given that I know all I am doing is assigning a scalar element to a scalar value within a matrix at known indices - when operating eagerly a reshape would be totally unnecessary, it essentially amounts to one comparison and a memory access on CPU. I can paste the traceback if required, but maybe it's easier to see using the nightly package in the colab notebook.I did try every possible way I could think of to inform the compiler that the shape of the tensors is known and fix the shape of intermediate values (we can infer all the required information from the shape of the ground truth/predictions; the rest is deterministic), as well as trying to force it to understand that the single element in the distance matrix and the scatter_nd_update are only making scalar changes. E.g. fixing the input_signature of the tf.function, hardcoding the shape as constants, etc.
I suspected the fact that the conditional branch is optionally assigning makes it not-very-functional, and this could be causing problems, but I have tried expanding it to a ternary-style if in a few ways and the same issue applies. Ultimately my only solution was to
scatter_nd_update
on every iteration with either the original value or the constant I need to replace it by - terrible for performance but I will just pray that the compiler can melt this away.I hope this provides some level of explanation, but as I am struggling to work out what the real issue is here, and any attempts to isolate the problem failed, afraid this is the most minimal example I could come up with; I'm not really sure what is going on.
As an aside, I'm sure it's a work-in-progress, but the traceback I get form the XLA compiler made this really, really hard work. It makes the template errors from the rangev3 library look like a breeze, and dumping the intermediaries just was not feasible to look through and work out anything sensible. It took about two days and I've gone round and round all of my code, before I finally pinned down the actual failure site to be the if statement. Aside from general improvements, given that static shapes are required, if it would be possible to get a traceback of how these shapes have been worked out would probably go a long way (in the same way I can trace an error in a C++ template and modern compilers will walk back each time and tell me how they deduced the particular type they used).
I'd probably also add that at a high level, the behaviour between symbolic tensors and XLA has caused rather a lot of friction. When using symbolic tensors, often the shape is not available, so we have to write our code in such a way that lazy operations that understand symbolic templates can function fine. But then, when it comes to XLA, we absolutely do need the shape, and it's impossible for me to guide the graph (when I absolutely do know what shape I am expecting), because of all the None values. As a user-programmer it really feels like the two systems are fighting each other and whilst I'd like to see how XLA performs and compares, I could probably achieve this more easily and quickly trying to write something using the HLO directly or using SYCL and making a custom Op. In an ideal world, I'd quite like to see some sort of DSL that builds upon the HLO, targets XLA, and that could interoperate with TensorFlow, as doing this in python and then with the graph compiler before converting that to XLA is a lot of overhead.
Standalone code to reproduce the issue
Relevant log output
No response
The text was updated successfully, but these errors were encountered: