Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using a Tensorflow input pipeline with skflow/tf learn

I've followed the Tensorflow Reading Data guide to get my app's data in the form of TFRecords, and am using TFRecordReader in my input pipelines to read this data.

I'm now reading the guides on using skflow/tf.learn to build a simple regressor, but I can't see how to use my input data with these tools.

In the following code, the app fails on the regressor.fit(..) call, with ValueError: setting an array element with a sequence..

Error:

Traceback (most recent call last):
  File ".../tf.py", line 138, in <module>
    run()
  File ".../tf.py", line 86, in run
    regressor.fit(x, labels)
  File ".../site-packages/tensorflow/contrib/learn/python/learn/estimators/base.py", line 218, in fit
    self.batch_size)
  File ".../site-packages/tensorflow/contrib/learn/python/learn/io/data_feeder.py", line 99, in setup_train_data_feeder
    return data_feeder_cls(X, y, n_classes, batch_size)
  File ".../site-packages/tensorflow/contrib/learn/python/learn/io/data_feeder.py", line 191, in __init__
    self.X = check_array(X, dtype=x_dtype)
  File ".../site-packages/tensorflow/contrib/learn/python/learn/io/data_feeder.py", line 161, in check_array
    array = np.array(array, dtype=dtype, order=None, copy=False)

ValueError: setting an array element with a sequence.

Code:

import tensorflow as tf
import tensorflow.contrib.learn as learn

def inputs():
    with tf.name_scope('input'):
        filename_queue = tf.train.string_input_producer([filename])

        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)

        features = tf.parse_single_example(serialized_example, feature_spec)
        labels = features.pop('actual')
        some_feature = features['some_feature']

        features_batch, labels_batch = tf.train.shuffle_batch(
            [some_feature, labels], batch_size=batch_size, capacity=capacity,
            min_after_dequeue=min_after_dequeue)

        return features_batch, labels_batch


def run():
    with tf.Graph().as_default():
        x, labels = inputs()

        # regressor = learn.TensorFlowDNNRegressor(hidden_units=[10, 20, 10])
        regressor = learn.TensorFlowLinearRegressor()

        regressor.fit(x, labels)
        ...

It looks like the check_array call is expecting a real array, not a tensor. Is there anything I can do to massage my data into the right shape?

like image 641
Mark McDonald Avatar asked May 30 '16 01:05

Mark McDonald


People also ask

What is TensorFlow input pipeline?

The tf. data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.

What is a data input pipeline?

A data pipeline is a series of data processing steps. If the data is not currently loaded into the data platform, then it is ingested at the beginning of the pipeline. Then there are a series of steps in which each step delivers an output that is the input to the next step.


1 Answers

It looks like the API that you were working with is depreciated. If you use a more modern tf.contrib.learn.LinearRegressor (I think >= 1.0), you are supposed to specify the input_fn, which basically produces the inputs and labels. I think in your example, that would be as simple as changing your run function to:

def run():
    with tf.Graph().as_default():
        regressor = tf.contrib.learn.LinearRegressor()
        regressor.fit(input_fn=my_input_fn)

and then defining an input function called my_input_fn. From the docs, this input function takes the form:

def my_input_fn():

    # Preprocess your data here...

    # ...then return 1) a mapping of feature columns to Tensors with
    # the corresponding feature data, and 2) a Tensor containing labels
    return feature_cols, labels

I think the documentation can get you the rest of the way. It is difficult from here for me to say how you should proceed without seeing your data.

like image 82
Engineero Avatar answered Nov 14 '22 23:11

Engineero