Using example from Andrew Ng's class (finding parameters for Linear Regression using normal equation):
With Python:
X = np.array([[1, 2104, 5, 1, 45], [1, 1416, 3, 2, 40], [1, 1534, 3, 2, 30], [1, 852, 2, 1, 36]])
y = np.array([[460], [232], [315], [178]])
θ = ((np.linalg.inv(X.T.dot(X))).dot(X.T)).dot(y)
print(θ)
Result:
[[ 7.49398438e+02]
[ 1.65405273e-01]
[ -4.68750000e+00]
[ -4.79453125e+01]
[ -5.34570312e+00]]
With Julia:
X = [1 2104 5 1 45; 1 1416 3 2 40; 1 1534 3 2 30; 1 852 2 1 36]
y = [460; 232; 315; 178]
θ = ((X' * X)^-1) * X' * y
Result:
5-element Array{Float64,1}:
207.867
0.0693359
134.906
-77.0156
-7.81836
Furthermore, when I multiple X by Julia's — but not Python's — θ, I get numbers close to y.
I can't figure out what I am doing wrong. Thanks!
pinv(X) which corresponds to the pseudo inverse is more broadly applicable than inv(X), which X^-1 equates to. Neither Julia nor Python do well using inv, but in this case apparently Julia does better.
but if you change the expression to
julia> z=pinv(X'*X)*X'*y
5-element Array{Float64,1}:
188.4
0.386625
-56.1382
-92.9673
-3.73782
you can verify that X*z = y
julia> X*z
4-element Array{Float64,1}:
460.0
232.0
315.0
178.0
A more numerically robust approach in Python, without having to do the matrix algebra yourself is to use numpy.linalg.lstsq
to do the regression:
In [29]: np.linalg.lstsq(X, y)
Out[29]:
(array([[ 188.40031942],
[ 0.3866255 ],
[ -56.13824955],
[ -92.9672536 ],
[ -3.73781915]]),
array([], dtype=float64),
4,
array([ 3.08487554e+03, 1.88409728e+01, 1.37100414e+00,
1.97618336e-01]))
(Compare the solution vector with @waTeim's answer in Julia).
You can see the source of the ill-conditioning by printing the matrix inverse you're calculating:
In [30]: np.linalg.inv(X.T.dot(X))
Out[30]:
array([[ -4.12181049e+13, 1.93633440e+11, -8.76643127e+13,
-3.06844458e+13, 2.28487459e+12],
[ 1.93633440e+11, -9.09646601e+08, 4.11827338e+11,
1.44148665e+11, -1.07338299e+10],
[ -8.76643127e+13, 4.11827338e+11, -1.86447963e+14,
-6.52609055e+13, 4.85956259e+12],
[ -3.06844458e+13, 1.44148665e+11, -6.52609055e+13,
-2.28427584e+13, 1.70095424e+12],
[ 2.28487459e+12, -1.07338299e+10, 4.85956259e+12,
1.70095424e+12, -1.26659193e+11]])
Eeep!
Taking the dot product of this with X.T
leads to a catastrophic loss of precision.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With