Skip to content

Backend framework comparison

Albert Zeyer edited this page Nov 9, 2021 · 9 revisions

We want to compare TensorFlow, PyTorch, JAX and maybe other similar frameworks here.

We do not want to compare higher-level frameworks like Keras here.

Side node: The classification low-level, mid-level and high-level framework comes from our tutorial (video, slides). Maybe what we call "backend framework" here corresponds most closely to "mid-level framework". But the distinction is not always clear. Here we mean all frameworks which could potentially be used as backend for RETURNN.

This comparison is specifically about their properties specifically when used as a backend framework for RETURNN, using the RETURNN principles and core features. Specifically:

  • RETURNN does a lot of extra checks and logic w.r.t. dimensions, dimension tags, sequence length, etc., all via Data. This is one of the main features. This probably would make it slow in eager mode, unless we optimize this away somehow (for all but the first step).
  • RETURNN performs automatic optimizations. E.g. it moves layers out of a recurrent loop when possible. This is another core feature. This however would also be slow to redo every single step. But this could be cashed somehow. However this is much more natural for graph mode.
  • RETURNN is specifically used for sequence classification, so we naturally have loops over sequences, so this is an important core operation which must be fast. When the loop is compiled and runs in native code (C++) (as you would usually have in graph-mode), this is probably faster than a pure Python based loop. Although when the number of loop iteration steps is low (10-20), this overhead is probably negligible.

Theano

  • Our initial backend.
  • The most widely used backend framework around ~2016-2017 maybe?
  • Graph-mode.
  • Performs optimizations on the computation graph. This could improve the training runtime but often would take a lot of time for optimizing and compiling the computation graph (in the order of minutes).
  • It supports non-contiguous tensors (by storing strides). Any op needing a contiguous tensor would make it contiguous explicitly before. The user usually never needs to care about this.
  • It supports inplace operations. The automatic optimization would find tensors which are not used after some op anymore and replace the op by an inplace op when possible. This can reduce memory and sometimes also runtime.

TensorFlow

  • Our second backend.
  • The most widely used backend framework around ~2017-2019 maybe?
  • Graph-mode first (TF 1), eager-mode was added later, and became default in TF 2, although graph-mode is still supported. Also, eager mode code can easily be converted into a computation graph (by tf.function), even including Python control flow like while and if. This is done by transformations on the Python AST.
  • Performs only minimal optimizations on the computation graph, such that this step is almost instant.
  • Does not support non-contiguous tensors (see here, here, here). This has a couple of implications: It will perform a copy of a tensor for operations like tf.transpose.
  • It does not support inplace operations. This implies that all ops are pure w.r.t. to a tensor (except for resource tensors including variables which have internal state), and tensor pointers can be passed around when possible (e.g. in reshape), and it uses ref counting.

PyTorch

  • Probably the most widely used framework since around ~2019?
  • Eager-mode is first class and maybe one reason people like it. It is possible to convert eager-mode code into graph-mode by JIT compiling it. The tracing does not support Python control flow logic like while or if but the torch.jit.script does support that.
  • Has a very clean API, including also higher-level API (all in torch.nn, the base class nn.Module, and also things like LSTM and Transformer). This is somewhat similar to TF Keras but arguably cleaner. This is another reason why people like it.
  • It supports non-contiguous tensors. transpose is always just a view, i.e. a very cheap op.
  • It supports explicit inplace operations but not automatic implicit. When the user uses them and the original value is needed for backpropagation, I think this will be an error.

JAX

  • Recently gained lots of interest in the community.
  • All based on Python code transformations. In the end, one final pass will convert it into TF XLA code.
  • This is basically like graph-mode.
  • Code transformation allow for some similar optimizations and ideas as we do in RETURNN, like abstracting away the batch dimension.