I wish to create a custom pooling layer which can efficiently work on GPUs.
For instance, I have following input tensor
in = <tf.Tensor: shape=(4, 5), dtype=float32, numpy=
array([[0., 1., 2., 3., 4.],
[5., 1., 7., 3., 2.],
[9., 9., 2., 3., 5.],
[2., 6., 2., 8., 4.]], dtype=float32)>
I wish to provide a list of column numbers over which I wish to perform pooling, for instance, I wish to perform max pooling over following column indices
pool_cols =
[<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 1], dtype=int32)>,
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([2, 3, 4], dtype=int32)>]
And the resultant pooled output will look like
pooled_out = <tf.Tensor: shape=(4, 2), dtype=float32, numpy=
array([[1., 4.],
[5., 7.],
[9., 5.],
[6., 8.]], dtype=float32)>
What would be the most efficient way to do this?
IIUC, you could try something like this using only tf
operations, but I'm not sure how efficient that will be on the GPU:
import tensorflow as tf
tensor = tf.constant([[0., 1., 2., 3., 4.],
[5., 1., 7., 3., 2.],
[9., 9., 2., 3., 5.],
[2., 6., 2., 8., 4.]])
pool_cols = [tf.constant([0, 1]), tf.constant([2, 3, 4])]
def column_max_pooling(tensor, pool_cols):
results = []
tensor_shape = tf.shape(tensor)
for col in pool_cols:
col_shape = tf.shape(col)
t = tf.gather_nd(tensor, tf.transpose(tf.stack([tf.tile(tf.range(tensor_shape[0]), [col_shape[0]]), tf.repeat(col, [tensor_shape[0]])])))
t = tf.reduce_max(tf.transpose(tf.reshape(t, (col_shape[0], tensor_shape[0]))), axis=-1, keepdims=True)
results.append(t)
return tf.concat(results, axis=-1)
print(column_max_pooling(tensor, pool_cols))
tf.Tensor(
[[1. 4.]
[5. 7.]
[9. 5.]
[6. 8.]], shape=(4, 2), dtype=float32)
If you can guarantee the order of pool_cols
, you could also try using tf.math.unsorted_segment_max
:
import tensorflow as tf
tensor = tf.constant([[0., 1., 2., 3., 4.],
[5., 1., 7., 3., 2.],
[9., 9., 2., 3., 5.],
[2., 6., 2., 8., 4.]])
pool_cols = [tf.constant([0, 1]), tf.constant([2, 3, 4])]
result = tf.transpose(tf.math.unsorted_segment_max(tf.transpose(tensor), tf.concat([tf.repeat(idx, tf.shape(col)[0])for idx, col in enumerate(pool_cols)], axis=0), num_segments=len(pool_cols)))
print(result)
tf.Tensor(
[[1. 4.]
[5. 7.]
[9. 5.]
[6. 8.]], shape=(4, 2), dtype=float32)