Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using Scipy curve_fit with piecewise function

I am getting an optimize warning:

OptimizeWarning: Covariance of the parameters could not be estimated
                 category=OptimizeWarning)

when trying to fit my piecewise function to my data using scipy.optimize.curve_fit. Meaning no fitting is happening. I can easily fit a parabola to my data, and I'm supplying curve_fit with what I feel are good initial parameters. Full code sample below. Does anyone know why curve_fit might not be getting along with np.piecewise? Or am I making a different mistake?

import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt


def piecewise_linear(x, x0, y0, k1, k2):
    y = np.piecewise(x, [x < x0, x >= x0],
                     [lambda x:k1*x + y0-k1*x0, lambda x:k2*x + y0-k2*x0])
    return y

def parabola(x, a, b):
    y = a * x**2 + b
    return y

x = np.array([-3, -2, -1, 0, 1, 2, 3])
y = np.array([9.15, 5.68, 2.32, 0.00, 2.05, 5.29, 8.62])


popt_piecewise, pcov = curve_fit(piecewise_linear, x, y, p0=[0.1, 0.1, -5, 5])
popt_parabola, pcov = curve_fit(parabola, x, y, p0=[1, 1])

new_x = np.linspace(x.min(), x.max(), 61)


fig, ax = plt.subplots()

ax.plot(x, y, 'o', ls='')
ax.plot(new_x, piecewise_linear(new_x, *popt_piecewise))
ax.plot(new_x, parabola(new_x, *popt_parabola))

ax.set_xlim(-4, 4)
ax.set_ylim(-2, 16)

enter image description here

like image 564
lanery Avatar asked Jan 13 '17 19:01

lanery


People also ask

How do you find the curve of best fit in SciPy?

SciPy | Curve Fitting. Given a Dataset comprising of a group of points, find the best fit representing the Data. We often have a dataset comprising of data following a general path, but each data has a standard deviation which makes them scattered across the line of best fit. We can get a single line using curve-fit() function.

How do you use curve_fit in Python?

Fitting Example With SciPy curve_fit Function in Python. The SciPy API provides a 'curve_fit' function in its optimization library to fit the data with a given function. This method applies non-linear least squares to fit the data and extract the optimal parameters out of it.

Why is SciPy curve_fit so bad?

Scipy’s curve_fit is not able to accurately force the intercept to be zero which causes that high standard deviation and a low R2 R 2 value. This is where lmfit (my favorite fitting package) comes into play. As the complexity of fitting function and parameter bounds increases curve_fit becomes less accurate and more crumbersome.

Why does scipy’s curve_fit have a low R2 R2 value?

Now look at the standard deviation on the intercept. Scipy’s curve_fit is not able to accurately force the intercept to be zero which causes that high standard deviation and a low R2 R 2 value. This is where lmfit (my favorite fitting package) comes into play.


1 Answers

It is a problem with types, you have to change the following line, so that the x is given as floats:

x = np.array([-3, -2, -1, 0, 1, 2, 3]).astype(np.float)

otherwise the piecewise_linear will might end up casting the types.

Just to be on the safe side you could also make the initial points float here:

popt_piecewise, pcov = curve_fit(piecewise_linear, x, y, p0=[0.1, 0.1, -5., 5.])
like image 159
J. P. Petersen Avatar answered Sep 22 '22 13:09

J. P. Petersen