Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is there an equivalent PyTorch function for `tf.nn.space_to_depth`

As the title says, is there an equivalent PyTorch function for tf.nn.space_to_depth?

like image 650
Priyatham Avatar asked Nov 07 '25 16:11

Priyatham


2 Answers

While torch.nn.functional.pixel_shuffle does exactly what tf.nn.depth_to_space does, PyTorch doesn't have any function to do the inverse operation similar to tf.nn.space_to_depth.

That being said, it is easy to implement space_to_depth using torch.nn.functional.unfold.

def space_to_depth(x, block_size):
    n, c, h, w = x.size()
    unfolded_x = torch.nn.functional.unfold(x, block_size, stride=block_size)
    return unfolded_x.view(n, c * block_size ** 2, h // block_size, w // block_size)
like image 122
Priyatham Avatar answered Nov 09 '25 04:11

Priyatham


Actually, @Priyatham's unfold is not correct. Unfold will enlarge/expand channel by block_size * block_size, but the space-to-depth requires duplicate channel by block_size * block_size (means copy channel by block_size * block_size)

So, the right way is to use eniops.rearrange():

result = einops.rearrange(x, 'b c (h p1) (w p2) -> b (p1 p2) h w', p1=block_size, p2=block_size)
like image 44
HeCao Avatar answered Nov 09 '25 04:11

HeCao



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!