If I have this tensor
ks=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
and a function
def myfunc(k:torch.tensor, s:float)->torch.tensor:
# do something
return torch.tensor([2.3, 3.4, 5.1, 2.2])
Return value of myfunc is not a constant of the above numbers. Its just to illustrate the tensor size. At the moment, I am looping as follows.
ks_indices = torch.tensor(list(range(4)))
for i in range(5):
k_set = torch.index_select(ks, 0, ks_indices.view(-1))
result_tensor = myfunc(k_set, ks[0])
ks = torch.roll(ks, -1)
Is it possible to avoid the loop and make a single tensor call to get all five result_tensor's as a 5x4 tensor in one go?
Your question is lacking a lot of detail, so I am making the following assumptions:
myfunc()
as k
are not in random order, but always adjacent.myfunc()
shift all indices by 1.s
always corresponds to k[0]
.If these assumptions (in particular, the first two of them) are indeed correct, then you can avoid the loop by providing a 2-d "sliding window view" of ks
to myfunc()
. Such a view can be achieved using the tensor's unfold()
method:
import torch
def myfunc(k: torch.tensor, s: float) -> torch.tensor:
# Do something with the input values: result = k * s
return k * s
# Basically, the code from the question
num_adjacent_indices = 4
ks=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
ks_indices = torch.tensor(list(range(num_adjacent_indices)))
all_result_tensors = [] # Collect all intermediate results
for i in range(5):
k_set = torch.index_select(ks, 0, ks_indices.view(-1))
result_tensor = myfunc(k_set, ks[0])
all_result_tensors.append(result_tensor)
ks = torch.roll(ks, -1)
# Assemble intermediate results
all_result_tensors_given = torch.stack(all_result_tensors)
# Proposed adjustments
num_adjacent_indices = 4
ks=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
ks_rolling = ks.unfold(0, num_adjacent_indices, 1) # Sliding window view
all_result_tensors_proposed = myfunc(ks_rolling, ks_rolling[:, 0:1])
# Compare results
assert torch.allclose(all_result_tensors_given, all_result_tensors_proposed)
Note that I made some adjustments to the code from your question, in particular, I:
myfunc()
, which enables a meaningful comparison of calling myfunc()
with different values;ks
from a list to a Tensor, or otherwise torch.index_select()
will not work;result_tensor
āall_result_tensors
) to enable a comparison.Note that both the proposed code and the adjusted code from your question work with the same implementation of myfunc()
.
Hope this helps. If that is not what you are looking for, however, please provide more detail in your question.