I have the following code segment to generate random samples. The generated samples
is a list
, where each entry of the list is a tensor
. Each tensor has two elements. I would like to extract the first element from all tensors in the list; and extract the second element from all tensors in the list as well. How to perform this kind of tensor slice operation
import torch
import pyro.distributions as dist
num_samples = 250
# note that both covariance matrices are diagonal
mu1 = torch.tensor([0., 5.])
sig1 = torch.tensor([[2., 0.], [0., 3.]])
dist1 = dist.MultivariateNormal(mu1, sig1)
samples1 = [pyro.sample('samples1', dist1) for _ in range(num_samples)]
samples1
I'd recommend torch.cat
with a list comprehension:
col1 = torch.cat([t[0] for t in samples1])
col2 = torch.cat([t[1] for t in samples1])
Docs for torch.cat
: https://pytorch.org/docs/stable/generated/torch.cat.html
ALTERNATIVELY
You could turn your list of 1D tensors into a single big 2D tensor using torch.stack
, then do a normal slice:
samples1_t = torch.stack(samples1)
col1 = samples1_t[:, 0] # : means all rows
col2 = samples1_t[:, 1]
Docs for torch.stack
: https://pytorch.org/docs/stable/generated/torch.stack.html