I did not find any settings for data_format=channels_first
or data_format=channels_last
in FLAX modules ( which are based on JAX ).
On the contrary, TensorFlow does have that designation. Does the choice of data_format is irrelevant to JAX performance ?
Unfortunately, I did not find any kind of information on this subject.
JAX/Flax has no equivalent to the concept of data_format
as used in Tensorflow/Keras.