pythonpytorch

Passing sections of tensor to function (can I avoid looping)?


If I have this tensor

ks=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] 

and a function

def myfunc(k:torch.tensor, s:float)->torch.tensor:
    # do something
    return torch.tensor([2.3, 3.4, 5.1, 2.2])

The return value of myfunc is not a constant of the above numbers. It's just to illustrate the tensor size. At the moment, I am looping as follows.

ks_indices = torch.tensor(list(range(4)))
for i in range(5):
   k_set = torch.index_select(ks, 0, ks_indices.view(-1))
   result_tensor = myfunc(k_set, ks[0])
   ks = torch.roll(ks, -1)

Is it possible to avoid the loop and make a single tensor call to get all five result_tensors as a 5x4 tensor in one go? To clarify: what I need is a sliding window that always processes the next four consecutive values together, then the start index is shifted by one, and so on.


Solution

  • You can avoid the loop by providing a 2-d "sliding window view" of ks to myfunc(). Such a view can be achieved using the tensor's unfold() method:

    import torch
    
    
    def myfunc(k: torch.tensor, s: float) -> torch.tensor:
        # Do something with the input values: result = k * s
        return k * s
    
    # Basically, the code from the question
    
    num_adjacent_indices = 4
    ks=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
    ks_indices = torch.tensor(list(range(num_adjacent_indices)))
    
    all_result_tensors = []  # Collect all intermediate results
    for i in range(5):
       k_set = torch.index_select(ks, 0, ks_indices.view(-1))
       result_tensor = myfunc(k_set, ks[0])
       all_result_tensors.append(result_tensor)
       ks = torch.roll(ks, -1)
    # Assemble intermediate results
    all_result_tensors_given = torch.stack(all_result_tensors)
    
    # Proposed adjustments
    
    num_adjacent_indices = 4
    ks=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
    
    ks_rolling = ks.unfold(0, num_adjacent_indices, 1)  # Sliding window view
    all_result_tensors_proposed = myfunc(ks_rolling, ks_rolling[:, 0:1])
    
    # Compare results
    assert torch.allclose(all_result_tensors_given, all_result_tensors_proposed)
    

    Note that I made some adjustments to the code from your question, in particular, I:

    Note that both the proposed code and the adjusted code from your question work with the same implementation of myfunc().