Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is there a way to compute the matrix logarithm of a PyTorch tensor?

Tags:

python

pytorch

I am trying to compute matrix logarithms in PyTorch but I need to keep tensors because I then apply gradients which means I can't use NumPy arrays. Basically I'm trying to do the equivalent of ScyPy's scipy.linalg.logm() but with PyTorch tensors.

like image 745
Nathan Avatar asked Sep 17 '25 10:09

Nathan


1 Answers

Unfortunately the matrix logarithm (unlike the matrix exponential) is not implemented yet, but matrix powers are, this means in the mean time you can approximate the matrix logarithm by using a the power series expansion, and just truncate it after you get a sufficient accuracy.

Alternatively Lezcano proposes a (slow) solution of a differentiable matrix logarithm via adjoint here. I'll cite their suggested solution:

import scipy.linalg
import torch

def adjoint(A, E, f):
    A_H = A.T.conj().to(E.dtype)
    n = A.size(0)
    M = torch.zeros(2*n, 2*n, dtype=E.dtype, device=E.device)
    M[:n, :n] = A_H
    M[n:, n:] = A_H
    M[:n, n:] = E
    return f(M)[:n, n:].to(A.dtype)

def logm_scipy(A):
    return torch.from_numpy(scipy.linalg.logm(A.cpu(), disp=False)[0]).to(A.device)

class Logm(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A):
        assert A.ndim == 2 and A.size(0) == A.size(1)  # Square matrix
        assert A.dtype in (torch.float32, torch.float64, torch.complex64, torch.complex128)
        ctx.save_for_backward(A)
        return logm_scipy(A)

    @staticmethod
    def backward(ctx, G):
        A, = ctx.saved_tensors
        return adjoint(A, G, logm_scipy)

logm = Logm.apply
like image 193
flawr Avatar answered Sep 20 '25 00:09

flawr