I have an array Z
in Julia which represents an image of a 2D Gaussian function. I.e. Z[i,j]
is the height of the Gaussian at pixel i,j. I would like to determine the parameters of the Gaussian (mean and covariance), presumably by some sort of curve fitting.
I've looked into various methods for fitting Z
: I first tried the Distributions
package, but it is designed for a somewhat different situation (randomly selected points). Then I tried the LsqFit
package, but it seems to be tailored for 1D fitting, as it is throwing errors when I try to fit 2D data, and there is no documentation I can find to lead me to a solution.
How can I fit a Gaussian to a 2D array in Julia?
The simplest approach is to use Optim.jl. Here is an example code (it was not optimized for speed, but it should show you how you can handle the problem):
using Distributions, Optim
# generate some sample data
true_d = MvNormal([1.0, 0.0], [2.0 1.0; 1.0 3.0])
const xr = -3:0.1:3
const yr = -3:0.1:3
const s = 5.0
const m = [s * pdf(true_d, [x, y]) for x in xr, y in yr]
decode(x) = (mu=x[1:2], sig=[x[3] x[4]; x[4] x[5]], s=x[6])
function objective(x)
mu, sig, s = decode(x)
try # sig might be infeasible so we have to handle this case
est_d = MvNormal(mu, sig)
ref_m = [s * pdf(est_d, [x, y]) for x in xr, y in yr]
sum((a-b)^2 for (a,b) in zip(ref_m, m))
catch
sum(m)
end
end
# test for an example starting point
result = optimize(objective, [1.0, 0.0, 1.0, 0.0, 1.0, 1.0])
decode(result.minimizer)
Alternatively you could use constrained optimization e.g. like this:
using Distributions, JuMP, NLopt
true_d = MvNormal([1.0, 0.0], [2.0 1.0; 1.0 3.0])
const xr = -3:0.1:3
const yr = -3:0.1:3
const s = 5.0
const Z = [s * pdf(true_d, [x, y]) for x in xr, y in yr]
m = Model(solver=NLoptSolver(algorithm=:LD_MMA))
@variable(m, m1)
@variable(m, m2)
@variable(m, sig11 >= 0.001)
@variable(m, sig12)
@variable(m, sig22 >= 0.001)
@variable(m, sc >= 0.001)
function obj(m1, m2, sig11, sig12, sig22, sc)
est_d = MvNormal([m1, m2], [sig11 sig12; sig12 sig22])
ref_Z = [sc * pdf(est_d, [x, y]) for x in xr, y in yr]
sum((a-b)^2 for (a,b) in zip(ref_Z, Z))
end
JuMP.register(m, :obj, 6, obj, autodiff=true)
@NLobjective(m, Min, obj(m1, m2, sig11, sig12, sig22, sc))
@NLconstraint(m, sig12*sig12 + 0.001 <= sig11*sig22)
setvalue(m1, 0.0)
setvalue(m2, 0.0)
setvalue(sig11, 1.0)
setvalue(sig12, 0.0)
setvalue(sig22, 1.0)
setvalue(sc, 1.0)
status = solve(m)
getvalue.([m1, m2, sig11, sig12, sig22, sc])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With