Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tensors are from different graphs

I am new to tensorflow. Trying to create an input pipeline from tfrecords. Below is my code snippet to create batches and feed into my estimator :

def generate_input_fn(image,label,batch_size=BATCH_SIZE):
    logging.info('creating batches...')    
    dataset = tf.data.Dataset.from_tensors((image, label)) #<-- dataset is 'TensorDataset'
    dataset = dataset.repeat().batch(batch_size)
    iterator=dataset.make_initializable_iterator()
    iterator.initializer
    return iterator.get_next()

The line iterator=dataset.make_initializable_iterator():

ValueError: Tensor("count:0", shape=(), dtype=int64, device=/device:CPU:0) must be from the same graph as Tensor("TensorDataset:0", shape=(), dtype=variant).

I think I am accidentally using tensors from different graphs, but I have no idea how and in which line of code. I have no idea which tensor is count:0 or whichone is TensorDataset:0.

Could anyone please help me debug this.

Error log:

      File "task.py", line 189, in main
    estimator.train(input_fn=lambda:generate_input_fn(image=image_data, label=label_data),steps=3,hooks=[logging_hook])
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 352, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 809, in _train_model
    input_fn, model_fn_lib.ModeKeys.TRAIN))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 668, in _get_features_and_labels_from_input_fn
    result = self._call_input_fn(input_fn, mode)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 760, in _call_input_fn
    return input_fn(**kwargs)
  File "task.py", line 189, in <lambda>
    estimator.train(input_fn=lambda:generate_input_fn(image=image_data, label=label_data),steps=3,hooks=[logging_hook])
  File "task.py", line 152, in generate_input_fn
    iterator=dataset.make_initializable_iterator()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 107, in make_initializable_iterator
    initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(),
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1399, in _as_variant_tensor
    self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1156, in _as_variant_tensor
    sparse.as_dense_types(self.output_types, self.output_classes)))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_dataset_ops.py", line 1696, in repeat_dataset
    output_types=output_types, output_shapes=output_shapes, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 350, in _apply_op_helper
    g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 5284, in _get_graph_from_inputs
    _assert_same_graph(original_graph_element, graph_element)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 5220, in _assert_same_graph
    original_item))
ValueError: Tensor("count:0", shape=(), dtype=int64, device=/device:CPU:0) must be from the same graph as Tensor("TensorDataset:0", shape=(), dtype=variant).

If I modify the function to:

image_placeholder=tf.placeholder(image.dtype,shape=image.shape)
label_placeholder=tf.placeholder(label.dtype,shape=label.shape)
dataset = tf.data.Dataset.from_tensors((image_placeholder, label_placeholder))

i.e. add placeholder, then I get output :

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
2018-03-18 01:56:55.902917: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
Killed
like image 722
Pratik Kumar Avatar asked Mar 17 '18 15:03

Pratik Kumar


People also ask

What is a tensor graph?

Graphs are data structures that contain a set of tf. Operation objects, which represent units of computation; and tf. Tensor objects, which represent the units of data that flow between operations. They are defined in a tf. Graph context.

What are the different types of tensors?

There are many types of tensors, including scalars and vectors (which are the simplest tensors), dual vectors, multilinear maps between vector spaces, and even some operations such as the dot product.

What is the shape of a tensor?

The base tf$Tensor class requires tensors to be “rectangular”—that is, along each axis, every element is the same size. However, there are specialized types of tensors that can handle different shapes: Ragged tensors (see RaggedTensor below)

What are three parameters that define tensors in TensorFlow?

Each tensor object is defined with tensor attributes like a unique label (name), a dimension (shape) and TensorFlow data types (dtype). You can define a tensor with decimal values or with a string by changing the type of data.


1 Answers

When you call estimator.train(input_fn), a new graph will be created with the graph defined in the model_fn of the estimator, and the graph defined in the input_fn.

Therefore, if any of these functions reference tensors from outside of their scopes, these will not be part of the same graph and you will get an error.


The easy solution is to make sure that every tensor you define is inside the input_fn or the model_fn.

For instance:

def generate_input_fn(batch_size):
    # Create the images and labels tensors here
    images = tf.placeholder(tf.float32, [None, 224, 224, 3])
    labels = tf.placeholder(tf.int64, [None])

    dataset = tf.data.Dataset.from_tensors((images, labels))
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(1)
    iterator = dataset.make_initializable_iterator()

    return iterator.get_next()
like image 101
Olivier Moindrot Avatar answered Sep 18 '22 12:09

Olivier Moindrot