I would like to optimize the fitting of SIR model. If I fit the SIR model with only 60 data points I get a "good" result. "Good" means, the fitted model curve is close to data points till t=40. My question is, how can I get a better fit, maybe based on all data points?
ydata = ['1e-06', '1.49920166169172e-06', '2.24595472686361e-06', '3.36377954575331e-06', '5.03793663882291e-06', '7.54533628058909e-06', '1.13006564683911e-05', '1.69249500601052e-05', '2.53483161761933e-05', '3.79636391699325e-05', '5.68567547875179e-05', '8.51509649182741e-05', '0.000127522555808945', '0.000189928392105942', '0.000283447055673738', '0.000423064043409294', '0.000631295993246634', '0.000941024110897193', '0.00140281896645859', '0.00209085569326554', '0.00311449589149717', '0.00463557784224762', '0.00689146863803467', '0.010227347567051', '0.0151380084180746', '0.0223233100045688', '0.0327384810150231', '0.0476330618585758', '0.0685260046667727', '0.0970432959143974', '0.134525888779423', '0.181363340075877', '0.236189247803334', '0.295374180276257', '0.353377036130714', '0.404138746080267', '0.442876028839178', '0.467273954573897', '0.477529937494976', '0.475582401936257', '0.464137179474659', '0.445930281787152', '0.423331710456602', '0.39821360956389', '0.371967226561944', '0.345577884704341', '0.319716449520481', '0.294819942458255', '0.271156813453547', '0.24887641905719', '0.228045466022105', '0.208674420183194', '0.190736203926912', '0.174179448652951', '0.158937806544529', '0.144936441326754', '0.132096533873646', '0.120338367115739', '0.10958340819268', '0.099755679236243', '0.0907826241267504', '0.0825956203546979', '0.0751302384111894', '0.0683263295744258', '0.0621279977639921', '0.0564834809370572', '0.0513449852139111', '0.0466684871328814', '0.042413516167789', '0.0385429293775096', '0.035022685071934', '0.0318216204865132', '0.0289112368382048', '0.0262654939162707', '0.0238606155312519', '0.021674906523588', '0.0196885815912485', '0.0178836058829335', '0.0162435470852779', '0.0147534385851646', '0.0133996531928511', '0.0121697868544064', '0.0110525517526551', '0.0100376781867076', '0.00911582462544914', '0.00827849534575178', '0.00751796508841916', '0.00682721019158058', '0.00619984569061827', '0.00563006790443123', '0.00511260205894446', '0.00464265452957236', '0.00421586931435123', '0.00382828837833139', '0.00347631553734708', '0.00315668357532714', '0.00286642431380459', '0.00260284137520731', '0.00236348540287827', '0.00214613152062159', '0.00194875883295343']
ydata = [float(d) for d in ydata]
xdata = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101']
xdata = [float(t) for t in xdata]
from scipy.optimize import minimize
from scipy import integrate
import numpy as np
import pylab as pl
def fitFunc(sir_values, time, beta, gamma, k):
s = sir_values[0]
i = sir_values[1]
r = sir_values[2]
res = np.zeros((3))
res[0] = - beta * s * i
res[1] = beta * s * i - gamma * i
res[2] = gamma * i
return res
def lsq(model, xdata, ydata, n):
"""least squares"""
time_total = xdata
# original record data
data_record = ydata
# normalize train data
k = 1.0/sum(data_record)
# init t = 0 values + normalized
I0 = data_record[0]*k
S0 = 1 - I0
R0 = 0
N0 = [S0,I0,R0]
# Set initial parameter values
param_init = [0.75, 0.75]
param_init.append(k)
# fitting
param = minimize(sse(model, N0, time_total, k, data_record, n), param_init, method="nelder-mead").x
# get the fitted model
Nt = integrate.odeint(model, N0, time_total, args=tuple(param))
# scale out
Nt = np.divide(Nt, k)
# Get the second column of data corresponding to I
return Nt[:,1]
def sse(model, N0, time_total, k, data_record, n):
"""sum of square errors"""
def result(x):
Nt = integrate.odeint(model, N0, time_total[:n], args=tuple(x))
INt = [row[1] for row in Nt]
INt = np.divide(INt, k)
difference = data_record[:n] - INt
# square the difference
diff = np.dot(difference, difference)
return diff
return result
result = lsq(fitFunc, xdata, ydata, 60)
# Plot data and fit
pl.clf()
pl.plot(xdata, ydata, "o")
pl.plot(xdata, result)
pl.show()
I expect something like this:
I'm converting my comment to a fully fledged answer.
The problem arises from setting up the model incorrectly. To simplify the differential equations, I will refer to dS(t)/dt
and dI(t)/dt
as S
and I
respectively.
# incorrect
S = -S * I * beta
I = S * I * beta - I * gamma
# correct
S = -S * I * beta / N
I = S * I * beta / N - I * gamma
By setting up the differential equations incorrectly, the rate of change, i.e. the change of going from y(t) to y(t+dt), is wrong. So, not only do you obtain the incorrectly integrated I(t), but you further divide it by N (or k, as you have called it), making it even more wrong.
We know the coupled system of these specific equations requires that S(t) + I(t) + R(t) = N, where N is the population constant. From the way you declare the initial conditions, we infer that N is 1. Note that this is also consistent with max(ydata)
, which is less than 1.
# IO + SO + R0 is always 1 regardless of "value"
I0 = value
S0 = 1 - I0
R0 = 0
Furthermore, the way you handle k
is really dubious. Your data seems to already be normalised, but you multiply it by a factor of 0.1. As you can see, k = 1./sum(ydata)
, has nothing to do with the population constant. By doing I0 = ydata[0] * k
and dividing I(t) by k
, you effectively scale down your data only to scale them up later on. That pretty much restricts I(t) in the range 0-1, no matter what the population constant is.
You can verify that your model is wrong by simply setting all initial conditions and unknown parameters and seeing what comes out of odeint()
. You'll notice that S(0), I(0) and R(0) may not correspond to the values you give them, which is a sign of doing things wrong with k
. But to discover the flawed dynamics evolution, you have to simply review your model.
A hacky solution would be to set k = 1.0
. Everything works out because multiplications and divisions have no effect, even though you're still technically doing the wrong calculations. However, if your population constant is ever supposed to be different from 1, everything will break. So, for thoroughness,
manually set k
to the population constant, which you should know anyway unless you're also trying to fit for S0, I0 and/or R0.
Write the correct rate of change in the model for S
and I
.
Get rid of any np.divide(array, k)
calculations you have, and
remove k
from the fitFunc()
arguments and don't append it to the param_init
list. While this last action is optional and won't affect the result, it's still technically correct. This is because by passing k
, the optimizing solver tries to find an optimal value for it, even if you don't end up using it anywhere to affect your computations.
If you want to do a least squares fit, you can use curve_fit(), which internally calls a least squares method. You will still need to create a wrapper function for the fitting, which has to numerically integrate the system for various beta and gamma values, but you won't have to manually do any SSE calculations.
curve_fit()
will also return the covariance matrix, which you can use to estimate confidence intervals for your fitted variables. Further related discussion on calculating confidence intervals from the covariance matrix can be found here.
import numpy as np
import matplotlib.pyplot as plt
from scipy import integrate, optimize
ydata = ['1e-06', '1.49920166169172e-06', '2.24595472686361e-06', '3.36377954575331e-06', '5.03793663882291e-06', '7.54533628058909e-06', '1.13006564683911e-05', '1.69249500601052e-05', '2.53483161761933e-05', '3.79636391699325e-05', '5.68567547875179e-05', '8.51509649182741e-05', '0.000127522555808945', '0.000189928392105942', '0.000283447055673738', '0.000423064043409294', '0.000631295993246634', '0.000941024110897193', '0.00140281896645859', '0.00209085569326554', '0.00311449589149717', '0.00463557784224762', '0.00689146863803467', '0.010227347567051', '0.0151380084180746', '0.0223233100045688', '0.0327384810150231', '0.0476330618585758', '0.0685260046667727', '0.0970432959143974', '0.134525888779423', '0.181363340075877', '0.236189247803334', '0.295374180276257', '0.353377036130714', '0.404138746080267', '0.442876028839178', '0.467273954573897', '0.477529937494976', '0.475582401936257', '0.464137179474659', '0.445930281787152', '0.423331710456602', '0.39821360956389', '0.371967226561944', '0.345577884704341', '0.319716449520481', '0.294819942458255', '0.271156813453547', '0.24887641905719', '0.228045466022105', '0.208674420183194', '0.190736203926912', '0.174179448652951', '0.158937806544529', '0.144936441326754', '0.132096533873646', '0.120338367115739', '0.10958340819268', '0.099755679236243', '0.0907826241267504', '0.0825956203546979', '0.0751302384111894', '0.0683263295744258', '0.0621279977639921', '0.0564834809370572', '0.0513449852139111', '0.0466684871328814', '0.042413516167789', '0.0385429293775096', '0.035022685071934', '0.0318216204865132', '0.0289112368382048', '0.0262654939162707', '0.0238606155312519', '0.021674906523588', '0.0196885815912485', '0.0178836058829335', '0.0162435470852779', '0.0147534385851646', '0.0133996531928511', '0.0121697868544064', '0.0110525517526551', '0.0100376781867076', '0.00911582462544914', '0.00827849534575178', '0.00751796508841916', '0.00682721019158058', '0.00619984569061827', '0.00563006790443123', '0.00511260205894446', '0.00464265452957236', '0.00421586931435123', '0.00382828837833139', '0.00347631553734708', '0.00315668357532714', '0.00286642431380459', '0.00260284137520731', '0.00236348540287827', '0.00214613152062159', '0.00194875883295343']
xdata = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101']
ydata = np.array(ydata, dtype=float)
xdata = np.array(xdata, dtype=float)
def sir_model(y, x, beta, gamma):
S = -beta * y[0] * y[1] / N
R = gamma * y[1]
I = -(S + R)
return S, I, R
def fit_odeint(x, beta, gamma):
return integrate.odeint(sir_model, (S0, I0, R0), x, args=(beta, gamma))[:,1]
N = 1.0
I0 = ydata[0]
S0 = N - I0
R0 = 0.0
popt, pcov = optimize.curve_fit(fit_odeint, xdata, ydata)
fitted = fit_odeint(xdata, *popt)
plt.plot(xdata, ydata, 'o')
plt.plot(xdata, fitted)
plt.show()
You may notice some runtime warnings, but they're mostly due to the initial search of the minimization solver (Levenburg-Marquardt), which tries some values for beta
and gamma
that cause numerical overflows during the integration. However, it should settle down to more reasonable values soon enough. If you try different solvers for minimize()
, you'll notice similar warnings.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With