I want to shift a tensor in a given axis. It's easy to do this in pandas or numpy. Like this:
import numpy as np
import pandas as pd
data = np.arange(0, 6).reshape(-1, 2)
pd.DataFrame(data).shift(1).fillna(0).values
Output is:
array([[0., 0.],
[0., 1.],
[2., 3.]])
But in tensorflow, the closest solution I found is tf.roll
. But it shift the last row to the first row. (I don't want that). So I have to use something like
tf.roll + tf.slice(remove the last row) + tf.concat(add tf.zeros to the first row)
.
It's really ugly.
Is there a better way to handle shift
in tensorflow or keras?
Thanks.
I think I find a better way for this problem.
We could use tf.roll
, then apply tf.math.multiply
to set the first row to zeros.
Sample code is as follows:
Original tensor:
A = tf.cast(tf.reshape(tf.range(27), (-1, 3, 3)), dtype=tf.float32)
A
Output:
<tf.Tensor: id=117, shape=(3, 3, 3), dtype=float32, numpy=
array([[[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.]],
[[ 9., 10., 11.],
[12., 13., 14.],
[15., 16., 17.]],
[[18., 19., 20.],
[21., 22., 23.],
[24., 25., 26.]]], dtype=float32)>
Shift (like pd.shift):
B = tf.concat((tf.zeros((1, 3)), tf.ones((2, 3))), axis=0)
C = tf.expand_dims(B, axis=0)
tf.math.multiply(tf.roll(A, 1, axis=1), C)
Output:
<tf.Tensor: id=128, shape=(3, 3, 3), dtype=float32, numpy=
array([[[ 0., 0., 0.],
[ 0., 1., 2.],
[ 3., 4., 5.]],
[[ 0., 0., 0.],
[ 9., 10., 11.],
[12., 13., 14.]],
[[ 0., 0., 0.],
[18., 19., 20.],
[21., 22., 23.]]], dtype=float32)>