Logo Questions Linux Laravel Mysql Ubuntu Git Menu

understand how this lambda function works




I thought I understood how lambda functions work, though I don't use them myself. But the lambda below from this tutorial totally stumps me:

import matplotlib.pyplot as plt
import numpy as np
import sklearn
import sklearn.datasets
import sklearn.linear_model
import matplotlib

That was easy. More:

# Generate a dataset and plot it
X, y = sklearn.datasets.make_moons(200, noise=0.20)
plt.scatter(X[:,0], X[:,1], s=40, c=y, cmap=plt.cm.Spectral)
clf = sklearn.linear_model.LogisticRegressionCV()
clf.fit(X, y)

# Helper function to plot a decision boundary.
# If you don't fully understand this function don't worry, it just generates the contour plot below.

def plot_decision_boundary(pred_func):

    # Set min and max values and give it some padding
    x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
    y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
    h = 0.01

    # Generate a grid of points with distance h between them
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

    # Predict the function value for the whole gid
    Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)

    # Plot the contour and training examples
    plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Spectral)

Now the line I don't understand:

plot_decision_boundary(lambda x: clf.predict(x))

I've read many times how lambdas work, but I just don't get how the x here passing the correct values from before. How is the x mapped to the relevant values?

like image 834
Ada Stra Avatar asked Nov 30 '22 23:11

Ada Stra

2 Answers

lambdas are just anonymous functions. lambda bodies can only be an expression (as subset of what you can put into a function) because they have to fit inline with other code.

plot_decision_boundary(lambda x: clf.predict(x)) could be rewritten as

def call_clf_predict(x):
    return clf.predict(x)

Here, its more clear what is going on. plot_decision_boundary gets a callable and calls it with the single parameter np.c_[xx.ravel(), yy.ravel()].

But lambda shouldn't have been used here in the first place. You could just do


In the grand tradition of python tutorials, lambda is abused once again.

like image 187
tdelaney Avatar answered Dec 29 '22 11:12


plot_decision_boundary(lambda x: clf.predict(x))

This line is passing a function that takes a single argument into the method. When the lambda is evaluated, or the method is "called" with an argument x, it'll do clf.predict(x)

Within the method, that function is named pred_func and it is called with its single argument at

Z = pred_func(np.c_[xx.ravel(), yy.ravel()])

So the code that is ran is

Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) 
like image 37
OneCricketeer Avatar answered Dec 29 '22 11:12
