I want to index into the last axis of a tensor with an arbitrary shape, except for the last which is 2.
e.g. Let x be of the shape (1,2,2). Index to the last axis by
x_0 = x[:, :, 0] # x_0, x_1 shapes are (1,2)
x_1 = x[:, :, 1]
e.g. Let x be of the shape (1,2,3,4,2). Index to the last axis by
x_0 = x[:, :, :, :, 0] # x_0, x_1 shapes are (1,2,3,4)
x_1 = x[:, :, :, :, 1]
I've been unable to find any tensorflow function or usage for slicing an arbitrary shape.
I need a general method to index, such that I can always access the last axis for any shape tensor.
The slice syntax in tensorflow is very similar to numpy. You can use the ellipsis in that case:
Ellipsis expands to the number of
:
objects needed for the selection tuple to index all dimensions. In most cases, this means that length of the expanded selection tuple isx.ndim
. There may only be a single ellipsis present.
In your case,
x_0 = x[..., 0]
will index the last axis of a tensor with an arbitrary shape.
You can also look at the answer to the question: What is the difference between the slice
(:) and the ellipsis
(…) operators in numpy
?.