Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to get rid of every column that are filled with zero from a Pytorch tensor?

I have a pytorch tensor A like below:

A = 
tensor([[  4,   3,   3,  ...,   0,   0,   0],
        [ 13,   4,  13,  ...,   0,   0,   0],
        [707, 707,   4,  ...,   0,   0,   0],
        ...,
        [  7,   7,   7,  ...,   0,   0,   0],
        [  0,   0,   0,  ...,   0,   0,   0],
        [195, 195, 195,  ...,   0,   0,   0]], dtype=torch.int32)

I would like to:

  • identify all the columns whose all of its entries are equal to 0
  • delete only those columns that has all of their entries equal to 0

I can imagine doing:

zero_list = []
for j in range(A.size()[1]):
    if torch.sum(A[:,j]) == 0:
         zero_list = zero_list.append(j)

to identify the columns that only has 0 for its elements but I am not sure how to delete such columns filled with 0 from the original tensor.

How can I delete the columns with zero from a pytorch tensor based on the index number?

Thank you,

like image 554
chico0913 Avatar asked Oct 15 '25 20:10

chico0913


1 Answers

Identify all the columns whose all of its entries are equal to 0

non_empty_mask = A.abs().sum(dim=0).bool()

This sums up over the absolute values of each column and then converts the result to a boolean, i.e. False if the sum is zero and True otherwise.

Delete only those columns that has all of their entries equal to 0

A[:,non_empty_mask]

This simply applies the mask to the original tensor, i.e. it keeps the rows where non_empty_mask is True.

like image 186
Aydo Avatar answered Oct 17 '25 09:10

Aydo