Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Verify convolution theorem using pytorch

Basically this theorem is formulated as below:

F(f*g) = F(f)xF(g)

I know this theorem but I just simply cannot reproduce the result by using pytorch.

Below is a reproducable code:

import torch
import torch.nn.functional as F

# calculate f*g
f = torch.ones((1,1,5,5))
g = torch.tensor(list(range(9))).view(1,1,3,3).float()
conv = F.conv2d(f, g, bias=None, padding=2)

# calculate F(f*g)
F_fg = torch.rfft(conv, signal_ndim=2, onesided=False)

# calculate F x G
f = f.squeeze()
g = g.squeeze()

# need to pad into at least [w1+w2-1, h1+h2-1], which is 7 in our case.
size = f.size(0) + g.size(0) - 1 

f_new = torch.zeros((7,7))
g_new = torch.zeros((7,7))

f_new[1:6,1:6] = f
g_new[2:5,2:5] = g

F_f = torch.rfft(f_new, signal_ndim=2, onesided=False)
F_g = torch.rfft(g_new, signal_ndim=2, onesided=False)
FxG = torch.mul(F_f, F_g)

print(FxG - F_fg)

here is the result for print(FxG - F_fg)

tensor([[[[[ 0.0000e+00,  0.0000e+00],
       [ 4.1426e+02,  1.7270e+02],
       [-3.6546e+01,  4.7600e+01],
       [-1.0216e+01, -4.1198e+01],
       [-1.0216e+01, -2.0223e+00],
       [-3.6546e+01, -6.2804e+01],
       [ 4.1426e+02, -1.1427e+02]],

      ...

      [[ 4.1063e+02, -2.2347e+02],
       [-7.6294e-06,  2.2817e+01],
       [-1.9024e+01, -9.0105e+00],
       [ 7.1708e+00, -4.1027e+00],
       [-2.6739e+00, -1.1121e+01],
       [ 8.8471e+00,  7.1710e+00],
       [ 4.2528e+01,  9.7559e+01]]]]])

and you can see that the difference is not always 0.

can someone tell me why and how to do this properly?

Thanks

like image 986
BarCodeReader Avatar asked Dec 31 '22 06:12

BarCodeReader


1 Answers

So I took a closer look at what you've done so far. I've identified three sources of error in your code. I'll try to sufficiently address each of them here.

1. Complex arithmetic

PyTorch doesn't currently support multiplication of complex numbers (AFAIK). The FFT operation simply returns a tensor with a real and imaginary dimension. Instead of using torch.mul or the * operator we need to explicitly code complex multiplication.

(a + ib) * (c + id) = (a*c - b*d) + i(a*d + b*c)

2. The definition of convolution

The definition of "convolution" often used in CNN literature is actually different from the definition used when discussing the convolution theorem. I won't go into detail, but the theoretical definition flips the kernel before sliding and multiplying. Instead, the convolution operation in pytorch, tensorflow, caffe, etc... doesn't do this flipping.

To account for this we can simply flip g (both horizontally and vertically) before applying the FFT.

3. Anchor position

The anchor-point when using the convolution theorem is assumed to be the upper left corner of the padded g. Again, I won't go into detail about this but it's how the math works out.


The second and third point may be easier to understand with an example. Suppose you used the following g

[1 2 3]
[4 5 6]
[7 8 9]

instead of g_new being

[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 1 2 3 0 0]
[0 0 4 5 6 0 0]
[0 0 7 8 9 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]

it should actually be

[5 4 0 0 0 0 6]
[2 1 0 0 0 0 3]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[8 7 0 0 0 0 9]

where we flip the kernel vertically and horizontally, then apply circular shift so that the center of the kernel is in the upper left corner.


I ended up rewriting most of your code and generalizing it a bit. The most complex operation is defining g_new properly. I decided to use a meshgrid and modulo arithmetic to simultaneously flip and shift the indices. If something here doesn't make sense to you please leave a comment and I'll try to clarify.

import torch
import torch.nn.functional as F

def conv2d_pyt(f, g):
    assert len(f.size()) == 2
    assert len(g.size()) == 2

    f_new = f.unsqueeze(0).unsqueeze(0)
    g_new = g.unsqueeze(0).unsqueeze(0)

    pad_y = (g.size(0) - 1) // 2
    pad_x = (g.size(1) - 1) // 2

    fcg = F.conv2d(f_new, g_new, bias=None, padding=(pad_y, pad_x))
    return fcg[0, 0, :, :]

def conv2d_fft(f, g):
    assert len(f.size()) == 2
    assert len(g.size()) == 2

    # in general not necessary that inputs are odd shaped but makes life easier
    assert f.size(0) % 2 == 1
    assert f.size(1) % 2 == 1
    assert g.size(0) % 2 == 1
    assert g.size(1) % 2 == 1

    size_y = f.size(0) + g.size(0) - 1
    size_x = f.size(1) + g.size(1) - 1

    f_new = torch.zeros((size_y, size_x))
    g_new = torch.zeros((size_y, size_x))

    # copy f to center
    f_pad_y = (f_new.size(0) - f.size(0)) // 2
    f_pad_x = (f_new.size(1) - f.size(1)) // 2
    f_new[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x] = f

    # anchor of g is 0,0 (flip g and wrap circular)
    g_center_y = g.size(0) // 2
    g_center_x = g.size(1) // 2
    g_y, g_x = torch.meshgrid(torch.arange(g.size(0)), torch.arange(g.size(1)))
    g_new_y = (g_y.flip(0) - g_center_y) % g_new.size(0)
    g_new_x = (g_x.flip(1) - g_center_x) % g_new.size(1)
    g_new[g_new_y, g_new_x] = g[g_y, g_x]

    # take fft of both f and g
    F_f = torch.rfft(f_new, signal_ndim=2, onesided=False)
    F_g = torch.rfft(g_new, signal_ndim=2, onesided=False)

    # complex multiply
    FxG_real = F_f[:, :, 0] * F_g[:, :, 0] - F_f[:, :, 1] * F_g[:, :, 1]
    FxG_imag = F_f[:, :, 0] * F_g[:, :, 1] + F_f[:, :, 1] * F_g[:, :, 0]
    FxG = torch.stack([FxG_real, FxG_imag], dim=2)

    # inverse fft
    fcg = torch.irfft(FxG, signal_ndim=2, onesided=False)

    # crop center before returning
    return fcg[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x]


# calculate f*g
f = torch.randn(11, 7)
g = torch.randn(5, 3)

fcg_pyt = conv2d_pyt(f, g)
fcg_fft = conv2d_fft(f, g)

avg_diff = torch.mean(torch.abs(fcg_pyt - fcg_fft)).item()

print('Average difference:', avg_diff)

Which gives me

Average difference: 4.6866085767760524e-07

This is very close to zero. The reason we don't get exactly zero is simply due to floating point errors.

like image 147
jodag Avatar answered Jan 22 '23 15:01

jodag