I have a tensor a
that is of shape (n/f, c, c)
that I want to multiply by another tensor b
of shape (n, c, 1)
. Each row of a
represents f
rows of b
, such that the naiive way of implementing this would be to simply repeat each row of a
f
times before performing the multiplication:
n = 100
c = 5
f = 10
a = tf.constant(np.random.rand(n//f, c, c))
b = tf.constant(np.random.rand(n, c, c))
a_prime = tf.repeat(a, f, 0)
result = a_prime @ b
This works, but for large n
and f
I'm worried about the memory footprint of the repeat
. I could of course loop through each row and perform dot-products manually, but that would have implications on performance. Is there a better way?
We can do this by reshaping tensors and utilizing broadcasting, We can perform matrix multiplication more efficiently by eliminating the need for explicit repetition.
import tensorflow as tf
import numpy as np
n = 100
c = 5
f = 10
a = tf.constant(np.random.rand(n // f, c, c))
b = tf.constant(np.random.rand(n, c, c))
#Reshape a and b
a_reshaped = tf.reshape(a, (1, n // f, c, c))
b_reshaped = tf.reshape(b, (n, 1, c, c))
# perform matrix multiplication
result = tf.matmul(a_reshaped, b_reshaped)
result = tf.reduce_sum(result, axis=1)
print(result.shape)
output:
(100, 5, 5)