I have a pytorch tensor A like below:
A =
tensor([[ 4, 3, 3, ..., 0, 0, 0],
[ 13, 4, 13, ..., 0, 0, 0],
[707, 707, 4, ..., 0, 0, 0],
...,
[ 7, 7, 7, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 0],
[195, 195, 195, ..., 0, 0, 0]], dtype=torch.int32)
I would like to:
I can imagine doing:
zero_list = []
for j in range(A.size()[1]):
if torch.sum(A[:,j]) == 0:
zero_list = zero_list.append(j)
to identify the columns that only has 0 for its elements but I am not sure how to delete such columns filled with 0 from the original tensor.
How can I delete the columns with zero from a pytorch tensor based on the index number?
Thank you,
Identify all the columns whose all of its entries are equal to 0
non_empty_mask = A.abs().sum(dim=0).bool()
This sums up over the absolute values of each column and then converts the result to a boolean, i.e. False if the sum is zero and True otherwise.
Delete only those columns that has all of their entries equal to 0
A[:,non_empty_mask]
This simply applies the mask to the original tensor, i.e. it keeps the rows where non_empty_mask is True.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With