Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to extract weights of network?

I want to extract the weights of an optimized network with python. I have the .caffemodel file and I've obtained net.params which gives me the parameters of the whole network. The problem is when I am calling it for the for example the first layer, i.e. net.params['ip2'] it gives me:

<caffe._caffe.BlobVec object at 0x7f1cb03c8fa0>

How I can get the matrix of the weights, instead of the pointer?

like image 496
Afshin Oroojlooy Avatar asked Dec 19 '22 14:12

Afshin Oroojlooy


1 Answers

You have to read the network using the .prototxt file and the .caffemodel file.

net = caffe.Net('path/to/conv.prototxt', 'path/to/conv.caffemodel', caffe.TEST)
W = net.params['con_1'][0].data[...]
b = net.params['con_1'][1].data[...]

Have a look at this link and this link for more information.

like image 149
malreddysid Avatar answered Dec 26 '22 11:12

malreddysid