I want to split a torch array by a list of indices.
For example say my input array is torch.arange(20)
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19])
and my list of indices is splits = [1,2,5,10]
Then my result would be:
(tensor([0]),
tensor([1, 2]),
tensor([3, 4, 5, 6, 7]),
tensor([ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]))
assume my input array is always long enough to bigger than the sum of my list of indices.
You could use tensor_split
on the cumulated sum of the splits
(e.g. with np.cumsum
), excluding the last chunk:
import torch
import numpy as np
t = torch.arange(20)
splits = [1,2,5,10]
t.tensor_split(np.cumsum(splits).tolist())[:-1]
Output:
(tensor([0]),
tensor([1, 2]),
tensor([3, 4, 5, 6, 7]),
tensor([ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]),
)