pythontensorflowlinear-algebrainversion

How to check if a matrix is invertible in tensorflow?


In my Tensorflow graph, I would like to invert a matrix if it is invertible do something with it. If it is not invertible, the, I'd like to do something else.

I could not find any way to check if the matrix is invertible in order to do something like :

is_invertible = tf.is_invertible(mat)
tf.cond(is_invertible, f1, f2)

Is there such a thing as an is_invertible function in Tensorflow ? I also considered using the fact that Tensorflow raises (not each time though) an InvalidArgumentError when I try to invert a non-nvertible matrix, but I couldn't take advantage of this.


Solution

  • As proposed in Efficient & pythonic check for singular matrix, you can check the condition number. Unfortunately, this is not currently implemented in TensorFlow as such, but it is not difficult to emulate the basic implementation of np.linalg.cond:

    import math
    import tensorflow as tf
    
    # Based on np.linalg.cond(x, p=None)
    def tf_cond(x):
        x = tf.convert_to_tensor(x)
        s = tf.linalg.svd(x, compute_uv=False)
        r = s[..., 0] / s[..., -1]
        # Replace NaNs in r with infinite unless there were NaNs before
        x_nan = tf.reduce_any(tf.is_nan(x), axis=(-2, -1))
        r_nan = tf.is_nan(r)
        r_inf = tf.fill(tf.shape(r), tf.constant(math.inf, r.dtype))
        tf.where(x_nan, r, tf.where(r_nan, r_inf, r))
        return r
    
    def is_invertible(x, epsilon=1e-6):  # Epsilon may be smaller with tf.float64
        x = tf.convert_to_tensor(x)
        eps_inv = tf.cast(1 / epsilon, x.dtype)
        x_cond = tf_cond(x)
        return tf.is_finite(x_cond) & (x_cond < eps_inv)
    
    m = [
        # Invertible matrix
        [[1., 2., 3.],
         [6., 5., 4.],
         [7., 7., 8.]],
        # Non-invertible matrix
        [[1., 2., 3.],
         [6., 5., 4.],
         [7., 7., 7.]],
    ]
    with tf.Graph().as_default(), tf.Session() as sess:
        print(sess.run(is_invertible(m)))
        # [ True False]