Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does pytorch backprop through argmax?

I'm building Kmeans in pytorch using gradient descent on centroid locations, instead of expectation-maximisation. Loss is the sum of square distances of each point to its nearest centroid. To identify which centroid is nearest to each point, I use argmin, which is not differentiable everywhere. However, pytorch is still able to backprop and update weights (centroid locations), giving similar performance to sklearn kmeans on the data.

Any ideas how this is working, or how I can figure this out within pytorch? Discussion on pytorch github suggests argmax is not differentiable: https://github.com/pytorch/pytorch/issues/1339.

Example code below (on random pts):

import numpy as np
import torch

num_pts, batch_size, n_dims, num_clusters, lr = 1000, 100, 200, 20, 1e-5

# generate random points
vector = torch.from_numpy(np.random.rand(num_pts, n_dims)).float()

# randomly pick starting centroids
idx = np.random.choice(num_pts, size=num_clusters)
kmean_centroids = vector[idx][:,None,:] # [num_clusters,1,n_dims]
kmean_centroids = torch.tensor(kmean_centroids, requires_grad=True)

for t in range(4001):
    # get batch
    idx = np.random.choice(num_pts, size=batch_size)
    vector_batch = vector[idx]

    distances = vector_batch - kmean_centroids # [num_clusters, #pts, #dims]
    distances = torch.sum(distances**2, dim=2) # [num_clusters, #pts]

    # argmin
    membership = torch.min(distances, 0)[1] # [#pts]

    # cluster distances
    cluster_loss = 0
    for i in range(num_clusters):
        subset = torch.transpose(distances,0,1)[membership==i]
        if len(subset)!=0: # to prevent NaN
            cluster_loss += torch.sum(subset[:,i])

    cluster_loss.backward()
    print(cluster_loss.item())

    with torch.no_grad():
        kmean_centroids -= lr * kmean_centroids.grad
        kmean_centroids.grad.zero_()
like image 891
jammygrams Avatar asked Mar 03 '19 14:03

jammygrams


People also ask

How does torch argmax work?

Torch. argmax() method accepts a tensor and returns the indices of the maximum values of the input tensor across a specified dimension/axis. If the input tensor exists with multiple maximal values then the function will return the index of the first maximal element.

Is argmax differentiable PyTorch?

(hard) argmax is not differentiable in general (this has nothing to do with PyTorch), i.e. one can not use gradient based methods with argmax. See e.g. https://www.reddit.com/r/MachineLearning/comments/4e2get/argmax_differentiable/ on how to train models involving argmax functions.

What is argmax in PyTorch?

argmax() Method in Python PyTorch. PyTorchServer Side ProgrammingProgramming. To find the indices of the maximum value of the elements in an input tensor, we can apply the torch. argmax() function. It returns the indices only, not the element value.

Is Torch Argmin differentiable?

As alvas noted in the comments, argmax is not differentiable. However, once you compute it and assign each datapoint to a cluster, the derivative of loss with respect to the location of these clusters is well-defined.


Video Answer


1 Answers

As alvas noted in the comments, argmax is not differentiable. However, once you compute it and assign each datapoint to a cluster, the derivative of loss with respect to the location of these clusters is well-defined. This is what your algorithm does.

Why does it work? If you had only one cluster (so that the argmax operation didn't matter), your loss function would be quadratic, with minimum at the mean of the data points. Now with multiple clusters, you can see that your loss function is piecewise (in higher dimensions think volumewise) quadratic - for any set of centroids [C1, C2, C3, ...] each data point is assigned to some centroid CN and the loss is locally quadratic. The extent of this locality is given by all alternative centroids [C1', C2', C3', ...] for which the assignment coming from argmax remains the same; within this region the argmax can be treated as a constant, rather than a function and thus the derivative of loss is well-defined.

Now, in reality, it's unlikely you can treat argmax as constant, but you can still treat the naive "argmax-is-a-constant" derivative as pointing approximately towards a minimum, because the majority of data points are likely to indeed belong to the same cluster between iterations. And once you get close enough to a local minimum such that the points no longer change their assignments, the process can converge to a minimum.

Another, more theoretical way to look at it is that you're doing an approximation of expectation maximization. Normally, you would have the "compute assignments" step, which is mirrored by argmax, and the "minimize" step which boils down to finding the minimizing cluster centers given the current assignments. The minimum is given by d(loss)/d([C1, C2, ...]) == 0, which for a quadratic loss is given analytically by the means of data points within each cluster. In your implementation, you're solving the same equation but with a gradient descent step. In fact, if you used a 2nd order (Newton) update scheme instead of 1st order gradient descent, you would be implicitly reproducing exactly the baseline EM scheme.

like image 127
Jatentaki Avatar answered Sep 18 '22 00:09

Jatentaki