In tensorflow I cannot find a straightforward possibility to do a convolution (tf.nn.conv2d) with periodic boundary conditions.
E.g. take the tensor
[[1,2,3],
[4,5,6],
[7,8,9]]
and any 3x3 filter. A convolution with periodic boundary conditions could in principle be done by doing a periodic padding to 5x5
[[9,7,8,9,7],
[3,1,2,3,1],
[6,4,5,6,4],
[9,7,8,9,7],
[3,1,2,3,1]]
and subsequently a convolution with the filter in "valid" mode. However, the function tf.pad unfortunately does not support periodic padding.
Is there a simple workaround?
The following should work for your case :
import tensorflow as tf
a = tf.constant([[1,2,3],[4,5,6],[7,8,9]])
b = tf.tile(a, [3, 3])
result = b[2:7, 2:7]
sess = tf.InteractiveSession()
print(result.eval())
# prints the following
array([[9, 7, 8, 9, 7],
[3, 1, 2, 3, 1],
[6, 4, 5, 6, 4],
[9, 7, 8, 9, 7],
[3, 1, 2, 3, 1]], dtype=int32)
As noted in the comments, this is a little inefficient in terms of memory. If memory is an issue for you, but are willing to spend some compute, the following will also work :
pre = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]])
post = tf.transpose(pre)
result = tf.matmul(tf.matmul(pre, a), post)
print(result.eval())
Slightly more general and flexible: periodic padding for one or more specified axes, with optionally specifiyng different padding lenghts for different axes
import tensorflow as tf
def periodic_padding_flexible(tensor, axis,padding=1):
"""
add periodic padding to a tensor for specified axis
tensor: input tensor
axis: on or multiple axis to pad along, int or tuple
padding: number of cells to pad, int or tuple
return: padded tensor
"""
if isinstance(axis,int):
axis = (axis,)
if isinstance(padding,int):
padding = (padding,)
ndim = len(tensor.shape)
for ax,p in zip(axis,padding):
# create a slice object that selects everything from all axes,
# except only 0:p for the specified for right, and -p: for left
ind_right = [slice(-p,None) if i == ax else slice(None) for i in range(ndim)]
ind_left = [slice(0, p) if i == ax else slice(None) for i in range(ndim)]
right = tensor[ind_right]
left = tensor[ind_left]
middle = tensor
tensor = tf.concat([right,middle,left], axis=ax)
return tensor
a = tf.constant([
[[1,2,3],[4,5,6],[7,8,9]],
[[11,12,13],[14,15,16],[17,18,19]],
])
sess = tf.InteractiveSession()
result = periodic_padding_flexible(a, axis=1,padding=1)
print('a:')
print(a.eval())
print('padded a:')
print(result.eval())
result = periodic_padding_flexible(a, axis=2,padding=1)
print('a:')
print(a.eval())
print('padded a:')
print(result.eval())
result = periodic_padding_flexible(a, axis=(1,2),padding=(1,2))
print('a:')
print(a.eval())
print('padded a:')
print(result.eval())
output:
a:
[[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]]
[[11 12 13]
[14 15 16]
[17 18 19]]]
padded a:
[[[ 7 8 9]
[ 1 2 3]
[ 4 5 6]
[ 7 8 9]
[ 1 2 3]]
[[17 18 19]
[11 12 13]
[14 15 16]
[17 18 19]
[11 12 13]]]
a:
[[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]]
[[11 12 13]
[14 15 16]
[17 18 19]]]
padded a:
[[[ 3 1 2 3 1]
[ 6 4 5 6 4]
[ 9 7 8 9 7]]
[[13 11 12 13 11]
[16 14 15 16 14]
[19 17 18 19 17]]]
a:
[[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]]
[[11 12 13]
[14 15 16]
[17 18 19]]]
padded a:
[[[ 8 9 7 8 9 7 8]
[ 2 3 1 2 3 1 2]
[ 5 6 4 5 6 4 5]
[ 8 9 7 8 9 7 8]
[ 2 3 1 2 3 1 2]]
[[18 19 17 18 19 17 18]
[12 13 11 12 13 11 12]
[15 16 14 15 16 14 15]
[18 19 17 18 19 17 18]
[12 13 11 12 13 11 12]]]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With