pythonpython-3.xnumpymachine-learningpytorch

Random Choice with Pytorch?


I have a tensor of pictures, and would like to randomly select from it. I'm looking for the equivalent of np.random.choice().

import torch

pictures = torch.randint(0, 256, (1000, 28, 28, 3))

Let's say I want 10 of these pictures.


Solution

  • torch has no equivalent implementation of np.random.choice(), see the discussion here. The alternative is indexing with a shuffled index or random integers.

    To do it with replacement:

    1. Generate n random indices
    2. Index your original tensor with these indices
    pictures[torch.randint(len(pictures), (10,))]  
    

    To do it without replacement:

    1. Shuffle the index
    2. Take the n first elements
    indices = torch.randperm(len(pictures))[:10]
    
    pictures[indices]
    

    Read more about torch.randint and torch.randperm. Second code snippet is inspired by this post in PyTorch Forums.