Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Simulating a neuron spike train in python

The model I'm working on has a neuron (modeled by the Hodgkin-Huxley equations), and the neuron itself receives a bunch of synaptic inputs from other neurons because it is in a network. The standard way to model the inputs is with a spike train made up of a bunch of delta function pulses that arrive at a specified rate, as a Poisson process. Some of the pulses provide an excitatory reaction to the neuron, and some provide an inhibitory pulse. So the synaptic current should look like this:

enter image description here

Here, Ne is the number of excitatory neurons, Ni is inhibitory, the h's are either 0 or 1 (1 with probability p) representing whether or not a spike was successfully transmitted, and the $t_k^l$ in the delta function is the discharge time of the l^th spike of the kth neuron (same for the $t_m^n$). So the basic idea behind how we tried coding this was to suppose first I had 100 neurons providing pulses into my HH neuron (80 of which are excitatory, 20 of which are inhibitory). We then formed an array where one column enumerated the neurons (so that neurons #0-79 were excitatory ones and #80-99 were inhibitory). We then checked to see if there is a spike in some time interval, and if there was, choose a random number between 0-1 and if it's below my specified probability p, then assign it the number 1, otherwise make it 0. We then plot the voltage as a function of time to look to see when the neuron spikes.

I think the code works, BUT the problem is that as soon as I add some more neurons in the network (one paper claimed they used 5000 total neurons), it takes forever to run, which is just unfeasible for doing numerical simulations. My question is: is there a better way to simulate a spike train pulsing into a neuron so that the computation is substantially faster for a large number of neurons in the network? Here is the code we tried: (it's a little long because the HH equations are quite detailed):

import scipy as sp
import numpy as np
import pylab as plt

#Constants
C_m  =   1.0 #membrane capacitance, in uF/cm^2"""
g_Na = 120.0 #Sodium (Na) maximum conductances, in mS/cm^2""
g_K  =  36.0 #Postassium (K) maximum conductances, in mS/cm^2"""
g_L  =   0.3 #Leak maximum conductances, in mS/cm^2"""
E_Na =  50.0 #Sodium (Na) Nernst reversal potentials, in mV"""
E_K  = -77.0 #Postassium (K) Nernst reversal potentials, in mV"""
E_L  = -54.387 #Leak Nernst reversal potentials, in mV"""

def poisson_spikes(t, N=100, rate=1.0 ):
    spks = []
    dt = t[1] - t[0]
    for n in range(N):
        spkt = t[np.random.rand(len(t)) < rate*dt/1000.] #Determine list of times of spikes
        idx = [n]*len(spkt) #Create vector for neuron ID number the same length as time
        spkn = np.concatenate([[idx], [spkt]], axis=0).T #Combine tw lists
        if len(spkn)>0:        
            spks.append(spkn)
    spks = np.concatenate(spks, axis=0)
    return spks

N = 100
N_ex = 80 #(0..79)
N_in = 20 #(80..99)
G_ex = 1.0
K = 4

dt = 0.01
t = sp.arange(0.0, 300.0, dt) #The time to integrate over """
ic = [-65, 0.05, 0.6, 0.32]

spks =  poisson_spikes(t, N, rate=10.)

def alpha_m(V):
        return 0.1*(V+40.0)/(1.0 - sp.exp(-(V+40.0) / 10.0))

def beta_m(V):
        return 4.0*sp.exp(-(V+65.0) / 18.0)

def alpha_h(V):
        return 0.07*sp.exp(-(V+65.0) / 20.0)

def beta_h(V):
        return 1.0/(1.0 + sp.exp(-(V+35.0) / 10.0))

def alpha_n(V):
        return 0.01*(V+55.0)/(1.0 - sp.exp(-(V+55.0) / 10.0))

def beta_n(V):
        return 0.125*sp.exp(-(V+65) / 80.0)

def I_Na(V, m, h):
        return g_Na * m**3 * h * (V - E_Na)

def I_K(V, n):
        return g_K  * n**4 * (V - E_K)

def I_L(V):
        return g_L * (V - E_L)

def I_app(t):
        return 3

def I_syn(spks, t):
    """
    Synaptic current
    spks = [[synid, t],]
    """
    exspk = spks[spks[:,0]<N_ex] # Check for all excitatory spikes
    delta_k = exspk[:,1] == t # Delta function
    if sum(delta_k) > 0:
        h_k = np.random.rand(len(delta_k)) < 0.5 # p = 0.5
    else:
        h_k = 0

    inspk = spks[spks[:,0] >= N_ex] #Check remaining neurons for inhibitory spikes
    delta_m = inspk[:,1] == t #Delta function for inhibitory neurons
    if sum(delta_m) > 0:
        h_m = np.random.rand(len(delta_m)) < 0.5 #p =0.5
    else:
        h_m = 0

    isyn = C_m*G_ex*(np.sum(h_k*delta_k) - K*np.sum(h_m*delta_m))

    return  isyn


def dALLdt(X, t):
        V, m, h, n = X
        dVdt = (I_app(t)+I_syn(spks,t)-I_Na(V, m, h) - I_K(V, n) - I_L(V)) / C_m
        dmdt = alpha_m(V)*(1.0-m) - beta_m(V)*m
        dhdt = alpha_h(V)*(1.0-h) - beta_h(V)*h
        dndt = alpha_n(V)*(1.0-n) - beta_n(V)*n
        return np.array([dVdt, dmdt, dhdt, dndt])

X = [ic]
for i in t[1:]:
    dx = (dALLdt(X[-1],i))
    x = X[-1]+dt*dx
    X.append(x)

X = np.array(X)    
V = X[:,0]        
m = X[:,1]
h = X[:,2]
n = X[:,3]
ina = I_Na(V, m, h)
ik = I_K(V, n)
il = I_L(V)

plt.figure()
plt.subplot(3,1,1)
plt.title('Hodgkin-Huxley Neuron')
plt.plot(t, V, 'k')
plt.ylabel('V (mV)')

plt.subplot(3,1,2)
plt.plot(t, ina, 'c', label='$I_{Na}$')
plt.plot(t, ik, 'y', label='$I_{K}$')
plt.plot(t, il, 'm', label='$I_{L}$')
plt.ylabel('Current')
plt.legend()

plt.subplot(3,1,3)
plt.plot(t, m, 'r', label='m')
plt.plot(t, h, 'g', label='h')
plt.plot(t, n, 'b', label='n')
plt.ylabel('Gating Value')
plt.legend()

plt.show()

I'm not familiar with other packages designed specifically for neural networks, but I wanted to write my own, mainly because I plan to do stochastic analysis which requires quite a bit of mathematical detail, and I don't know if those packages provide such detail.

like image 834
Brenton Avatar asked Mar 17 '16 01:03

Brenton


2 Answers

Profiling shows that most of your time is being spent in these two lines:

    if sum(delta_k) > 0:

and

    if sum(delta_m) > 0:

Changing each of these to:

    if np.any(...)

speeds everything up by a factor of 10. Take a look at kernprof if you'd like to do more line by line profiling: https://github.com/rkern/line_profiler

like image 151
welch Avatar answered Sep 22 '22 09:09

welch


In complement to welch's answer, you can use scipy.integrate.odeint to accelerate integration: replacing

X = [ic]
for i in t[1:]:
    dx = (dALLdt(X[-1],i))
    x = X[-1]+dt*dx
    X.append(x)

by

from scipy.integrate import odeint
X=odeint(dALLdt,ic,t)

speeds the calculation by more than 10 on my computer.

like image 34
JPG Avatar answered Sep 21 '22 09:09

JPG