Logo Questions Linux Laravel Mysql Ubuntu Git Menu

tensorflow periodic padding

In tensorflow I cannot find a straightforward possibility to do a convolution (tf.nn.conv2d) with periodic boundary conditions.

E.g. take the tensor


and any 3x3 filter. A convolution with periodic boundary conditions could in principle be done by doing a periodic padding to 5x5


and subsequently a convolution with the filter in "valid" mode. However, the function tf.pad unfortunately does not support periodic padding.

Is there a simple workaround?

like image 733
Jens Avatar asked Aug 22 '16 20:08


2 Answers

The following should work for your case :

import tensorflow as tf
a = tf.constant([[1,2,3],[4,5,6],[7,8,9]])
b = tf.tile(a, [3, 3])
result = b[2:7, 2:7]
sess = tf.InteractiveSession()

# prints the following 
array([[9, 7, 8, 9, 7],
       [3, 1, 2, 3, 1],
       [6, 4, 5, 6, 4],
       [9, 7, 8, 9, 7],
       [3, 1, 2, 3, 1]], dtype=int32)

As noted in the comments, this is a little inefficient in terms of memory. If memory is an issue for you, but are willing to spend some compute, the following will also work :

pre = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]])
post = tf.transpose(pre)
result = tf.matmul(tf.matmul(pre, a), post)
like image 135
keveman Avatar answered Oct 14 '22 13:10


Slightly more general and flexible: periodic padding for one or more specified axes, with optionally specifiyng different padding lenghts for different axes

import tensorflow as tf

def periodic_padding_flexible(tensor, axis,padding=1):
        add periodic padding to a tensor for specified axis
        tensor: input tensor
        axis: on or multiple axis to pad along, int or tuple
        padding: number of cells to pad, int or tuple

        return: padded tensor

    if isinstance(axis,int):
        axis = (axis,)
    if isinstance(padding,int):
        padding = (padding,)

    ndim = len(tensor.shape)
    for ax,p in zip(axis,padding):
        # create a slice object that selects everything from all axes,
        # except only 0:p for the specified for right, and -p: for left

        ind_right = [slice(-p,None) if i == ax else slice(None) for i in range(ndim)]
        ind_left = [slice(0, p) if i == ax else slice(None) for i in range(ndim)]
        right = tensor[ind_right]
        left = tensor[ind_left]
        middle = tensor
        tensor = tf.concat([right,middle,left], axis=ax)

    return tensor

a = tf.constant([

sess = tf.InteractiveSession()

result = periodic_padding_flexible(a, axis=1,padding=1)
print('padded a:')

result = periodic_padding_flexible(a, axis=2,padding=1)
print('padded a:')

result = periodic_padding_flexible(a, axis=(1,2),padding=(1,2))
print('padded a:')


[[[ 1  2  3]
  [ 4  5  6]
  [ 7  8  9]]
 [[11 12 13]
  [14 15 16]
  [17 18 19]]]
padded a:
[[[ 7  8  9]
  [ 1  2  3]
  [ 4  5  6]
  [ 7  8  9]
  [ 1  2  3]]
 [[17 18 19]
  [11 12 13]
  [14 15 16]
  [17 18 19]
  [11 12 13]]]
[[[ 1  2  3]
  [ 4  5  6]
  [ 7  8  9]]
 [[11 12 13]
  [14 15 16]
  [17 18 19]]]
padded a:
[[[ 3  1  2  3  1]
  [ 6  4  5  6  4]
  [ 9  7  8  9  7]]
 [[13 11 12 13 11]
  [16 14 15 16 14]
  [19 17 18 19 17]]]
[[[ 1  2  3]
  [ 4  5  6]
  [ 7  8  9]]
 [[11 12 13]
  [14 15 16]
  [17 18 19]]]
padded a:
[[[ 8  9  7  8  9  7  8]
  [ 2  3  1  2  3  1  2]
  [ 5  6  4  5  6  4  5]
  [ 8  9  7  8  9  7  8]
  [ 2  3  1  2  3  1  2]]
 [[18 19 17 18 19 17 18]
  [12 13 11 12 13 11 12]
  [15 16 14 15 16 14 15]
  [18 19 17 18 19 17 18]
  [12 13 11 12 13 11 12]]]
like image 3
Sip Avatar answered Oct 14 '22 13:10
