Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

3D-plot of the error function in a linear regression

I would like to visually plot a 3D graph of the error function calculated for a given slope and y-intercept for a linear regression. This graph will be used to illustrate a gradient descent application.

Let’s suppose we want to model a set of points with a line. To do this we’ll use the standard y=mx+b line equation where m is the line’s slope and b is the line’s y-intercept. To find the best line for our data, we need to find the best set of slope m and y-intercept b values.

A standard approach to solving this type of problem is to define an error function (also called a cost function) that measures how “good” a given line is. This function will take in a (m,b) pair and return an error value based on how well the line fits the data. To compute this error for a given line, we’ll iterate through each (x,y) point in the data set and sum the square distances between each point’s y value and the candidate line’s y value (computed at mx+b). It’s conventional to square this distance to ensure that it is positive and to make our error function differentiable. In python, computing the error for a given line will look like:

# y = mx + b
# m is slope, b is y-intercept
def computeErrorForLineGivenPoints(b, m, points):
    totalError = 0
    for i in range(0, len(points)):
        totalError += (points[i].y - (m * points[i].x + b)) ** 2
    return totalError / float(len(points))

Since the error function consists of two parameters (m and b) we can visualize it as a two-dimensional surface.

Now my question, how can we plot such 3D-graph using python ?

Here is a skeleton code to build a 3D plot. This code snippet is totally out of the question context but it show the basics for building a 3D plot. For my example i would need the x-axis being the slope, the y-axis being the y-intercept and the z-axis, the error.

Can someone help me build such example of graph ?

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import random

def fun(x, y):
  return x**2 + y

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
x = y = np.arange(-3.0, 3.0, 0.05)
X, Y = np.meshgrid(x, y)
zs = np.array([fun(x,y) for x,y in zip(np.ravel(X), np.ravel(Y))])
Z = zs.reshape(X.shape)

ax.plot_surface(X, Y, Z)

ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')

plt.show()

The above code produce the following plot, which is very similar to what i am looking for. Here is what the above code produce

like image 376
Michael Avatar asked Feb 16 '15 13:02

Michael


1 Answers

Simply replace fun with computeErrorForLineGivenPoints:

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import collections

def error(m, b, points):
    totalError = 0
    for i in range(0, len(points)):
        totalError += (points[i].y - (m * points[i].x + b)) ** 2
    return totalError / float(len(points))

x = y = np.arange(-3.0, 3.0, 0.05)
Point = collections.namedtuple('Point', ['x', 'y'])

m, b = 3, 2
noise = np.random.random(x.size)
points = [Point(xp, m*xp+b+err) for xp,err in zip(x, noise)]

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ms = np.linspace(2.0, 4.0, 10)
bs = np.linspace(1.5, 2.5, 10)

M, B = np.meshgrid(ms, bs)
zs = np.array([error(mp, bp, points) 
               for mp, bp in zip(np.ravel(M), np.ravel(B))])
Z = zs.reshape(M.shape)

ax.plot_surface(M, B, Z, rstride=1, cstride=1, color='b', alpha=0.5)

ax.set_xlabel('m')
ax.set_ylabel('b')
ax.set_zlabel('error')

plt.show()

yields enter image description here

Tip: I renamed computeErrorForLineGivenPoints as error. Generally, there is no need to name a function compute... since almost all functions compute something. You also do not need to specify "GivenPoints" since the function signature shows that points is an argument. If you have other error functions or variables in your program, line_error or total_error might be a better name for this function.

like image 176
unutbu Avatar answered Oct 14 '22 00:10

unutbu