I have a 2D torch tensor with nan values, I would like to get column minimum values and ignore cells with nan values.
import torch
data = torch.tensor([[ 0., 1., float('nan'), 3.],[ 4., 5., 6., 7.], [ 8., 9., 10., 11.]])
torch.min(data,0)
# What I would like to get is
# tensor([0., 1., 6., 3.])
Is there all suggestion? Thanks
You can do so by converting all the nan values in the tensor to an incredible high value and then running torch.min:
#I am replacing nan with 10^15
data = torch.nan_to_num(data,nan = 10e14)
data = torch.min(data,0)
torch.nan_to_num is used to turn all the nan values in your tensors to a certain value. Manipulate that as per your requirement.