Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow session in a class in python?

I want to use tensorflow gradients for a computation of other quantities later on. I need to numerically compute the objective function and gradients as functions in a class (This class then is used in the remaining suite). However, I am getting error for the below code:

import tensorflow as tf
class MyClass:
    def __init__(self):
        x=tf.Variable(tf.zeros(2))
        func = tf.cos(14.5 * x[0] - 0.3) + (x[1] + 0.2) * x[1] + (x[0] + 0.2) * x[0]
        diff_func = tf.gradients(func,x)

        sess = tf.Session()

    def getFunc(self,coords):
        return self.sess.run(self.func,feed_dict={self.x:coords})

    def getGrad(self,coords):
        grad = self.sess.run(self.diff_func,feed_dict={self.x:coords})
        return grad

MyClass = MyClass()
MyClass.getFunc([0.362,0.556])
MyClass.getGrad([0.362,0.556])

The error I am getting is:

Traceback (most recent call last):

File "", line 19, in MyClass.getFunc([0.362,0.556])

File "", line 11, in getFunc return self.sess.run(self.func,feed_dict={self.x:coords})

AttributeError: MyClass instance has no attribute 'sess'

Not sure how I can get this class running correctly. Thanks.

like image 459
Dave Avatar asked Mar 29 '18 22:03

Dave


People also ask

What is tf session ()?

TensorFlow Session is a session object which encapsulates the environment in which Operation objects are executed, and data objects are evaluated. TensorFlow requires a session to execute an operation and retrieve its calculated value. A session may own several resources, for example, tf. QueueBase, tf.

How do you initialize a session in TensorFlow?

Step 1 : In the first place , Import tensorflow module . Step 2 : Lets Define Variable and create graph in tensorflow . Lets move further . Step 3 : Now Create session ,initialize the variable and execute the graph .

What is tf compat v1?

tf. compat allows you to write code that works both in TensorFlow 1.


2 Answers

Basically, 'self' tells which variables and methods belongs to a class. So you have to tell that (x, func, diff_func and sess) belong to the MyClass. So modify the code as below:

import tensorflow as tf


class MyClass:
    def __init__(self):
        self.x = tf.Variable(tf.zeros(2))
        self.func = tf.cos(14.5 * self.x[0] - 0.3) + (self.x[1] + 0.2) * self.x[1] + (self.x[0] + 0.2) * self.x[0]
        self.diff_func = tf.gradients(self.func, self.x)

        self.sess = tf.Session()

    def getFunc(self, coords):
        return self.sess.run(self.func, feed_dict={self.x: coords})

    def getGrad(self, coords):
        grad = self.sess.run(self.diff_func, feed_dict={self.x: coords})
        return grad


MyClass = MyClass()
MyClass.getFunc([0.362, 0.556])
print(MyClass.getGrad([0.362, 0.556]))
like image 168
Aravindh Kuppusamy Avatar answered Sep 21 '22 16:09

Aravindh Kuppusamy


Replace sess = tf.Session() with self.sess = tf.Session().

like image 40
Christian NH Avatar answered Sep 23 '22 16:09

Christian NH