I created a "Python"
layer "myLayer"
in caffe, and use it in the net train_val.prototxt
I insert the layer like this:
layer {
name: "my_py_layer"
type: "Python"
bottom: "in"
top: "out"
python_param {
module: "my_module_name"
layer: "myLayer"
}
include { phase: TRAIN } # THIS IS THE TRICKY PART!
}
Now, my layer only participates in the TRAIN
ing phase of the net,
how can I know that in my layer's setup
function??
class myLayer(caffe.Layer):
def setup(self, bottom, top):
# I want to know here what is the phase?!!
...
PS,
I posted this question on "Caffe Users" google group as well. I'll udpdate if anything pops there.
As pointed out by galloguille, caffe is now exposing phase
to the python layer class. This new feature makes this answer a bit redundant. Still it is useful to know about the param_str
in caffe python layer for passing other parameters to the layer.
AFAIK there is no trivial way of getting the phase. However, one can pass arbitrary parameters from the net prototxt to python. This can be done using the param_str
parameters of the python_param
.
Here's how it's done:
layer {
type: "Python"
...
python_param {
...
param_str: '{"phase":"TRAIN","numeric_arg":5}' # passing params as a STRING
In python, you get param_str
in the layer's setup
function:
import caffe, json
class myLayer(caffe.Layer):
def setup(self, bottom, top):
param = json.loads( self.param_str ) # use JSON to convert string to dict
self.phase = param['phase']
self.other_param = int( param['numeric_arg'] ) # I might want to use this as well...
This is a very good workaround, but if you are only interested in passing the phase
as a parameter, now you can access the phase as an attribute of the layer. This feature was merged just 6 days ago https://github.com/BVLC/caffe/pull/3995.
Specific commit: https://github.com/BVLC/caffe/commit/de8ac32a02f3e324b0495f1729bff2446d402c2c
With this new feature you just need to use the attribute self.phase
. For example you can do the following:
class PhaseLayer(caffe.Layer):
"""A layer for checking attribute `phase`"""
def setup(self, bottom, top):
pass
def reshape(self, bootom, top):
top[0].reshape()
def forward(self, bottom, top):
top[0].data[()] = self.phase
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With