Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Matlab filtfilt() function implementation in Java

Has anyone tried implementing matlab's filtfilt() function in Java (or at least in C++)? If you guys have an algorithm, that would be of great help.

like image 978
Bahubali Angol Avatar asked Dec 01 '14 01:12

Bahubali Angol


2 Answers

Allright, I know this question is ancient, but maybe I can be of help to someone else who winds up here wondering what filtfilt actually does.

Although it is obvious from the docs that filtfilt does forward-backward (a.k.a. zero-phase) filtering, it was not so obvious to me how it deals with things like padding and initial conditions.

As I couldn't find any other answers here (nor elsewhere) with sufficient information about these implementation details of filtfilt, I implemented a simplified version of Python's scipy.signal.filtfilt, based on its source and documentation (so, not Java, nor C++, but Python). I believe the scipy version works the same way as Matlab's.

To keep things simple, the code below was written specifically for a second order IIR filter, and it assumes the coefficient vectors a and b are known (e.g. obtained from scipy.signal.butter, or calculated by hand).

It matches the filtfilt default behavior, using odd padding of length 3 * max(len(a), len(b)), which is applied before the forward pass. The initial state is found using the approach from scipy.signal.lfilter_zi (docs).

Disclaimer: This code is only intended to provide some insight into certain implementation details of filtfilt, so the goal is clarity instead of computational efficiency/performance. The scipy.signal.filtfilt implementation is much faster (e.g. 100x faster according to a quick & dirty timeit test on my system).

import numpy


def custom_filter(b, a, x):
    """ 
    Filter implemented using state-space representation.

    Assume a filter with second order difference equation (assuming a[0]=1):

        y[n] = b[0]*x[n] + b[1]*x[n-1] + b[2]*x[n-2] + ...
                         - a[1]*y[n-1] - a[2]*y[n-2]

    """
    # State space representation (transposed direct form II)
    A = numpy.array([[-a[1], 1], [-a[2], 0]])
    B = numpy.array([b[1] - b[0] * a[1], b[2] - b[0] * a[2]])
    C = numpy.array([1.0, 0.0])
    D = b[0]

    # Determine initial state (solve zi = A*zi + B, see scipy.signal.lfilter_zi)
    zi = numpy.linalg.solve(numpy.eye(2) - A, B)

    # Scale the initial state vector zi by the first input value
    z = zi * x[0]

    # Apply filter
    y = numpy.zeros(numpy.shape(x))
    for n in range(len(x)):
        # Determine n-th output value (note this simplifies to y[n] = z[0] + b[0]*x[n])
        y[n] = numpy.dot(C, z) + D * x[n]
        # Determine next state (i.e. z[n+1])
        z = numpy.dot(A, z) + B * x[n]
    return y

def custom_filtfilt(b, a, x):
    # Apply 'odd' padding to input signal
    padding_length = 3 * max(len(a), len(b))  # the scipy.signal.filtfilt default
    x_forward = numpy.concatenate((
        [2 * x[0] - xi for xi in x[padding_length:0:-1]],
        x,
        [2 * x[-1] - xi for xi in x[-2:-padding_length-2:-1]]))

    # Filter forward
    y_forward = custom_filter(b, a, x_forward)

    # Filter backward
    x_backward = y_forward[::-1]  # reverse
    y_backward = custom_filter(b, a, x_backward)

    # Remove padding and reverse
    return y_backward[-padding_length-1:padding_length-1:-1]

Note that this implementation does not require scipy. Moreover, it can easily be adapted to work in pure python, without even numpy, by writing out the solution for zi and using lists instead of numpy arrays. This even comes with a substantial performance benefit, because accessing individual numpy array elements in a python loop is much slower than accessing list elements.

The filter itself is implemented here in a simple Python loop. It uses the state space representation, because this is used anyway to determine the initial conditions (see scipy.signal.lfilter_zi). I believe that the actual scipy implementation of the linear filter (i.e. scipy.signal.sigtools._linear_filter) does something similar in C, as can be seen here (thanks to this answer).

Here's some code providing a (very basic) check for equality of the scipy output and custom output:

import numpy
import numpy.testing
import scipy.signal
from matplotlib import pyplot
from . import custom_filtfilt

def sinusoid(sampling_frequency_Hz=50.0, signal_frequency_Hz=1.0, periods=1.0,
             amplitude=1.0, offset=0.0, phase_deg=0.0, noise_std=0.1):
    """
    Create a noisy test signal sampled from a sinusoid (time series)

    """
    signal_frequency_rad_per_s = signal_frequency_Hz * 2 * numpy.pi
    phase_rad = numpy.radians(phase_deg)
    duration_s = periods / signal_frequency_Hz
    number_of_samples = int(duration_s * sampling_frequency_Hz)
    time_s = (numpy.array(range(number_of_samples), float) /
              sampling_frequency_Hz)
    angle_rad = signal_frequency_rad_per_s * time_s
    signal = offset + amplitude * numpy.sin(angle_rad - phase_rad)
    noise = numpy.random.normal(loc=0.0, scale=noise_std, size=signal.shape)
    return signal + noise


if __name__ == '__main__':
    # Design filter
    sampling_freq_hz = 50.0
    cutoff_freq_hz = 2.5
    order = 2
    normalized_frequency = cutoff_freq_hz * 2 / sampling_freq_hz
    b, a = scipy.signal.butter(order, normalized_frequency, btype='lowpass')

    # Create test signal
    signal = sinusoid(sampling_frequency_Hz=sampling_freq_hz,
                      signal_frequency_Hz=1.5, periods=3, amplitude=2.0,
                      offset=2.0, phase_deg=25)

    # Apply zero-phase filters
    filtered_custom = custom_filtfilt(b, a, signal)
    filtered_scipy = scipy.signal.filtfilt(b, a, signal)

    # Verify near-equality
    numpy.testing.assert_array_almost_equal(filtered_custom, filtered_scipy,
                                            decimal=12)

    # Plot result
    pyplot.subplot(1, 2, 1)
    pyplot.plot(signal)
    pyplot.plot(filtered_scipy)
    pyplot.plot(filtered_custom, '.')
    pyplot.title('raw vs filtered signals')
    pyplot.legend(['raw', 'scipy filtfilt', 'custom filtfilt'])
    pyplot.subplot(1, 2, 2)
    pyplot.plot(filtered_scipy-filtered_custom)
    pyplot.title('difference (scipy vs custom)')
    pyplot.show()

This basic comparison yields a figure like below, suggesting equality to at least 14 decimals, for this specific case (machine precision, I guess?):

filtfilt scipy vs custom implementation

like image 146
djvg Avatar answered Sep 26 '22 14:09

djvg


Here is my implementation in C++ of the filtfilt algorithm as implemented in MATLAB. Hope this helps you.

like image 39
Darien Pardinas Avatar answered Sep 22 '22 14:09

Darien Pardinas