Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

A better method to implement the alpha synapse function in MATLAB

I have implemented a leaky integrate and fire neuron system that gives output in the form of an alpha function alpha = t/tau * exp(1 - (t/tau)). However, the piece of code that I use to implement it takes up atleast 80% of the runtime (approximately 4s out of the total 5s). During the course of the program this prt gets called at least 30000 times while the computealphasynapseoutput function gets called at least 1.5 million times. So I want to reduce the runtime of this part. I have tried using arrayfun to implement it but that takes up a lot more time than this.

Can anyone suggest a more efficient implementation for this code??

To implement the alpha synapse I used the following piece of code:

% Get the identity of the currently active neurons
idAllActiveNeuron = idAllActiveNeuron > 0;
if any(idAllActiveNeuron) % Only run if atleast one neuron is active
    for iActiveNeuron = find(idAllActiveNeuron).' % run for each active neuron
        %% synapticOutputArray stores the synaptic output for each neuron for each time instant
        %% iIntegration is the time instant
        %% spikeTimesArray is a cell array that is composed of spike times for 
        % each neuron. So if I have 5 neurons and neuron 4 spikes twice 
        % while neuron 5 spikes once, spikeTimesArray would be something 
        % like {[], [], [], [0.0023, 0.0034], [0.0675]}
        %% integrationInstant would be a time value like 0.0810
        %% tau_syn stores the value of tau for each neuron

        synapticOutputArray(iActiveNeuron, iIntegration) = computealphasynapseoutput(spikeTimesArray{iActiveNeuron}, integrationInstant, tau_syn(iActiveNeuron));

    end % iActiveNeuron
end

The function computealphasynapse is implemented as follows:

function synapticOutput = computealphasynapseoutput(firingTime, integrationInstant, tauSyn)
%%COMPUTEALPHASYNAPSEOUTPUT Calculates the synaptic output over all
%previous spikes of the neuron at a particular time instant using the
%alpha synapse function
%
% Usage:
%   synapticOutput = computealphasynapseoutput(firingTime, integrationInstant, tauSyn)
%
% Inputs:
%           firingTime: vector of previous firing times (in seconds)
%   integrationInstant: current integration time instant (in seconds)
%               tauSyn: synaptic time constant of the neuron (in seconds)
%
% Output:
%   synapticOutput: Synaptic output of the neuron at current time summed
%                   over all firing instances
%

% Calculate the time difference of firing from current time
timeDifference = (integrationInstant - firingTime) / tauSyn;
% Calculate and sum the synaptic output for each spike
synapticOutput = sum(timeDifference .* exp(1 - timeDifference));

end % computealphasynapseoutput

Edit:

I have finally awarded the bounty for this question to gnovice for his awesome answer. It helped me to shave off an entire 40 seconds from my simulation time (from 68 to 28 s). I hope it works for people in the future too. I would also like to acknowledge MrAzzaman and qbzenker for taking their time to answer the question and teaching me some cool new approaches. Also, others who commented and helped me. Thanks

like image 867
ammportal Avatar asked May 21 '17 10:05

ammportal


3 Answers

This solution may be a little weird, and very specific to the problem at hand, but very efficient. The idea is to reorganize the formula for calculating the summated alpha synapse output in such a way as to minimize the computations done, specifically reducing the number of exponentiations computed using exp, which is computationally expensive. Here's the reformulation for a single neuron:

enter image description here

For the result we have a function of time f(t) (that computes one exponential for the current time point t) times a linear function in t with time-invariant parameters A and B. Note that A and B only depend on the spike times, summing an exponential over all previous spike occurrences. We can store these two parameters A and B for each neuron, and when a new spike occurs in a neuron, we simply compute a new pair of terms and add them to the existing values for A and B.

Here's the caveat: notice the values of the exponents. For A and B the value is ts/tau, and if it gets too large it could overflow the calculation, resulting in Inf. For the parameters we're using here, that won't happen. The value of ts will be 1 at its largest (since we're only simulating one second) and the smallest value of tau is 0.01 sec. The exponent will be at most 100, giving values that are large (around 10^43), but easily handled by a double variable. Likewise, the exponent in f(t) will have a largest negative value of -99, giving very small values (around 10^-43), but still easily handled by a double variable without underflowing to 0. However, please bear in mind that using smaller tau values or a longer simulation time could cause trouble.

So, how do we implement this? Here are the relevant pieces of code you'll have to add/modify (note that you won't even need spikeTimesArray any more):

% You'll have to initialize storage for A and B:
A = zeros(1, nNeurons);
B = zeros(1, nNeurons);
idAllActiveNeuron = false(1, nNeurons);

...

