Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Replace all nonzero values by zero and all zero values by a specific value

Tags:

pytorch

I have a 3d tensor which contains some zero and nonzero values. I want to replace all nonzero values by zero and zero values by a specific value. How can I do that?

like image 245
Wasi Ahmad Avatar asked Jul 29 '17 02:07

Wasi Ahmad


Video Answer


2 Answers

Pretty much exactly how you would do it using numpy, like so:

tensor[tensor!=0] = 0

In order to replace zeros and non-zeros, you can just chain them together. Just be sure to use a copy of the tensor, since they get modified:

def custom_replace(tensor, on_zero, on_non_zero):
    # we create a copy of the original tensor, 
    # because of the way we are replacing them.
    res = tensor.clone()
    res[tensor==0] = on_zero
    res[tensor!=0] = on_non_zero
    return res

And use it like so:

>>>z 
(0 ,.,.) = 
  0  1
  1  3

(1 ,.,.) = 
  0  1
  1  0
[torch.LongTensor of size 2x2x2]

>>>out = custom_replace(z, on_zero=5, on_non_zero=0)
>>>out
(0 ,.,.) = 
  5  0
  0  0

(1 ,.,.) = 
  5  0
  0  5
[torch.LongTensor of size 2x2x2]
like image 58
entrophy Avatar answered Oct 24 '22 23:10

entrophy


Use

torch.where(<your_tensor> != 0, <tensor with zeroz>, <tensor with the value>)

Example:

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620,  0.3139],
         [ 0.3898, -0.7197],
         [ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
Tensor([[ 1.0000,  0.3139],
        [ 0.3898,  1.0000],
        [ 0.0478,  1.0000]])

See more at: https://pytorch.org/docs/stable/generated/torch.where.html

like image 25
Eli Safra Avatar answered Oct 24 '22 22:10

Eli Safra