pythonpytorch

Passing sections of tensor to function (can I avoid looping)?


If I have this tensor

ks=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] 

and a function

def myfunc(k:torch.tensor, s:float)->torch.tensor:
    # do something
    return torch.tensor([2.3, 3.4, 5.1, 2.2])

Return value of myfunc is not a constant of the above numbers. Its just to illustrate the tensor size. At the moment, I am looping as follows.

ks_indices = torch.tensor(list(range(4)))
for i in range(5):
   k_set = torch.index_select(ks, 0, ks_indices.view(-1))
   result_tensor = myfunc(k_set, ks[0])
   ks = torch.roll(ks, -1)

Is it possible to avoid the loop and make a single tensor call to get all five result_tensor's as a 5x4 tensor in one go?


Solution

  • Your question is lacking a lot of detail, so I am making the following assumptions:

    If these assumptions (in particular, the first two of them) are indeed correct, then you can avoid the loop by providing a 2-d "sliding window view" of ks to myfunc(). Such a view can be achieved using the tensor's unfold() method:

    import torch
    
    
    def myfunc(k: torch.tensor, s: float) -> torch.tensor:
        # Do something with the input values: result = k * s
        return k * s
    
    # Basically, the code from the question
    
    num_adjacent_indices = 4
    ks=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
    ks_indices = torch.tensor(list(range(num_adjacent_indices)))
    
    all_result_tensors = []  # Collect all intermediate results
    for i in range(5):
       k_set = torch.index_select(ks, 0, ks_indices.view(-1))
       result_tensor = myfunc(k_set, ks[0])
       all_result_tensors.append(result_tensor)
       ks = torch.roll(ks, -1)
    # Assemble intermediate results
    all_result_tensors_given = torch.stack(all_result_tensors)
    
    # Proposed adjustments
    
    num_adjacent_indices = 4
    ks=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
    
    ks_rolling = ks.unfold(0, num_adjacent_indices, 1)  # Sliding window view
    all_result_tensors_proposed = myfunc(ks_rolling, ks_rolling[:, 0:1])
    
    # Compare results
    assert torch.allclose(all_result_tensors_given, all_result_tensors_proposed)
    

    Note that I made some adjustments to the code from your question, in particular, I:

    Note that both the proposed code and the adjusted code from your question work with the same implementation of myfunc().

    Hope this helps. If that is not what you are looking for, however, please provide more detail in your question.