Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

2D curve fitting in Julia

I have an array Z in Julia which represents an image of a 2D Gaussian function. I.e. Z[i,j] is the height of the Gaussian at pixel i,j. I would like to determine the parameters of the Gaussian (mean and covariance), presumably by some sort of curve fitting.

I've looked into various methods for fitting Z: I first tried the Distributions package, but it is designed for a somewhat different situation (randomly selected points). Then I tried the LsqFit package, but it seems to be tailored for 1D fitting, as it is throwing errors when I try to fit 2D data, and there is no documentation I can find to lead me to a solution.

How can I fit a Gaussian to a 2D array in Julia?

like image 550
Yly Avatar asked Sep 06 '25 11:09

Yly


1 Answers

The simplest approach is to use Optim.jl. Here is an example code (it was not optimized for speed, but it should show you how you can handle the problem):

using Distributions, Optim

# generate some sample data    
true_d = MvNormal([1.0, 0.0], [2.0  1.0; 1.0 3.0])
const xr = -3:0.1:3
const yr = -3:0.1:3
const s = 5.0
const m = [s * pdf(true_d, [x, y]) for x in xr, y in yr]

decode(x) = (mu=x[1:2], sig=[x[3] x[4]; x[4] x[5]], s=x[6])

function objective(x)
    mu, sig, s = decode(x)
    try # sig might be infeasible so we have to handle this case
        est_d = MvNormal(mu, sig)
        ref_m = [s * pdf(est_d, [x, y]) for x in xr, y in yr]
        sum((a-b)^2 for (a,b) in zip(ref_m, m))
    catch
        sum(m)
    end
end

# test for an example starting point
result = optimize(objective, [1.0, 0.0, 1.0, 0.0, 1.0, 1.0])
decode(result.minimizer)

Alternatively you could use constrained optimization e.g. like this:

using Distributions, JuMP, NLopt

true_d = MvNormal([1.0, 0.0], [2.0  1.0; 1.0 3.0])
const xr = -3:0.1:3
const yr = -3:0.1:3
const s = 5.0
const Z = [s * pdf(true_d, [x, y]) for x in xr, y in yr]

m = Model(solver=NLoptSolver(algorithm=:LD_MMA))

@variable(m, m1)
@variable(m, m2)
@variable(m, sig11 >= 0.001)
@variable(m, sig12)
@variable(m, sig22 >= 0.001)
@variable(m, sc >= 0.001)

function obj(m1, m2, sig11, sig12, sig22, sc)
    est_d = MvNormal([m1, m2], [sig11 sig12; sig12 sig22])
    ref_Z = [sc * pdf(est_d, [x, y]) for x in xr, y in yr]
    sum((a-b)^2 for (a,b) in zip(ref_Z, Z))
end

JuMP.register(m, :obj, 6, obj, autodiff=true)
@NLobjective(m, Min, obj(m1, m2, sig11, sig12, sig22, sc))
@NLconstraint(m, sig12*sig12 + 0.001 <= sig11*sig22)

setvalue(m1, 0.0)
setvalue(m2, 0.0)
setvalue(sig11, 1.0)
setvalue(sig12, 0.0)
setvalue(sig22, 1.0)
setvalue(sc, 1.0)

status = solve(m)
getvalue.([m1, m2, sig11, sig12, sig22, sc])
like image 121
Bogumił Kamiński Avatar answered Sep 13 '25 10:09

Bogumił Kamiński