I've looked around on Stackoverflow, but could not find anything that would answer my question.
Problem Setup:
I am trying to solve a system of stiff ODEs using scipy.integrate.ode. I've reduced the code to the minimal working example:
import scipy as sp
from scipy import integrate
import matplotlib.pylab as plt
spiketrain =[0]
syn_inst = 0
def synapse(t, t0):
tau_1 = 5.3
tau_2 = 0.05
tau_rise = (tau_1 * tau_2) / (tau_1 - tau_2)
B = ((tau_2 / tau_1) ** (tau_rise / tau_1) - (tau_2 / tau_1) ** (tau_rise / tau_2)) ** -1
return B*(sp.exp(-(t - t0) / tau_1) - sp.exp(-(t - t0) / tau_2)) #the culprit
def alpha_m(v, vt):
return -0.32*(v - vt -13)/(sp.exp(-1*(v-vt-13)/4)-1)
def beta_m(v, vt):
return 0.28 * (v - vt - 40) / (sp.exp((v- vt - 40) / 5) - 1)
def alpha_h(v, vt):
return 0.128 * sp.exp(-1 * (v - vt - 17) / 18)
def beta_h(v, vt):
return 4 / (sp.exp(-1 * (v - vt - 40) / 5) + 1)
def alpha_n(v, vt):
return -0.032*(v - vt - 15)/(sp.exp(-1*(v-vt-15)/5) - 1)
def beta_n(v, vt):
return 0.5* sp.exp(-1*(v-vt-10)/40)
def inputspike(t):
if int(t) in a :
spiketrain.append(t)
def f(t,X):
V = X[0]
m = X[1]
h = X[2]
n = X[3]
inputspike(t)
g_syn = synapse(t, spiketrain[-1])
syn = 0.5* g_syn * (V - 0)
global syn_inst
syn_inst = g_syn
dydt = sp.zeros([1, len(X)])[0]
dydt[0] = - 50*m**3*h*(V-60) - 10*n**4*(V+100) - syn - 0.1*(V + 70)
dydt[1] = alpha_m(V, -45)*(1-m) - beta_m(V, -45)*m
dydt[2] = alpha_h(V, -45)*(1-h) - beta_h(V, -45)*h
dydt[3] = alpha_n(V, -45)*(1-n) - beta_n(V, -45)*n
return dydt
t_start = 0.0
t_end = 2000
dt = 0.1
num_steps = int(sp.floor((t_end - t_start) / dt) + 1)
a = sp.zeros([1,int(t_end/100)])[0]
a[0] = 500 #so the model settles
sp.random.seed(0)
for i in range(1, len(a)):
a[i] = a[i-1] + int(round(sp.random.exponential(0.1)*1000, 0))
r = integrate.ode(f).set_integrator('vode', nsteps = num_steps,
method='bdf')
X_start = [-70, 0, 1,0]
r.set_initial_value(X_start, t_start)
t = sp.zeros(num_steps)
syn = sp.zeros(num_steps)
X = sp.zeros((len(X_start),num_steps))
X[:,0] = X_start
syn[0] = 0
t[0] = t_start
k = 1
while r.successful() and k < num_steps:
r.integrate(r.t + dt)
# Store the results to plot later
t[k] = r.t
syn[k] = syn_inst
X[:,k] = r.y
k += 1
plt.plot(t,syn)
plt.show()
Problem:
I find that when I actually run the code, time t in the solver appears to go back and forth, which results in spiketrain[-1] being greater than t, and the value syn becoming very negative and significantly messing up my simulations (you can see the negative values in the plot if the code is run).
I am guessing it has something to do with variable time steps in the solver, so I was wondering if it is possible to restrict time to only forward (positive) propagation.
Thanks
The solver do actually go back and forth, and I think also because of the variable time stepping. But I think the difficulty comes from that the result of f(t, X) is not only a function of t and X but of the previous call made to this function, which is not a good idea.
Your code works by replacing:
inputspike(t)
g_syn = synapse(t, spiketrain[-1])
by:
last_spike_date = np.max( a[a<t] )
g_syn = synapse(t, last_spike_date)
And by setting an "old event" for the "settle time" with a = np.insert(a, 0, -1e4). This is needed to always have a last_spike_date defined (see the comment in the code below).
Here is a modified version of your code:
I modified how the time of the last spike if found (using this time the Numpy function searchsorted so that the function can be vectorized). I also modified the way the array a is created. This is not my field, so maybe I misunderstood the intent.
I used solve_ivp instead of ode but still with a BDF solver (However it's not the same implementation as in ode which is in Fortran).
import numpy as np # rather than scipy
import matplotlib.pylab as plt
from scipy.integrate import solve_ivp
def synapse(t, t0):
tau_1 = 5.3
tau_2 = 0.05
tau_rise = (tau_1 * tau_2) / (tau_1 - tau_2)
B = ((tau_2 / tau_1)**(tau_rise / tau_1) - (tau_2 / tau_1)**(tau_rise / tau_2)) ** -1
return B*(np.exp(-(t - t0) / tau_1) - np.exp(-(t - t0) / tau_2))
def alpha_m(v, vt):
return -0.32*(v - vt -13)/(np.exp(-1*(v-vt-13)/4)-1)
def beta_m(v, vt):
return 0.28 * (v - vt - 40) / (np.exp((v- vt - 40) / 5) - 1)
def alpha_h(v, vt):
return 0.128 * np.exp(-1 * (v - vt - 17) / 18)
def beta_h(v, vt):
return 4 / (np.exp(-1 * (v - vt - 40) / 5) + 1)
def alpha_n(v, vt):
return -0.032*(v - vt - 15)/(np.exp(-1*(v-vt-15)/5) - 1)
def beta_n(v, vt):
return 0.5* np.exp(-1*(v-vt-10)/40)
def f(t, X):
V = X[0]
m = X[1]
h = X[2]
n = X[3]
# Find the largest value in `a` before t:
last_spike_date = a[ a.searchsorted(t, side='right') - 1 ]
# Another simpler way to write this is:
# last_spike_date = np.max( a[a<t] )
# but didn't work with an array for t
g_syn = synapse(t, last_spike_date)
syn = 0.5 * g_syn * (V - 0)
dVdt = - 50*m**3*h*(V-60) - 10*n**4*(V+100) - syn - 0.1*(V + 70)
dmdt = alpha_m(V, -45)*(1-m) - beta_m(V, -45)*m
dhdt = alpha_h(V, -45)*(1-h) - beta_h(V, -45)*h
dndt = alpha_n(V, -45)*(1-n) - beta_n(V, -45)*n
return [dVdt, dmdt, dhdt, dndt]
# Define the spike events:
nbr_spike = 20
beta = 100
first_spike_date = 500
np.random.seed(0)
a = np.cumsum( np.random.exponential(beta, size=nbr_spike) ) + first_spike_date
a = np.insert(a, 0, -1e4) # set a very old spike at t=-1e4
# it is a hack in order to set a t0 for t<first_spike_date (model settle time)
# so that `synapse(t, t0)` can be called regardless of t
# synapse(t, -1e4) = 0 for t>0
# Solve:
t_start = 0.0
t_end = 2000
X_start = [-70, 0, 1,0]
sol = solve_ivp(f, [t_start, t_end], X_start, method='BDF', max_step=1, vectorized=True)
print(sol.message)
# Graph
V, m, h, n = sol.y
plt.plot(sol.t, V);
plt.xlabel('time'); plt.ylabel('V');
which gives:

note: There is an events parameters in solve_ivp which could be useful.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With