I'm struggling with understanding the way torch.permute()
works. In general, how exactly is an n-D tensor permuted? An example with explaination for a 4-D or higher dimension tensor is highly appreciated.
I've search across the web but did not find any clearly explaination.
All tensors are contiguous 1D data lists in memory. What differs is the interface PyTorch provides us with to access them. This all revolves around the notion of stride, which is the way this data is navigated through. Indeed, on a higher level, we prefer to reason our data in higher dimensions by using tensor shapes. The following example and description are still valid for higher-dimensional tensors.
The permutation operator offers a way to change how you access the tensor data by seemingly changing the order of dimensions. Permutations return a view and do not require a copy of the original tensor (as long as you do not make the data contiguous), in other words, the permuted tensor shares the same underlying data.
At the user interface, permutation reorders the dimensions, which means the way this tensor is indexed changes depending on the order of dimensions supplied to the torch.Tensor.permute
method.
Take a simple 3D tensor example: x
shaped (I=3,J=2,K=2)
. Given i<I
, j<J
, and k<K
, x
could naturally be accessed via x[i,j,k]
. Concerning the underlying data being accessed, since the stride of x
is (JK=4, J=2, 1)
, then x[i,j,k]
corresponds to _x[i*JK+j*J+1]
where _x
is the underlying data of x
. By "corresponds to", it means the data array associated with tensor x
is being accessed with the index i*JK+j*J+1
.
If you now were to permute your dimensions, say y = x.permute(2,0,1)
, then the underlying data would remain the same (in fact data_ptr
would yield the same pointer) but the interface to y
is different! We have y
with a shape of (K,I,J)
and accessing y[i,j,k]
translate to x[k,i,j]
ie. dim=2
move to the front and dim=0,1
moved to the back... After permutation the stride is no longer the same, y
has a stride of (IJ, I, 1)
so y[i,j,k]
corresponds to _x[i*IJ+j*I+1]
.
To read more about views and strides, refer to this.