% When you are generating new spikes, modify A and B like so:
index = ...;  % A vector of neuron indices where new spikes occur this integration step
if ~isempty(index)
  expTerm = exp(integrationInstant./tau_syn(index));
  A(index) = A(index)+expTerm;
  B(index) = B(index)-integrationInstant.*expTerm;
  idAllActiveNeuron(index) = true;
end

...

% Updating the synaptic output no longer requires a loop:
if any(idAllActiveNeuron)
  synapticOutputArray(idAllActiveNeuron, iIntegration) = ...
    (integrationInstant.*A(idAllActiveNeuron)+B(idAllActiveNeuron)).*...
    exp(1-integrationInstant./tau_syn(idAllActiveNeuron))./tau_syn(idAllActiveNeuron);
end

And how well does it do? Here are some measurements I made (using timeit) with different average spike firing frequencies for 1101 neurons simulated with a time step of 0.1 msec for one second (10001 time points):

enter image description here

The more spikes that occurred, the more time it took the original solution to complete. However, the time required for the reformulated solution was near constant and never exceeded 0.6 seconds. I can also confirm that the outputs for the two approaches are equal (maximum differences are on the order of 10^-13 for the synaptic output waveforms).

like image 74
gnovice Avatar answered Oct 17 '22 19:10

gnovice


I tested the following code with a set of randomly produced firing times and taus, and it consistently ran about 7-8x faster. It's a bit strange, but bear with me.

The main idea is that MATLAB excels at operations on matrices, but can be pretty slow if you have to break those operations up, for instance with your for loop. The idea, then, is to try and do the whole cell array in one go. This would be easier if you could modify how the firing times are stored (i.e., not in a cell array), but I've made do with what you've got.

Basically, what we're going to do is convert the cell array of firing times to one big vector, calculate the synaptic output for all of them at once, then sum them according to their neuron. So, for instance, we'll eventually get the following:

% given spikeTimesArray = {[], [], [], [0.0023, 0.0034], [0.0675],...}
firingTimes = [0.0023, 0.0034, 0.0675,...];
neuronInd = [4,4,5,...];

where the neuronInd vector is the corresponding neuron index of each firing time. This is the vector that, if you could create it when your firing times array was created, this whole process would be that much quicker. As it is, the process to calculate neuronInd is a bit opaque.

Regardless, we calculate this as follows:

% Calculate the number of time spikes for each neuron
numSpikes = cellfun(@numel,spikeTimesArray);
% Determine the index of all neurons with spikes 
n = find(numSpikes);
% Calculate the starting indices for the neuronInd vector
m = cumsum([1,numSpikes]);
% Generate the neuronInd vector
neuronInd = interp1(m(n),n,1:(m(end)-1),'previous');
% Messy kludge to get around how interp1 works
neuronInd(isnan(neuronInd)) = n(end);
%Calculate timeDifference
delt = ([spikeTimesArray{:}] - integratonInstant)./tau_syn(neuronInd);
synapticOutputArray(:, iIntegration) = accumarray(neuronInd',delt.*exp(1-delt));

Hope this helps.

like image 2
MrAzzaman Avatar answered Oct 17 '22 17:10

MrAzzaman


This code gives me a significant speed up under the right conditions (typically low max number of spike times for any given neuron at a given time and a number of neurons > 5). With a low number of neurons, it is similar to the performance of your code above.

Note that I haven't tested this code with your real inputs so you might have to do some tweaking in order to make it work. Also, it might actually be slower with your data than the dummy data I used. The idea behind the speedup is vectorization -- essentially doing operations on the entire vector instead of splitting it up over a bunch of for loops.

This will require you to make spikeTimesArray a matrix with NaNs instead of missing values. So something like this:

 spikeTimesArray={[], [], [], [0.0023, 0.0034], [0.0675]}

would become

spikeTimesArray=[[NaN;NaN], [NaN;NaN], [NaN;NaN], [0.0023; 0.0034], [0.0675;NaN]]

and instead of using for loops, you'd just do the entire vector operation:

idAllActiveNeuron = idAllActiveNeuron > 0;
if any(idAllActiveNeuron) % Only run if atleast one neuron is active
   % for iActiveNeuron = find(idAllActiveNeuron) % ELIMINATE THIS FOR LOOP

iActiveNeuron = find(idAllActiveNeuron);

timeDifference = (integrationInstant -spikeTimesArray2(:,iActiveNeuron))...
                  ./ tau_syn(iActiveNeuron);

synapticOutputArray(iActiveNeuron, iIntegration) = ...
                                sum(timeDifference .* exp(1 - timeDifference),'omitnan');

Note that in this implementation I am assuming that spikeTimesArray2 is an NxM matrix where N corresponds to the number of maximal spikes for a given neuron and M is the total number of neurons. Also, tau_syn is a row vector.

like image 1
qbzenker Avatar answered Oct 17 '22 19:10

qbzenker