google-cloud-platformgoogle-compute-enginegcloudtpu

Is there float32 precision on Google TPU?


I plan to use Google TPUs for scientific numerical simulation (finite element analysis).

That said, do TPUs support float32 for computation? If the precision is too small its unsuitable for me.

There is a notification here, but it says XLA converts it automatically.


Solution

  • Yes. It is not the default, but it is available ("Run a calculation on a Cloud TPU VM using JAX" at Google Cloud).

    You can set precision=jax.lax.Precision.HIGHEST on particular operations, such as matrix multiplication.

    It "uses even more [Matrix Multiply Units] passes to achieve full float32 precision".

    E.g.:

    jax.numpy.matmul(a, b, precision=Precision.HIGHEST)