pythonpytorchtensortorchtext

How to get rid of every column that are filled with zero from a Pytorch tensor?


I have a pytorch tensor A like below:

A = 
tensor([[  4,   3,   3,  ...,   0,   0,   0],
        [ 13,   4,  13,  ...,   0,   0,   0],
        [707, 707,   4,  ...,   0,   0,   0],
        ...,
        [  7,   7,   7,  ...,   0,   0,   0],
        [  0,   0,   0,  ...,   0,   0,   0],
        [195, 195, 195,  ...,   0,   0,   0]], dtype=torch.int32)

I would like to:

I can imagine doing:

zero_list = []
for j in range(A.size()[1]):
    if torch.sum(A[:,j]) == 0:
         zero_list = zero_list.append(j)

to identify the columns that only has 0 for its elements but I am not sure how to delete such columns filled with 0 from the original tensor.

How can I delete the columns with zero from a pytorch tensor based on the index number?

Thank you,


Solution

  • It makes more sense to index the columns you want to keep instead of what you want to delete.

    valid_cols = []
    for col_idx in range(A.size(1)):
        if not torch.all(A[:, col_idx] == 0):
            valid_cols.append(col_idx)
    A = A[:, valid_cols]
    

    Or a little more cryptically

    valid_cols = [col_idx for col_idx, col in enumerate(torch.split(A, 1, dim=1)) if not torch.all(col == 0)]
    A = A[:, valid_cols]