I need to take the matrix product of two NumPy matrices (or other 2d arrays) containing log probabilities. The naive way np.log(np.dot(np.exp(a), np.exp(b)))
is not preferred for obvious reasons.
Using
from scipy.misc import logsumexp res = np.zeros((a.shape[0], b.shape[1])) for n in range(b.shape[1]): # broadcast b[:,n] over rows of a, sum columns res[:, n] = logsumexp(a + b[:, n].T, axis=1)
works but runs about 100 times slower than np.log(np.dot(np.exp(a), np.exp(b)))
Using
logsumexp((tile(a, (b.shape[1],1)) + repeat(b.T, a.shape[0], axis=0)).reshape(b.shape[1],a.shape[0],a.shape[1]), 2).T
or other combinations of tile and reshape also work but run even slower than the loop above due to the prohibitively large amounts of memory required for realistically sized input matrices.
I am currently considering writing a NumPy extension in C to compute this, but of course I'd rather avoid that. Is there an established way to do this, or does anybody know of a less memory intensive way of performing this computation?
EDIT: Thanks to larsmans for this solution (see below for derivation):
def logdot(a, b): max_a, max_b = np.max(a), np.max(b) exp_a, exp_b = a - max_a, b - max_b np.exp(exp_a, out=exp_a) np.exp(exp_b, out=exp_b) c = np.dot(exp_a, exp_b) np.log(c, out=c) c += max_a + max_b return c
A quick comparison of this method to the method posted above (logdot_old
) using iPython's magic %timeit
function yields the following:
In [1] a = np.log(np.random.rand(1000,2000)) In [2] b = np.log(np.random.rand(2000,1500)) In [3] x = logdot(a, b) In [4] y = logdot_old(a, b) # this takes a while In [5] np.any(np.abs(x-y) > 1e-14) Out [5] False In [6] %timeit logdot_old(a, b) 1 loops, best of 3: 1min 18s per loop In [6] %timeit logdot(a, b) 1 loops, best of 3: 264 ms per loop
Obviously larsmans' method obliterates mine!
The NumPy library's np. multiply(x1, x2) method receives two matrices as input and executes element-wise multiplication over them before returning the resultant matrix.
the * operator (and arithmetic operators in general) were defined as element-wise operations on ndarrays and as matrix-multiplication on numpy. matrix type.
logsumexp
works by evaluating the right-hand side of the equation
log(∑ exp[a]) = max(a) + log(∑ exp[a - max(a)])
I.e., it pulls out the max before starting to sum, to prevent overflow in exp
. The same can be applied before doing vector dot products:
log(exp[a] ⋅ exp[b]) = log(∑ exp[a] × exp[b]) = log(∑ exp[a + b]) = max(a + b) + log(∑ exp[a + b - max(a + b)]) { this is logsumexp(a + b) }
but by taking a different turn in the derivation, we obtain
log(∑ exp[a] × exp[b]) = max(a) + max(b) + log(∑ exp[a - max(a)] × exp[b - max(b)]) = max(a) + max(b) + log(exp[a - max(a)] ⋅ exp[b - max(b)])
The final form has a vector dot product in its innards. It also extends readily to matrix multiplication, so we get the algorithm
def logdotexp(A, B): max_A = np.max(A) max_B = np.max(B) C = np.dot(np.exp(A - max_A), np.exp(B - max_B)) np.log(C, out=C) C += max_A + max_B return C
This creates two A
-sized temporaries and two B
-sized ones, but one of each can be eliminated by
exp_A = A - max_A np.exp(exp_A, out=exp_A)
and similarly for B
. (If the input matrices may be modified by the function, all the temporaries can be eliminated.)
Suppose A.shape==(n,r)
and B.shape==(r,m)
. In computing the matrix product C=A*B
, there are actually n*m
summations. To have stable results when you're working in log-space, You need the logsumexp trick in each of these summations. Fortunately, using numpy broadcasting that's quite easy to control stability of rows and columns of A and B separately.
Here is the code:
def logdotexp(A, B): max_A = np.max(A,1,keepdims=True) max_B = np.max(B,0,keepdims=True) C = np.dot(np.exp(A - max_A), np.exp(B - max_B)) np.log(C, out=C) C += max_A + max_B return C
Note:
The reasoning behind this is similar to the FredFoo's answer, but he used a single maximum value for each matrix. Since he did not consider every n*m
summations, some elements of the final matrix might still be unstable as mentioned in one of the comments.
Comparing with the currently accepted answer using @identity-m counter example:
def logdotexp_less_stable(A, B): max_A = np.max(A) max_B = np.max(B) C = np.dot(np.exp(A - max_A), np.exp(B - max_B)) np.log(C, out=C) C += max_A + max_B return C print('old method:') print(logdotexp_less_stable([[0,0],[0,0]], [[-1000,0], [-1000,0]])) print('new method:') print(logdotexp([[0,0],[0,0]], [[-1000,0], [-1000,0]]))
which prints
old method: [[ -inf 0.69314718] [ -inf 0.69314718]] new method: [[-9.99306853e+02 6.93147181e-01] [-9.99306853e+02 6.93147181e-01]]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With