Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can you implement Householder based QR decomposition in Python?

I'm currently trying to implement the Householder based QR decomposition for rectangular matrices as described in http://eprints.ma.man.ac.uk/1192/1/qrupdating_12nov08.pdf (pages 3, 4, 5).

Apparently I got some of the pseudocode wrong though, since (1) my results differ from numpy.qr.linalg() and (2) the matrix R produced by my routines is not an upper triangular matrix.

My code (also available under https://pyfiddle.io/fiddle/afcc2e0e-0857-4cb2-adb5-06ff9b80c9d3/?i=true)

import math
import argparse
import numpy as np
from typing import Union

def householder(alpha: float, x: np.ndarray) -> Union[np.ndarray, int]:
    """
    Computes Householder vector for alpha and x.
    :param alpha:
    :param x:
    :return:
    """

    s = math.pow(np.linalg.norm(x, ord=2), 2)
    v = x

    if s == 0:
        tau = 0
    else:
        t = math.sqrt(alpha * alpha + s)
        v_one = alpha - t if alpha <= 0 else -s / (alpha + t)

        tau = 2 * v_one * v_one / (s + v_one * v_one)
        v /= v_one

    return v, tau


def qr_decomposition(A: np.ndarray, m: int, n: int) -> Union[np.ndarray, np.ndarray]:
    """
    Applies Householder-based QR decomposition on specified matrix A.
    :param A:
    :param m:
    :param n:
    :return:
    """
    H = []
    R = A
    Q = A
    I = np.eye(m, m)

    for j in range(0, n):
        # Apply Householder transformation.
        x = A[j + 1:m, j]
        v_householder, tau = householder(np.linalg.norm(x), x)
        v = np.zeros((1, m))
        v[0, j] = 1
        v[0, j + 1:m] = v_householder

        res = I - tau * v * np.transpose(v)
        R = np.matmul(res, R)
        H.append(res)

    return Q, R

m = 10
n = 8

A = np.random.rand(m, n)
q, r = np.linalg.qr(A)
Q, R = qr_decomposition(A, m, n)

print("*****")
print(Q)
print(q)
print("-----")
print(R)
print(r)

So I'm unclear about how to introduce zeroes to my R matrix/about which part of my code is incorrect. I'd be happy about any pointers! Thanks a lot for your time.

like image 686
justonemorething Avatar asked Nov 26 '18 21:11

justonemorething


People also ask

What is QR decomposition used for?

QR decomposition is often used to solve the linear least squares problem and is the basis for a particular eigenvalue algorithm, the QR algorithm.


1 Answers

There were a bunch of problems/missing details in the notes you linked to. After consulting a few other sources (including this very useful textbook), I was able to come up with a working implementation of something similar.

The working algorithm

Heres the code for a working version of qr_decomposition:

import numpy as np
from typing import Union

def householder(x: np.ndarray) -> Union[np.ndarray, int]:
    alpha = x[0]
    s = np.power(np.linalg.norm(x[1:]), 2)
    v = x.copy()

    if s == 0:
        tau = 0
    else:
        t = np.sqrt(alpha**2 + s)
        v[0] = alpha - t if alpha <= 0 else -s / (alpha + t)

        tau = 2 * v[0]**2 / (s + v[0]**2)
        v /= v[0]

    return v, tau

def qr_decomposition(A: np.ndarray) -> Union[np.ndarray, np.ndarray]:
    m,n = A.shape
    R = A.copy()
    Q = np.identity(m)

    for j in range(0, n):
        # Apply Householder transformation.
        v, tau = householder(R[j:, j])
        H = np.identity(m)
        H[j:, j:] -= tau * v.reshape(-1, 1) @ v
        R = H @ R
        Q = H @ Q

    return Q[:n].T, R[:n]

m = 5
n = 4

A = np.random.rand(m, n)
q, r = np.linalg.qr(A)
Q, R = qr_decomposition(A)

with np.printoptions(linewidth=9999, precision=20, suppress=True):
    print("**** Q from qr_decomposition")
    print(Q)
    print("**** Q from np.linalg.qr")
    print(q)
    print()
    
    print("**** R from qr_decomposition")
    print(R)
    print("**** R from np.linalg.qr")
    print(r)

Output:

**** Q from qr_decomposition
[[ 0.5194188817843675  -0.10699353671401633  0.4322294754656072  -0.7293293270703678 ]
 [ 0.5218635773595086   0.11737804362574514 -0.5171653705211056   0.04467925806590414]
 [ 0.34858177783013133  0.6023104248793858  -0.33329256746256875 -0.03450824948274838]
 [ 0.03371048915852807  0.6655221685383623   0.6127023580593225   0.28795294754791   ]
 [ 0.5789790833500734  -0.411189947884951    0.24337120818874305  0.618041080584351  ]]
**** Q from np.linalg.qr
[[-0.5194188817843672    0.10699353671401617   0.4322294754656068    0.7293293270703679  ]
 [-0.5218635773595086   -0.11737804362574503  -0.5171653705211053   -0.044679258065904115]
 [-0.3485817778301313   -0.6023104248793857   -0.33329256746256863   0.03450824948274819 ]
 [-0.03371048915852807  -0.665522168538362     0.6127023580593226   -0.2879529475479097  ]
 [-0.5789790833500733    0.41118994788495106   0.24337120818874317  -0.6180410805843508  ]]

**** R from qr_decomposition
[[ 0.6894219296137802      1.042676051151294       1.3418719684631446      1.2498925815126485    ]
 [ 0.00000000000000000685  0.7076056836914905      0.29883043386651403     0.41955370595004277   ]
 [-0.0000000000000000097  -0.00000000000000007292  0.5304551654027297      0.18966088433421135   ]
 [-0.00000000000000000662  0.00000000000000008718  0.00000000000000002322  0.6156558913022807    ]]
**** R from np.linalg.qr
[[-0.6894219296137803  -1.042676051151294   -1.3418719684631442  -1.2498925815126483 ]
 [ 0.                  -0.7076056836914905  -0.29883043386651376 -0.4195537059500425 ]
 [ 0.                   0.                   0.53045516540273     0.18966088433421188]
 [ 0.                   0.                   0.                  -0.6156558913022805 ]]

This version of qr_decomposition near exactly reproduces the output of np.linalg.qr. The differences are commented on below.

Numerical precision of the output

The values in the outputs of np.linalg.qr and qr_decomposition match to high precision. However, the combination of computations that qr_decomposition uses to produce the zeros in R don't exactly cancel, so the zeros aren't actually quite equal to zero.

It turns out that np.linalg.qr isn't doing any fancy floating point tricks to ensure that the zeros in its output are 0.0. It just calls np.triu, which forcibly sets those values to 0.0. So to achieve the same results, just change the return line in qr_decomposition to:

return Q[:n].T, np.triu(R[:n])

Signs (+/-) in the output

Some of the +/- signs in Q and R are different in the outputs of np.linalg.qr and qr_decomposition, but this isn't really an issue as there are many valid choices for the signs (see this discussion of the uniqueness of Q and R). You can exactly match the sign convention that np.linalg.qr by using an alternative algorithm to generate v and tau:

def householder_vectorized(a):
    """Use this version of householder to reproduce the output of np.linalg.qr 
    exactly (specifically, to match the sign convention it uses)
    
    based on https://rosettacode.org/wiki/QR_decomposition#Python
    """
    v = a / (a[0] + np.copysign(np.linalg.norm(a), a[0]))
    v[0] = 1
    tau = 2 / (v.T @ v)
    
    return v,tau

Exactly matching the output of np.linalg.qr

Putting it all together, this version of qr_decomposition will exactly match the output of np.linalg.qr: ​

import numpy as np
from typing import Union

def qr_decomposition(A: np.ndarray) -> Union[np.ndarray, np.ndarray]:
    m,n = A.shape
    R = A.copy()
    Q = np.identity(m)
    
    for j in range(0, n):
        # Apply Householder transformation.
        v, tau = householder_vectorized(R[j:, j, np.newaxis])
        
        H = np.identity(m)
        H[j:, j:] -= tau * (v @ v.T)
        R = H @ R
        Q = H @ Q
        
    return Q[:n].T, np.triu(R[:n])
​
m = 5
n = 4
​
A = np.random.rand(m, n)
q, r = np.linalg.qr(A)
Q, R = qr_decomposition(A)
​
with np.printoptions(linewidth=9999, precision=20, suppress=True):
    print("**** Q from qr_decomposition")
    print(Q)
    print("**** Q from np.linalg.qr")
    print(q)
    print()
    
    print("**** R from qr_decomposition")
    print(R)
    print("**** R from np.linalg.qr")
    print(r)

Output:

**** Q from qr_decomposition
[[-0.10345123000824041   0.6455437884382418    0.44810714367794663  -0.03963544711256745 ]
 [-0.55856415402318     -0.3660716543156899    0.5953932791844518    0.43106504879433577 ]
 [-0.30655198880585594   0.6606757192118904   -0.21483067305535333   0.3045011114089389  ]
 [-0.48053620675695174  -0.11139783377793576  -0.6310958848894725    0.2956864520726446  ]
 [-0.5936453158283703   -0.01904935140131578  -0.016510508076204543 -0.79527388379824    ]]
**** Q from np.linalg.qr
[[-0.10345123000824041   0.6455437884382426    0.44810714367794663  -0.039635447112567376]
 [-0.5585641540231802   -0.3660716543156898    0.5953932791844523    0.4310650487943359  ]
 [-0.30655198880585594   0.6606757192118904   -0.21483067305535375   0.30450111140893893 ]
 [-0.48053620675695186  -0.1113978337779356   -0.6310958848894725    0.29568645207264455 ]
 [-0.5936453158283704   -0.01904935140131564  -0.0165105080762043   -0.79527388379824    ]]

**** R from qr_decomposition
[[-1.653391466100325   -1.0838054573405895  -1.0632037969249921  -1.1825735233596888 ]
 [ 0.                   0.7263519982452554   0.7798481878600413   0.5496287509656425 ]
 [ 0.                   0.                  -0.26840760341581243 -0.2002757085967938 ]
 [ 0.                   0.                   0.                   0.48524469321440966]]
**** R from np.linalg.qr
[[-1.6533914661003253 -1.0838054573405895 -1.0632037969249923 -1.182573523359689 ]
 [ 0.                  0.7263519982452559  0.7798481878600418  0.5496287509656428]
 [ 0.                  0.                 -0.2684076034158126 -0.2002757085967939]
 [ 0.                  0.                  0.                  0.4852446932144096]]

Aside from the inevitable rounding error in the trailing digits, the outputs now match.

like image 173
tel Avatar answered Oct 10 '22 17:10

tel