Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow: using a session/graph in method

My situation is like this:

I have a script that trains a tensorflow model. Inside this script I instantiate a class that feeds the training data. That classes's initialization in turn instantiates another class called "image" to do various operations for data augmentation and what not.

main script -> instantiates data_feed class -> instantiates image class

My problem is that I'm trying to use tensorflow to do some operations within this image class by passing along either the session itself or the graph. But I've had little success.

Approach that works (but too slow)

What I have right now, but working painfully slow, is something like this (simplified):

class image(object):
    def __init__(self, im):
        self.im = im

    def augment(self):
        aux_im = tf.image.random_saturation(self.im, 0.6)

        sess = tf.Session(graph=aux_im.graph)
        self.im = sess.run(aux_im)

class data_feed(object):
    def __init__(self, data_dir):
        self.images = load_data(data_dir)

    def process_data(self):
        for im in self.images:
            image = image(im)
            image.augment()

if __name__ == "__main__":
    # initialize everything tensorflow related here, including model
    sess = tf.Session()
    # next load the data
    data_feed = data_feed(TRAIN_DATA_DIR)
    train_data = data_feed.process_data()

This aproach works, but it creates a new Session for every image:

I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1070, pci bus id: 0000:01:00.0)
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1070, pci bus id: 0000:01:00.0)
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1070, pci bus id: 0000:01:00.0)
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1070, pci bus id: 0000:01:00.0)
etc ...

Approach that doesn't work (and should have been way faster)

What does not work, for example, and I can't figure out why, is to pass along the graph or session from my main script, like so:

class image(object):
    def __init__(self, im):
        self.im = im

    def augment(self, tf_sess):
        with tf_sess.as_default():
            aux_im = tf.image.random_saturation(self.im, 0.6)

            self.im = tf_sess.run(aux_im)

class data_feed(object):
    def __init__(self, data_dir, tf_sess):
        self.images = load_data(data_dir)
        self.tf_sess = tf_sess

    def process_data(self):
        for im in self.images:
            image = image(im)
            image.augment(self.tf_sess)

if __name__ == "__main__":
    # initialize everything tensorflow related here, including model
    sess = tf.Session()
    # next load the data
    data_feed = data_feed(TRAIN_DATA_DIR, sess)
    train_data = data_feed.process_data()

This is the error I get:

Traceback (most recent call last):
  File "/usr/lib/python2.7/threading.py", line 801, in __bootstrap_inner
    self.run()
  File "/usr/lib/python2.7/threading.py", line 754, in run
    self.__target(*self.__args, **self.__kwargs)
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 409, in data_generator_task
    generator_output = next(generator)
  File "/home/mathetes/Dropbox/ML/load_gluc_data.py", line 198, in generate
    yield self.next_batch()
  File "/home/mathetes/Dropbox/ML/load_gluc_data.py", line 192, in next_batch
    X, y, l = self.process_image(json_im, X, y, l)
  File "/home/mathetes/Dropbox/ML/load_gluc_data.py", line 131, in process_image
    im.augment_with_tf(self.tf_sess)
  File "/home/mathetes/Dropbox/ML/load_gluc_data.py", line 85, in augment_with_tf
    self.im = sess.run(saturation, {im_placeholder: np.asarray(self.im)})
  File "/home/mathetes/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 766, in run
    run_metadata_ptr)
  File "/home/mathetes/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 921, in _run
    + e.args[0])
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(96, 96, 3), dtype=float32) is not an element of this graph.

Any help would be much appreciated!

like image 238
mathetes Avatar asked Oct 18 '22 16:10

mathetes


1 Answers

How about instead of having an Image class you create an ImageAugmenter class that takes a session when initialized, and then processes your images using Tensorflow? You could do something like this:

import tensorflow as tf
import numpy as np

class ImageAugmenter(object):
    def __init__(self, sess):
        self.sess = sess
        self.im_placeholder = tf.placeholder(tf.float32, shape=[1,784,3])

    def augment(self, image):
        augment_op = tf.image.random_saturation(self.im_placeholder, 0.6, 0.8)
        return self.sess.run(augment_op, {self.im_placeholder: image})

class DataFeed(object):
    def __init__(self, data_dir, sess):
        self.images = load_data(data_dir)
        self.augmenter = ImageAugmenter(sess)

    def process_data(self):
        processed_images = []
        for im in self.images:
            processed_images.append(self.augmenter.augment(im))
        return processed_images

def load_data(data_dir):
    # True method would read images from disk
    # This is just a mockup
    images = []
    images.append(np.random.random([1,784,3]))
    images.append(np.random.random([1,784,3]))
    return images

if __name__ == "__main__":
    TRAIN_DATA_DIR = '/some/dir/'
    sess = tf.Session()
    data_feed = DataFeed(TRAIN_DATA_DIR, sess)
    train_data = data_feed.process_data()
    print(train_data)

With this you wouldn't be creating a new session for each image, and it should give you what you want.

Note how sess.run() is called; The key I pass to its feed dict is the placeholder tensor defined above. According to your error trace, you are probably trying to call sess.run() from a part of your code where im_placeholder has not been defined, or it has been defined as other than a tf.placeholder.

Additionally, you could further improve the code by changing the ImageAugmenter.augment() method to receive both the lower and upper parameters as input for the tf.image.random_saturation() method, or you could initialize the ImageAugmenter with a specific shape instead of having it hardcoded, for example.

like image 109
jabalazs Avatar answered Oct 23 '22 04:10

jabalazs