I am trying to import weights saved from a Tensorflow model to PyTorch. So far the results have been very similar. I ran into a snag when the model calls for conv2d
with stride=2
.
To verify the mismatch, I set up a very simple comparison between TF and PyTorch. First, I compare conv2d
with stride=1
.
import tensorflow as tf
import numpy as np
import torch
import torch.nn.functional as F
np.random.seed(0)
sess = tf.Session()
# Create random weights and input
weights = torch.empty(3, 3, 3, 8)
torch.nn.init.constant_(weights, 5e-2)
x = np.random.randn(1, 3, 10, 10)
weights_tf = tf.convert_to_tensor(weights.numpy(), dtype=tf.float32)
# PyTorch adopts [outputC, inputC, kH, kW]
weights_torch = torch.Tensor(weights.permute((3, 2, 0, 1)))
# Tensorflow defaults to NHWC
x_tf = tf.convert_to_tensor(x.transpose((0, 2, 3, 1)), dtype=tf.float32)
x_torch = torch.Tensor(x)
# TF Conv2D
tf_conv2d = tf.nn.conv2d(x_tf,
weights_tf,
strides=[1, 1, 1, 1],
padding="SAME")
# PyTorch Conv2D
torch_conv2d = F.conv2d(x_torch, weights_torch, padding=1, stride=1)
sess.run(tf.global_variables_initializer())
tf_result = sess.run(tf_conv2d)
diff = np.mean(np.abs(tf_result.transpose((0, 3, 1, 2)) - torch_conv2d.detach().numpy()))
print('Mean of Abs Diff: {0}'.format(diff))
The result of this execution is:
Mean of Abs Diff: 2.0443112092038973e-08
When I change stride
to 2, the results start to vary.
# TF Conv2D
tf_conv2d = tf.nn.conv2d(x_tf,
weights_tf,
strides=[1, 2, 2, 1],
padding="SAME")
# PyTorch Conv2D
torch_conv2d = F.conv2d(x_torch, weights_torch, padding=1, stride=2)
The result of this execution is:
Mean of Abs Diff: 0.2104552686214447
According to PyTorch documentation, conv2d
uses zero-padding defined by the padding
argument. Thus, zeros are added to the left, top, right, and bottom of the input in my example.
If PyTorch simply adds padding on both sides based on the input parameter, it should be easy to replicate in Tensorflow.
# Manually add padding - consistent with PyTorch
paddings = tf.constant([[0, 0], [1, 1], [1, 1], [0, 0]])
x_tf = tf.convert_to_tensor(x.transpose((0, 2, 3, 1)), dtype=tf.float32)
x_tf = tf.pad(x_tf, paddings, "CONSTANT")
# TF Conv2D
tf_conv2d = tf.nn.conv2d(x_tf,
weights_tf,
strides=[1, 2, 2, 1],
padding="VALID")
The result of this comparison is:
Mean of Abs Diff: 1.6035047067930464e-08
What this tells me is that if I am somehow able to replicate the default padding behavior from Tensorflow into PyTorch, then my results will be similar.
This question inspected the behavior of padding in Tensorflow. TF documentation explains how padding is added for "SAME" convolutions. I discovered these links while writing this question.
Now that I know the padding strategy of Tensorflow, I can implement it in PyTorch.
To replicate the behavior, padding sizes are calculated as described in the Tensorflow documentation. Here, I test the padding behavior by setting stride=2
and padding the PyTorch input.
import tensorflow as tf
import numpy as np
import torch
import torch.nn.functional as F
np.random.seed(0)
sess = tf.Session()
# Create random weights and input
weights = torch.empty(3, 3, 3, 8)
torch.nn.init.constant_(weights, 5e-2)
x = np.random.randn(1, 3, 10, 10)
weights_tf = tf.convert_to_tensor(weights.numpy(), dtype=tf.float32)
weights_torch = torch.Tensor(weights.permute((3, 2, 0, 1)))
# Tensorflow padding behavior. Assuming that kH == kW to keep this simple.
stride = 2
if x.shape[2] % stride == 0:
pad = max(weights.shape[0] - stride, 0)
else:
pad = max(weights.shape[0] - (x.shape[2] % stride), 0)
if pad % 2 == 0:
pad_val = pad // 2
padding = (pad_val, pad_val, pad_val, pad_val)
else:
pad_val_start = pad // 2
pad_val_end = pad - pad_val_start
padding = (pad_val_start, pad_val_end, pad_val_start, pad_val_end)
x_tf = tf.convert_to_tensor(x.transpose((0, 2, 3, 1)), dtype=tf.float32)
x_torch = torch.Tensor(x)
x_torch = F.pad(x_torch, padding, "constant", 0)
# TF Conv2D
tf_conv2d = tf.nn.conv2d(x_tf,
weights_tf,
strides=[1, stride, stride, 1],
padding="SAME")
# PyTorch Conv2D
torch_conv2d = F.conv2d(x_torch, weights_torch, padding=0, stride=stride)
sess.run(tf.global_variables_initializer())
tf_result = sess.run(tf_conv2d)
diff = np.mean(np.abs(tf_result.transpose((0, 3, 1, 2)) - torch_conv2d.detach().numpy()))
print('Mean of Abs Diff: {0}'.format(diff))
The output is:
Mean of Abs Diff: 2.2477470551507395e-08
I wasn't quite sure why this was happening when I started writing this question, but a bit of reading clarified this very quickly. I hope this example can help others.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With