I'm trying to grok the einops syntax for tensor reordering, but am somehow missing the point
If I have the following matrix:
mat = torch.randint(1, 10, (8,4))
I understand what the following command does:
rearrange(mat, '(h n) w -> (n h) w', n = 2)
But can't really wrap my head around the following ones:
rearrange(mat, '(n h) w -> (h n) w', n = 2)
rearrange(mat, '(n h) w -> (h n) w', n = 4)
Any help would be appreciated
rearrange(mat, '(h n) w -> (n h) w', n = 2)
and
rearrange(mat, '(n h) w -> (h n) w', n = 2)
are inversions of each other. If you can imagine what one does, second makes reverse transform
As for the latter, mat is 8x4
rearrange(mat, '(n h) w -> (h n) w', n = 4)
So you first split first dimension in 4x2 (below I ignore w dimension, because nothing special happens with it)
[0, 1, 2, 3, 4, 5, 6, 7]
to
[0, 1,
2, 3,
4, 5,
6, 7]
then you change order of axes to 2x4 (transpose)
[0, 2, 4, 6,
1, 3, 5, 7]
then merge two dimensions into one
[0, 2, 4, 5, 1, 3, 5, 7]
If you still don't feel how that works, take simpler examples like
rearrange(np.arange(50), '(h n) -> h n', h=5)
rearrange(np.arange(50), '(h n) -> h n', h=10)
rearrange(np.arange(50), '(h n) -> n h', h=10)
etc. So that you could track movement of each element in the matrix