Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

My LSTM learns, loss decreases, but Numerical Gradients don't match Analytical Gradients

The following is self-contained, and when you run it it will:

1. print the loss to verify it's decreasing (learning a sin wave),

2. Check the numeric gradients against my hand-derived gradient function.

The two gradients tend to match within 1e-1 to 1e-2 (which is still bad, but shows it's trying) and there are occasional extreme outliers.

I have spent all saturday backing out to a normal FFNN, getting that to work (yay, gradients match!), and now sunday on this LSTM, and well, I can't find the bug in my logic. Oh, and it's heavily depends on my random seed, sometimes it's great, sometimes awful.

I've hand checked my implementation against hand derived derivatives for the LSTM equations (i did the calculus), and against the implementations in these 3 blogs/gist:

  • http://blog.varunajayasiri.com/numpy_lstm.html

  • https://gist.github.com/karpathy/d4dee566867f8291f086

  • http://colah.github.io/posts/2015-08-Understanding-LSTMs/

And tried the (amazing) debugging methods suggested here: https://blog.slavv.com/37-reasons-why-your-neural-network-is-not-working-4020854bd607

Can you help see where I have implemented something wrong?

import numpy as np
np.set_printoptions(precision=3, suppress=True)

def check_grad(params, In, Target, f, df_analytical, delta=1e-5, tolerance=1e-7, num_checks=10):
    """
    delta : how far on either side of the param value to go

    tolerance : how far the analytical and numerical values can diverge
    """

    h_n = params['Wf'].shape[1] # TODO: h & c should be passed in (?)
    h = np.zeros(h_n)
    c = np.zeros(h_n)

    y, outputs, loss, h, c, caches = f(params, h, c, inputs, targets)
    dparams = df_analytical(params, inputs, targets, outputs, caches)

    passes = True
    for _ in range(num_checks):
        print()
        for pname, p, dpname, dp in zip(params.keys(), params.values(), dparams.keys(), dparams.values()):

            pix = np.random.randint(0, p.size)
            old_val = p.flat[pix]

            # d = delta * abs(old_val) if old_val != 0 else 1e-5
            d = delta

            p.flat[pix] = old_val + d
            _, _, loss_plus, _, _, _ = f(params, h, c, In, Target) # note `_` is the cache
            p.flat[pix] = old_val - d
            _, _, loss_minus, _, _, _ = f(params, h, c, In, Target)
            p.flat[pix] = old_val

            grad_analytic = dp.flat[pix]
            grad_numeric = (loss_plus - loss_minus) / (2 * d)

            denom = abs(grad_numeric + grad_analytic) + 1e-12 # max((abs(grad_numeric), abs(grad_analytic)))
            relative_error = abs(grad_analytic - grad_numeric) / denom

            if relative_error > tolerance:
                print(("fails: %s % 4d |  r: % 3.4f,   a: % 3.4f,   n: % 3.4f,   a/n: %0.2f") % (pname, pix, relative_error, grad_analytic, grad_numeric, grad_analytic/grad_numeric))
            passes &= relative_error <= tolerance

    return passes


# ----------

def lstm(params, inp, h_old, c_old):

    Wf, Wi, Wg, Wo, Wy = params['Wf'], params['Wi'], params['Wg'], params['Wo'], params['Wy']
    bf, bi, bg, bo, by = params['bf'], params['bi'], params['bg'], params['bo'], params['by']

    xh = np.concatenate([inp, h_old])

    f = np.dot(xh, Wf) + bf
    f_sigm = 1 / (1 + np.exp(-f))

    i = np.dot(xh, Wi) + bi
    i_sigm = 1 / (1 + np.exp(-i))

    g = np.dot(xh, Wg) + bg # C-tilde or C-bar in the literature
    g_tanh = np.tanh(g)

    o = np.dot(xh, Wo) + bo
    o_sigm = 1 / (1 + np.exp(-o))

    c = f_sigm * c_old + i_sigm * g_tanh

    c_tanh = np.tanh(c)
    h = o_sigm * c_tanh

    y = np.dot(h, Wy) + by # NOTE: this is a dense layer bolted on after a normal LSTM
    # TODO: should it have a nonlinearity after it? MSE would not work well with, for ex, a sigmoid

    cache = (xh, f, f_sigm, i, i_sigm, g, g_tanh, o, o_sigm, c, c_tanh, c_old, h)
    return y, h, c, cache


def dlstm(params, dy, dh_next, dc_next, cache):

    Wf, Wi, Wg, Wo, Wy = params['Wf'], params['Wi'], params['Wg'], params['Wo'], params['Wy']
    bf, bi, bg, bo, by = params['bf'], params['bi'], params['bg'], params['bo'], params['by']

    xh, f, f_sigm, i, i_sigm, g, g_tanh, o, o_sigm, c, c_tanh, c_old, h = cache

    dby = dy.copy()
    dWy = np.outer(h, dy)
    dh = np.dot(dy, Wy.T) + dh_next.copy()
    do = c_tanh * dh * o_sigm * (1 - o_sigm)
    dc = dc_next.copy() + o_sigm * dh * (1 - c_tanh ** 2) # TODO: copy?
    dg = i_sigm * dc * (1 - g_tanh ** 2)
    di = g_tanh * dc * i_sigm * (1 - i_sigm)
    df = c_old  * dc * f_sigm * (1 - f_sigm) # ERROR FIXED: ??? c_old -> c?, c->c_old?

    dWo = np.outer(xh, do); dbo = do; dXo = np.dot(do, Wo.T)
    dWg = np.outer(xh, dg); dbg = dg; dXg = np.dot(dg, Wg.T)
    dWi = np.outer(xh, di); dbi = di; dXi = np.dot(di, Wi.T)
    dWf = np.outer(xh, df); dbf = df; dXf = np.dot(df, Wf.T)

    dX = dXo + dXg + dXi + dXf
    dh_next = dX[-h.size:]
    dc_next = f_sigm * dc

    dparams = dict(Wf = dWf, Wi = dWi, Wg = dWg, Wo = dWo, Wy = dWy,
                   bf = dbf, bi = dbi, bg = dbg, bo = dbo, by = dby)

    return dparams, dh_next, dc_next


def lstm_loss(params, h, c, inputs, targets):
    loss = 0
    outputs = []
    caches = []
    for inp, target in zip(inputs, targets):
        y, h, c, cache = lstm(params, inp, h, c)
        loss += np.mean((y - target) ** 2)
        outputs.append(y)
        caches.append(cache)
    loss = loss # / inputs.shape[0]
    return y, outputs, loss, h, c, caches

def dlstm_loss(params, inputs, targets, outputs, caches):
    h_shape = caches[0][-1].shape
    dparams = {k:np.zeros_like(v) for k, v in params.items()}
    dh = np.zeros(h_shape)
    dc = np.zeros(h_shape)

    for inp, out, target, cache in reversed(list(zip(inputs, outputs, targets, caches))):
        dy = 2 * (out - target)
        dps, dh, dc = dlstm(params, dy, dh, dc, cache)
        for dpk, dpv in dps.items():
            dparams[dpk] += dpv
    return  dparams


# ----------
# setup

x_n = 1
h_n = 5
o_n = 1

params = dict(
    Wf = np.random.normal(size=(x_n + h_n, h_n)),
    Wi = np.random.normal(size=(x_n + h_n, h_n)),
    Wg = np.random.normal(size=(x_n + h_n, h_n)),
    Wo = np.random.normal(size=(x_n + h_n, h_n)),
    Wy = np.random.normal(size=(h_n, o_n)),
    bf = np.zeros(h_n) + np.random.normal(size=h_n) * 0.1,
    bi = np.zeros(h_n) + np.random.normal(size=h_n) * 0.1,
    bg = np.zeros(h_n) + np.random.normal(size=h_n) * 0.1,
    bo = np.zeros(h_n) + np.random.normal(size=h_n) * 0.1,
    by = np.zeros(o_n) + np.random.normal(size=o_n) * 0.1,
)

for name in ['Wf', 'Wi', 'Wg', 'Wo', 'Wy']:
    W = params[name]
    W *= np.sqrt(2 / (W.shape[0] + W.shape[1])) # Xavier initialization
for name in params:
    params[name] = params[name].astype('float64')


# ----------
# Sanity check, learn sin wave

def test_sin():
    emaloss = 1 # EMA average
    emak = 0.99

    for t in range(5000):
        data = np.sin(np.linspace(0, 3 * np.pi, 30))
        start = np.random.randint(0, data.size // 4)
        end = np.random.randint((data.size * 3) // 4, data.size)
        inputs = data[start:end, None]
        targets = np.roll(inputs, 1, axis=0)


        h_n = params['Wf'].shape[1] # TODO: h & c should be passed in
        h = np.random.normal(size=h_n)
        c = np.random.normal(size=h_n)

        y, outputs, loss, h, c, caches = lstm_loss(params, h, c, inputs, targets)
        dparams = dlstm_loss(params, inputs, targets, outputs, caches)

        for k in params.keys():
            params[k] -= dparams[k] * 0.01


        emaloss = emaloss * emak + loss * (1 - emak)
        if t % 100 == 0:
            print('%.4f' % emaloss)
test_sin()

# ----------
data = np.sin(np.linspace(0, 4 * np.pi, 90))
start = np.random.randint(0, data.size // 4)
end = np.random.randint((data.size * 3) // 4, data.size)
inputs = data[start:end, None]
targets = np.roll(inputs, 1, axis=0)

for inp, targ in zip(inputs, targets):
    assert(check_grad(params, inputs, targets, lstm_loss, dlstm_loss, delta=1e-5, tolerance=1e-7, num_checks=10))
print('grads are ok') # <- i never reach here
like image 849
Josh.F Avatar asked Jan 21 '19 00:01

Josh.F


2 Answers

Solved it! in my check_grad, I need to build the caches which is served to df_analytical, but in so doing, I also overwrite the h and c which should have been np.zeroes.

y, outputs, loss, h, c, caches = f(params, h, c, inputs, targets)

_, _, loss_minus, _, _, _ = f(params, h, c, inputs, targets)
p.flat[pix] = old_val

So, simply not overwriting h and c fixes it, and the LSTM code was a.o.k.

_, outputs, loss, _, _, caches = f(params, h, c, inputs, targets)
like image 76
Josh.F Avatar answered Oct 15 '22 05:10

Josh.F


I think the problem might be this line:

c = f_sigm * c_old + i_sigm * g_tanh
like image 26
Felipe Valdes Avatar answered Oct 15 '22 06:10

Felipe Valdes