I have my data tensor which is of the shape [batch_size,512]
and I have a constant matrix with values only of 0 and 1 which has the shape [256,512]
.
I would like to compute efficiently for each batch the sum of the products of my vector (second dimension of the data tensor) only for the entries which are 1 and not 0.
An explaining example:
let us say I have 1-sized batch: the data tensor has the values [5,4,3,7,8,2]
and my constant matrix has the values:
[0,1,1,0,0,0]
[1,0,0,0,0,0]
[1,1,1,0,0,1]
it means that I would like to compute for the first row 4*3
, for the second 5
and for the third 5*4*3*2
.
and in total for this batch, I get 4*3+5+5*4*3*2
which equals to 137.
Currently, I do it by iterating over all the rows, compute elementwise the product of my data and constant-matrix-row and then sum, which runs pretty slow.
How about something like this:
import tensorflow as tf
# Two-element batch
data = [[5, 4, 3, 7, 8, 2],
[4, 2, 6, 1, 6, 8]]
mask = [[0, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 1]]
with tf.Graph().as_default(), tf.Session() as sess:
# Data as tensors
d = tf.constant(data, dtype=tf.int32)
m = tf.constant(mask, dtype=tf.int32)
# Tile data as needed
dd = tf.tile(d[:, tf.newaxis], (1, tf.shape(m)[0], 1))
mm = tf.tile(m[tf.newaxis, :], (tf.shape(d)[0], 1, 1))
# Replace values with 1 wherever the mask is 0
w = tf.where(tf.cast(mm, tf.bool), dd, tf.ones_like(dd))
# Multiply row-wise and sum
result = tf.reduce_sum(tf.reduce_prod(w, axis=-1), axis=-1)
print(sess.run(result))
Output:
[137 400]
EDIT:
If you input data is a single vector then you would just have:
import tensorflow as tf
# Two-element batch
data = [5, 4, 3, 7, 8, 2]
mask = [[0, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 1]]
with tf.Graph().as_default(), tf.Session() as sess:
# Data as tensors
d = tf.constant(data, dtype=tf.int32)
m = tf.constant(mask, dtype=tf.int32)
# Tile data as needed
dd = tf.tile(d[tf.newaxis], (tf.shape(m)[0], 1))
# Replace values with 1 wherever the mask is 0
w = tf.where(tf.cast(m, tf.bool), dd, tf.ones_like(dd))
# Multiply row-wise and sum
result = tf.reduce_sum(tf.reduce_prod(w, axis=-1), axis=-1)
print(sess.run(result))
Output:
137