tensorflow2.0sparse-matrixmatrix-multiplicationeinsum

How to implement tensor multiplication of a sparse matrix with a 3D tensor in TensorFlow2.0?


I have a sparse matrix A of shape (60400, 32600) and a 3D tensor B of shape (32600, 60400, 64). What I want to do is to multiply A and B to get a tensor C of shape (60400, 60400, 64), then generate the final tensor of shape (60400, 64) which has the diagonal elements of C.

I tried tf.sparse.sparse_dense_matmul(A, B) to get the tensor C first, but I got the ValueError: Shape must be rank 2 but is rank 3.

Then I used tf.einsum('ij, jik->ik', A, B) but I got the TypeError: Tensors in list passed to 'inputs' of 'Einsum' Op have types [NOT CONVERTIBLE TO TENSOR, float32] that don't all match. I think this is because A is a sparse matrix, but I can't convert A from a sparse matrix to a dense matrix.

Any thoughts about this? I would be really appreciated.


Solution

  • I can't help you with the Tensorflow calculations directly, but can make some obsevations based on numpy and scipy.sparse.

    scipy.sparse matrices are limited to 2d. Recent versions make sparse arrays, but they are still limited to 2d.

    Sparse 2d matmul with a dense is also limited to 2d dense arrays. And the result is itself dense. It's only sparse with sparse that produces a sparse result. So even if sparse with 3d dense works, the (60400, 60400, 64) would be too large of a dense tensor. Also taking the diagonal of that throws out a whole lot values.

    numpy einsum` does not work with sparse arrays.

     einsum('ij, jik->ik', A, B) 
    

    suggests that if B can be rearranged to ijk you could treat i as a match dimension. With matmul you could do

    'i1j, ijk -> i1k'  (last j of A with 2nd to last j of B)
    

    So you could iterate on the i dimension of A and B, doing a matmul of A[i,:] (1d) and B[:,i,:] (2d). But since the last dimension of B, 64, is relatively small, maybe you could iterate on that.

    Anyways, you may need to spend a lot more time reading the tensorflow docs about sparse.