I'm trying to use curve_fit
to fit a simple sine wave (not even with any noise) as a test before I move on to more complex problems. Unfortunately it's not giving even remotely the right answer. Here's my syntax:
x = linspace(0,100,300)
y = sin(1.759*x)
def mysine(x, a):
return sin(a*x)
popt, pcov = curve_fit(mysine, x, y)
popt
array([ 0.98679056])
And then if I try an initial guess (say 1.5):
popt, pcov = curve_fit(mysine, x, y, p0=1.5)
popt
array([ 1.49153365])
... which is still nowhere near the right answer.
I guess I'm surprised that, given how well the function is sampled, the fit doesn't work well.
Curve fitting is not always that straightforward. The curve_fit
algorithm is based on least squares curve fitting and usually needs an initial guess for the input parameters. Depending on the kind of function you want to fit, your initial guess has to be a good one.
Even though you tried an initial guess, I would say you have an additional problem which has to do with your sampling frequency and the frequency of your wave. For further information, you can have a look at the Nyquist-Shannon sampling theorem at Wikipedia. In simple words, the frequency of your wave is 1.759 / (2 * pi) = 0.28, which it turns out to be very close to the sampling frequency of your x
array (~0.33). Another issue that might arise is to have too many oscillations to fit to your function.
In order for your code to work, I would either suggest you increase the frequency of your wave (a > 4 * 0.33) or you increase your sampling frequency and reduce the length of your space vector x
.
I ran the following code and obtained the results as illustrated here:
# -*- coding: utf-8 -*-
import numpy as np
import pylab as pl
from scipy.optimize import curve_fit
def mysine(x, a):
return 1. * np.sin(a * x)
a = 1.759 # Wave frequency
x = np.linspace(0, 10, 100) # <== This is what I changed
y = np.sin(a * x) + 0. * np.random.normal(size=len(x))
# Runs curve fitting with initial guess.
popt, pcov = curve_fit(mysine, x, y, p0=[1.5])
# Calculates the fitted curve
yf = mysine(x, *popt)
# Plots results for comparison.
pl.ion()
pl.close('all')
fig = pl.figure()
ax = fig.add_subplot(111)
ax.plot(x, y, '-', c=[0.5, 0.5, 0.5])
ax.plot(x, yf, 'k-', linewidth=2.0)
ax.text(0.97, 0.97, ur'a=%.4f, ã=%.4f' % (a, popt[0]), ha='right', va='top',
fontsize=14, transform=ax.transAxes)
fig.savefig('stow_curve_fit.png')
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With