Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

MPSCNN Weight Ordering

The Metal Performance Shader framework provides support for building your own Convolutional Neural Nets. When creating for instance an MSPCNNConvolution it requires a 4D weight tensor as init parameter that is represented as a 1D float pointer.

init(device: MTLDevice,
  convolutionDescriptor: MPSCNNConvolutionDescriptor,
  kernelWeights: UnsafePointer<Float>,
  biasTerms: UnsafePointer<Float>?,
  flags: MPSCNNConvolutionFlags)

The documentation has this to say about the 4D tensor

The layout of the filter weight is arranged so that it can be reinterpreted as a 4D tensor (array) weight[outputChannels][kernelHeight][kernelWidth][inputChannels/groups]

Unfortunately that information doesn't really tell me how to arrange a 4D array into a one dimensional Float pointer.

I tried ordering the weights like the BNNS counterpart requires it, but without luck.

How do I properly represent the 4D tensor (array) as a 1D Float pointer (array)?

PS: I tried arranging it like a C array and getting the pointer to the flat array, but it didn't work.

UPDATE

@RhythmicFistman: That's how I stored it in a plain array, which I can convert to a UsafePointer<Float> (but doesn't work):

var output = Array<Float>(repeating: 0, count: weights.count)

for o in 0..<outputChannels {
    for ky in 0..<kernelHeight {
        for kx in 0..<kernelWidth {
            for i in 0..<inputChannels {
                let offset = ((o * kernelHeight + ky) * kernelWidth + kx) * inputChannels + i
                output[offset] = ...
            }
        }
    }
}
like image 967
Era Avatar asked Nov 10 '16 07:11

Era


1 Answers

Ok so I figured it out. Here are the 2 python functions I use to reform my convolutions and fully connected matrices

# shape required for MPSCNN [oC kH kW iC]
# tensorflow order is [kH kW iC oC]
def convshape(a):
    a = np.swapaxes(a, 2, 3)
    a = np.swapaxes(a, 1, 2)
    a = np.swapaxes(a, 0, 1)
    return a

# fully connected only requires a x/y swap
def fullshape(a):
    a = np.swapaxes(a, 0, 1)
    return a
like image 196
Era Avatar answered Oct 03 '22 11:10

Era