Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use Java Math Commons CurveFitter?

How do I use Math Commons CurveFitter to fit a function to a set of data? I was told to use CurveFitter with LevenbergMarquardtOptimizer and ParametricUnivariateFunction, but I don't know what to write in the ParametricUnivariateFunction gradient and value methods. Besides, after writing them, how to get the fitted function parameters? My function:

public static double fnc(double t, double a, double b, double c){
  return a * Math.pow(t, b) * Math.exp(-c * t);
}
like image 689
Italo Maia Avatar asked Jul 04 '12 20:07

Italo Maia


1 Answers

So, this is an old question, but I ran into the same issue recently, and ended up having to delve into mailing lists and the Apache Commons Math source code to figure it out.

This API is remarkably poorly documented, but in the current version of Apache Common Math (3.3+), there are two parts, assuming you have a single variable with multiple parameters: the function to fit with (which implements ParametricUnivariateFunction) and the curve fitter (which extends AbstractCurveFitter).

Function to Fit

  • public double value(double t, double... parameters)
    • Your equation. This is where you would put your fnc logic.
  • public double[] gradient(double t, double... parameters)
    • Returns an array of partial derivative of the above with respect to each parameters. This calculator may be helpful if (like me) you're rusty on your calculus, but any good computer algebra system can calculate these values.

Curve Fitter

  • protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> points)
    • Sets up a bunch of boilerplate crap, and returns a least squares problem for the fitter to use.

Putting it all together, here's an example solution in your specific case:

import java.util.*;
import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
import org.apache.commons.math3.fitting.AbstractCurveFitter;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
import org.apache.commons.math3.fitting.WeightedObservedPoint;
import org.apache.commons.math3.linear.DiagonalMatrix;

class MyFunc implements ParametricUnivariateFunction {
    public double value(double t, double... parameters) {
        return parameters[0] * Math.pow(t, parameters[1]) * Math.exp(-parameters[2] * t);
    }

    // Jacobian matrix of the above. In this case, this is just an array of
    // partial derivatives of the above function, with one element for each parameter.
    public double[] gradient(double t, double... parameters) {
        final double a = parameters[0];
        final double b = parameters[1];
        final double c = parameters[2];

        return new double[] {
            Math.exp(-c*t) * Math.pow(t, b),
            a * Math.exp(-c*t) * Math.pow(t, b) * Math.log(t),
            a * (-Math.exp(-c*t)) * Math.pow(t, b+1)
        };
    }
}

public class MyFuncFitter extends AbstractCurveFitter {
    protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> points) {
        final int len = points.size();
        final double[] target  = new double[len];
        final double[] weights = new double[len];
        final double[] initialGuess = { 1.0, 1.0, 1.0 };

        int i = 0;
        for(WeightedObservedPoint point : points) {
            target[i]  = point.getY();
            weights[i] = point.getWeight();
            i += 1;
        }

        final AbstractCurveFitter.TheoreticalValuesFunction model = new
            AbstractCurveFitter.TheoreticalValuesFunction(new MyFunc(), points);

        return new LeastSquaresBuilder().
            maxEvaluations(Integer.MAX_VALUE).
            maxIterations(Integer.MAX_VALUE).
            start(initialGuess).
            target(target).
            weight(new DiagonalMatrix(weights)).
            model(model.getModelFunction(), model.getModelFunctionJacobian()).
            build();
    }

    public static void main(String[] args) {
        MyFuncFitter fitter = new MyFuncFitter();
        ArrayList<WeightedObservedPoint> points = new ArrayList<WeightedObservedPoint>();

        // Add points here; for instance,
        WeightedObservedPoint point = new WeightedObservedPoint(1.0,
            1.0,
            1.0);
        points.add(point);

        final double coeffs[] = fitter.fit(points);
        System.out.println(Arrays.toString(coeffs));
    }
}
like image 160
i80and Avatar answered Oct 27 '22 23:10

i80and