LSTM in NumPy/Python

Posted on May 11, 2018 by Ilya

Intro

For an introduction to LSTM, please refer to http://colah.github.io/posts/2015-08-Understanding-LSTMs/. What follows is a clean and minimal implementation of the forward pass of LSTM cell. I have used the equations as provided by Chris Olah.

The code


```python

import numpy as np

cc = lambda a,b : np.vstack([a, b]) # concatenate column-vectors
mm = np.matmul

def sigmoid(x):
    return 1.0/(1.0+np.exp(-x))


def lstm(parameters, C, h, x):
    Wf, bf, Wi, bi, Wc, bc, Wo, bo = parameters
    z = cc(h, x)
    forget_gate = sigmoid( mm(Wf, z) + bf )
    input_gate  = sigmoid( mm(Wi, z) + bi )
    output_gate = sigmoid( mm(Wo, z) + bo )
    C_hat       = np.tanh( mm(Wc, z) + bc )

    C_next      = forget_gate * C + input_gate * C_hat
    h_next      = output_gate * np.tanh( C )

    return C_next, h_next


def genRandomParameters(n_in, n_out, gen=np.random.randn):
    n_z = n_out + n_in

    W, b = {}, {}
    for k in ['f', 'i', 'c', 'o']:
        b[k] = gen(n_out, 1)
        W[k] = gen(n_out, n_z)

    return W['f'], b['f'], W['i'], b['i'], W['c'], b['c'], W['o'], b['o']


if __name__ == '__main__':

    n_in  = 30
    n_out = 20

    parameters = genRandomParameters(n_in, n_out)

    C, h = {}, {}
    x    = np.random.randn(n_in, 1)
    C[0] = np.random.randn(n_out, 1)
    h[0] = np.random.randn(n_out, 1)

    C[1], h[1] = lstm(parameters, C[0], h[0], x)

    print("Here is the input at step 0\n")
    print(x)
    print("\n")
    for i in range(2):
        print("Here is the state {}\n".format(i))
        print("\nC = ")
        print(C[i])
        print("\nh = ")
        print(h[i])
        print("\n")
```

Output

Here is the input at step 0

[[-0.05627357]
 [-0.29153725]
 [-1.28900898]
 [ 1.70351343]
 [ 0.19033016]
 [-0.68916496]
 [-1.57563025]
 [ 0.63144996]
 [ 0.95782593]
 [ 1.06586039]
 [-1.30962654]
 [-0.21002679]
 [-1.74616201]
 [ 0.86942999]
 [ 1.46664102]
 [-0.528788  ]
 [ 0.91849627]
 [-0.5678924 ]
 [-0.54006408]
 [-0.98086492]
 [ 1.02920194]
 [-1.13406362]
 [-0.56705029]
 [-2.31024086]
 [ 0.30760565]
 [ 2.1718798 ]
 [-0.25775158]
 [ 0.97286009]
 [ 0.1501893 ]
 [-0.75663434]]


Here is the state 0


C =
[[ 0.22116217]
 [-0.18373211]
 [ 0.8237491 ]
 [-1.99654357]
 [-0.62020134]
 [ 0.92473404]
 [ 0.82270896]
 [ 1.58050728]
 [-0.64807869]
 [ 0.64705852]
 [-0.83207134]
 [-0.06340116]
 [ 0.37436416]
 [ 2.10939065]
 [ 0.39075321]
 [ 1.1974727 ]
 [-0.17977976]
 [ 1.15609134]
 [-0.72855696]
 [-0.26332118]]

h =
[[-1.30358873]
 [ 0.52950027]
 [-0.77330343]
 [-1.29107344]
 [ 1.94376534]
 [-0.24147924]
 [ 1.14879692]
 [ 1.88227173]
 [ 0.41226936]
 [-0.93637718]
 [-1.28293428]
 [-0.42759261]
 [-1.03448204]
 [ 0.2648483 ]
 [ 1.25716327]
 [-1.1590435 ]
 [ 0.47442464]
 [-1.51877055]
 [ 2.41283117]
 [-0.90421226]]


Here is the state 1


C =
[[ 2.09379677e-04]
 [ 6.57006251e-01]
 [ 1.69534554e+00]
 [ 7.77343533e-01]
 [-1.00762388e+00]
 [-1.80160070e-01]
 [ 9.93025673e-01]
 [ 9.59586938e-01]
 [-6.39248726e-02]
 [ 1.56690524e+00]
 [-6.61027783e-01]
 [ 9.21934182e-01]
 [-9.82113268e-01]
 [ 2.10935721e+00]
 [ 9.99832539e-01]
 [-1.94399191e-01]
 [-1.17692612e+00]
 [ 1.99813640e+00]
 [-2.05687643e-03]
 [-1.25625351e+00]]

h =
[[ 2.15788523e-01]
 [-1.81692141e-01]
 [ 2.44754749e-02]
 [-1.47709331e-03]
 [-3.41002130e-04]
 [ 6.82350229e-01]
 [ 2.08100037e-06]
 [ 1.64620825e-01]
 [-3.94912532e-01]
 [ 1.83452280e-02]
 [-6.81294774e-01]
 [-5.38801179e-02]
 [ 3.45337878e-01]
 [ 6.71864623e-01]
 [ 3.70973883e-01]
 [ 8.31566331e-01]
 [-1.77707301e-01]
 [ 8.15617291e-01]
 [-6.16516306e-01]
 [-4.97042055e-02]]

References