Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How would I implement k-means with TensorFlow?

Tags:

The intro tutorial, which uses the built-in gradient descent optimizer, makes a lot of sense. However, k-means isn't just something I can plug into gradient descent. It seems like I'd have to write my own sort of optimizer, but I'm not quite sure how to do that given the TensorFlow primitives.

What approach should I take?

like image 207
Raphie Palefsky-Smith Avatar asked Nov 10 '15 02:11

Raphie Palefsky-Smith


People also ask

Which algorithm is used for K-means implementation?

The K-means clustering algorithm computes centroids and repeats until the optimal centroid is found. It is presumptively known how many clusters there are. It is also known as the flat clustering algorithm. The number of clusters found from data by the method is denoted by the letter 'K' in K-means.

How implement K-means algorithm in Python?

Step-1: Select the value of K, to decide the number of clusters to be formed. Step-2: Select random K points which will act as centroids. Step-3: Assign each data point, based on their distance from the randomly selected points (Centroid), to the nearest/closest centroid which will form the predefined clusters.


1 Answers

(note: You can now get a more polished version of this code as a gist on github.)

you can definitely do it, but you need to define your own optimization criteria (for k-means, it's usually a max iteration count and when the assignment stabilizes). Here's an example of how you might do it (there are probably more optimal ways to implement it, and definitely better ways to select the initial points). It's basically like you'd do it in numpy if you were trying really hard to stay away from doing things iteratively in python:

import tensorflow as tf import numpy as np import time  N=10000 K=4 MAX_ITERS = 1000  start = time.time()  points = tf.Variable(tf.random_uniform([N,2])) cluster_assignments = tf.Variable(tf.zeros([N], dtype=tf.int64))  # Silly initialization:  Use the first two points as the starting                 # centroids.  In the real world, do this better.                                  centroids = tf.Variable(tf.slice(points.initialized_value(), [0,0], [K,2]))  # Replicate to N copies of each centroid and K copies of each                     # point, then subtract and compute the sum of squared distances.                  rep_centroids = tf.reshape(tf.tile(centroids, [N, 1]), [N, K, 2]) rep_points = tf.reshape(tf.tile(points, [1, K]), [N, K, 2]) sum_squares = tf.reduce_sum(tf.square(rep_points - rep_centroids),                             reduction_indices=2)  # Use argmin to select the lowest-distance point                                  best_centroids = tf.argmin(sum_squares, 1) did_assignments_change = tf.reduce_any(tf.not_equal(best_centroids,                                                     cluster_assignments))  def bucket_mean(data, bucket_ids, num_buckets):     total = tf.unsorted_segment_sum(data, bucket_ids, num_buckets)     count = tf.unsorted_segment_sum(tf.ones_like(data), bucket_ids, num_buckets)     return total / count  means = bucket_mean(points, best_centroids, K)  # Do not write to the assigned clusters variable until after                      # computing whether the assignments have changed - hence with_dependencies with tf.control_dependencies([did_assignments_change]):     do_updates = tf.group(         centroids.assign(means),         cluster_assignments.assign(best_centroids))  sess = tf.Session() sess.run(tf.initialize_all_variables())  changed = True iters = 0  while changed and iters < MAX_ITERS:     iters += 1     [changed, _] = sess.run([did_assignments_change, do_updates])  [centers, assignments] = sess.run([centroids, cluster_assignments]) end = time.time() print ("Found in %.2f seconds" % (end-start)), iters, "iterations" print "Centroids:" print centers print "Cluster assignments:", assignments 

(Note that a real implementation would need to be more careful about initial cluster selection, avoiding problem cases with all points going to one cluster, etc. This is just a quick demo. I've updated my answer from earlier to make it a bit more clear and "example-worthy".)

like image 55
dga Avatar answered Oct 18 '22 14:10

dga