Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using KMeans tflearn estimator as part of a graph in tensorflow

I am trying to use tensorflow.contrib.learn.KMeansClustering as part of a graph in Tensorflow. I would like to use it as a component of a graph, giving me predictions and centers. The relevant part of the code is the following:

with tf.variable_scope('kmeans'):
    kmeans = KMeansClustering(num_clusters=num_clusters,
                              relative_tolerance=0.0001)
    kmeans.fit(input_fn= (lambda : [X, None]))
    clusters = kmeans.clusters()

init_vars = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_vars, feed_dict={X: full_data_x})
clusters_np = sess.run(clusters, feed_dict={X: full_data_x})

However, I get the following error:

ValueError: Tensor("kmeans/strided_slice:0", shape=(), dtype=int32) must be from the same graph as Tensor("sub:0", shape=(), dtype=int32).

I believe this is because KMeansClustering is a TFLearn estimator; which would be more akin to a whole graph than a single module. Is that correct? Can I transform it to a module of the default graph? If not, is there a function to do KMeans within another graph?

Thanks!

like image 366
etal Avatar asked Nov 17 '25 06:11

etal


2 Answers

The KMeansClustering Estimator uses ops from tf.contrib.factorization. The factorization MNIST example uses KMeans without an Estimator.

like image 100
Allen Lavoie Avatar answered Nov 18 '25 21:11

Allen Lavoie


The KMeansClustering Estimator API builds its own tf.Graph and manage tf.Session by itself, so you don't need to run a tf.Session to feed values (that is done by input_fn), that's why the ValueError arise.

The correct usage of KMeansClustering Estimator is just:

kmeans = KMeansClustering(num_clusters=num_clusters,
                          relative_tolerance=0.0001)
kmeans.fit(input_fn=(lambda: [X, None]))
clusters = kmeans.clusters()

where X is a tf.constant input tensor that holds the values (e.g. define X as np.array and than use tf.convert_to_tensor). Here X is not a tf.placeholder that needs to be feed at a tf.Session run.

Update for TensorFlow 1.4:

Use tf.contrib.factorization.KMeansClustering API to find cluster centers:

kmeans=tf.contrib.factorization.KMeansClustering(num_clusters=num_clusters)
kmeans.train(input_fn=(lambda: [X, None]))
centers = kmeans.cluster_centers()

To predict centers for given features just use:

predictions = kmeans.predict(input_fn=(lambda:[another_X, None]))
like image 41
J.E.K Avatar answered Nov 18 '25 20:11

J.E.K



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!