I'm trying to solve quite a simple task (I thought it to be), which is replicating a tensor in custom layer on TPU.
My input is 2 tensors of shapes A=(BS, H, n, C) and B = (BS, n, W, C), where n in my case can be (1, 3, 5, 7), but should probably also work with other numbers.
My task is to repeat both tensors A & B to shape (BS, H, W, C) and them sum them for the output. It would be easy if H (or W) were always divisible by n, but they are not. So the number of repeats for each slice (BS, H, 1, C) of A would differ. Thus the output is calculated using the following pseudocode:
for i in range(W):
A1[BS, H, i, C] = A[BS, H, floor(n*i/W), C]
I tried implementing it in a multiple ways:
class StripPoolingCombine(tf.keras.layers.Layer):
def __init__(self, n=1):
super(StripPoolingCombine, self).__init__()
self.n = n
def call(self, v, h, training=False):
H, W = v.shape[1], h.shape[2]
v_repeats = tf.unique_with_counts(tf.math.floor(tf.range(W) * self.n / W))[-1]
h_repeats = tf.unique_with_counts(tf.math.floor(tf.range(H) * self.n / H))[-1]
v = tf.repeat(v, repeats=v_repeats, axis=2)
h = tf.repeat(h, repeats=h_repeats, axis=1)
return Add()([v, h])
Or by replacing unique_with_counts
with the following logic:
tf.math.bincount(tf.cast(tf.math.floor(tf.range(W) * self.n / W), dtype=tf.int32)
f = tf.cast(tf.math.ceil(W / self.n), dtype=tf.int32)
s = tf.cast(tf.math.floor(W / self.n), dtype=tf.int32)
b = tf.cast(f!=s, dtype=tf.int32)
r = W - f - s * (self.n - 1)
x1 = s * tf.ones(self.n-1, dtype=tf.int32)
x2 = (1 - tf.range(r*2) % 2) * b
x2 = tf.pad(x2, paddings=[[0, self.n-r*2-1]])
x3 = tf.concat([[f], tf.add(x1, x2)], axis=0)
But as could be seen at Available TensorFlow Ops for TPU, it doesn't support dynamic tf.range
, tf.unique_with_counts
or tf.math.bincount
, and my implementations all result in errors when bulding a model and calling model.fit()
or model.predict()
. Yet I still hope that tensorflow has provided some way to work with dynamic shapes in a way that would suit my task, and won't me rewrite whole Ops module for such a trivial issue. Please, help!
Full reproducible example (using Colab TPU):
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Add
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
print(f'Running on TPU: {tpu.master()}')
except ValueError:
print('Could not connect to TPU')
tpu = None
if tpu:
try:
print('Initializing TPU...')
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.TPUStrategy(tpu)
print('TPU initialized!')
except Exception:
print('Failed to initialize TPU')
# class StripPoolingCombine(tf.keras.layers.Layer):
# def __init__(self, n=1):
# super(StripPoolingCombine, self).__init__()
# self.n = n
# def call(self, v, h, training=False):
# H, W = v.shape[1], h.shape[2]
# v_repeats = tf.unique_with_counts(tf.math.floor(tf.range(W) * self.n / W))[-1]
# h_repeats = tf.unique_with_counts(tf.math.floor(tf.range(H) * self.n / H))[-1]
# v = tf.repeat(v, repeats=v_repeats, axis=2)
# h = tf.repeat(h, repeats=h_repeats, axis=1)
# return Add()([v, h])
class StripPoolingCombine(tf.keras.layers.Layer):
def __init__(self, n=1):
super(StripPoolingCombine, self).__init__()
self.n = n
def call(self, v, h, training=False):
H, W = tf.shape(v)[1], tf.shape(h)[2]
f = tf.cast(tf.math.ceil(W / self.n), dtype=tf.int32)
s = tf.cast(tf.math.floor(W / self.n), dtype=tf.int32)
b = tf.cast(f!=s, dtype=tf.int32)
r = W - f - s * (self.n - 1)
x1 = s * tf.ones(self.n-1, dtype=tf.int32)
x2 = (1 - tf.range(r*2) % 2) * b
x2 = tf.pad(x2, paddings=[[0, self.n-r*2-1]])
x3 = tf.concat([[f], tf.add(x1, x2)], axis=0)
v = tf.repeat(v, repeats=x3, axis=2)
h = tf.repeat(h, repeats=x3, axis=1)
output = tf.add(v, h)
return output
def build_model(n=7):
v = Input(shape=(256, n, 3))
h = Input(shape=(n, 256, 3))
outputs = StripPoolingCombine()(v, h)
model = Model(inputs=[v, h], outputs=outputs)
return model
tf.keras.backend.clear_session()
with strategy.scope():
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999)
model = build_model()
model.compile(optimizer=optimizer, loss='mean_squared_error')
rng_1 = tf.random.uniform([1, 256, 7, 3])
rng_2 = tf.random.uniform([1, 7, 256, 3])
model.predict([rng_1, rng_2])
Use tf.gather
:
def call(self, v, h, training=False):
def out(A, H, axis):
r = tf.range(H)
inds = tf.floor(self.n * r / H)
inds = tf.cast(inds, tf.int32)
return tf.gather(A, inds, axis=axis)
H, W = tf.shape(v)[1], tf.shape(h)[2]
v = out(v, W, 2)
h = out(h, H, 1)
output = tf.add(v, h)
return output