I am trying to follow the official guide for defining new operators in tensorflow. https://www.tensorflow.org/extend/adding_an_op
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow;
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c){
c->set_output(0, c->input(0));
return Status::OK();
});
However I cannot find a line-by-line explanation of this code and in particular I do not understand what is the role of .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) and its syntax. Also I am puzzled by InferenceContext, I am guessing it is a way to pass elements of any array one-by-one in succession.. I could not find explicit definitions anywhere, maybe I am looking in the wrong places, can someone help me either with explanation or reference? I would like to deeply understand what this piece of code is doing under the hood.
Did you spot the section on shape inference functions here? https://www.tensorflow.org/extend/adding_an_op#shape_functions_in_c
That has quite a lot of discussion of the ShapeInferenceContext class and the mechanics of writing your own functions. If that doesn't cover what you're interested in, could you give more details?
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With