Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Understanding the Definition of New Tensorflow Operators in C++

I am trying to follow the official guide for defining new operators in tensorflow. https://www.tensorflow.org/extend/adding_an_op

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c){
      c->set_output(0, c->input(0));
      return Status::OK();
    });

However I cannot find a line-by-line explanation of this code and in particular I do not understand what is the role of .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) and its syntax. Also I am puzzled by InferenceContext, I am guessing it is a way to pass elements of any array one-by-one in succession.. I could not find explicit definitions anywhere, maybe I am looking in the wrong places, can someone help me either with explanation or reference? I would like to deeply understand what this piece of code is doing under the hood.

like image 675
ricvo Avatar asked Oct 29 '22 02:10

ricvo


1 Answers

Did you spot the section on shape inference functions here? https://www.tensorflow.org/extend/adding_an_op#shape_functions_in_c

That has quite a lot of discussion of the ShapeInferenceContext class and the mechanics of writing your own functions. If that doesn't cover what you're interested in, could you give more details?

like image 129
Pete Warden Avatar answered Jan 02 '23 20:01

Pete Warden