Intro
For an introduction to LSTM, please refer to http://colah.github.io/posts/2015-08-Understanding-LSTMs/. What follows is a clean and minimal implementation of the forward pass of LSTM cell. I have used the equations as provided by Chris Olah.
The code
```python
import numpy as np
cc = lambda a,b : np.vstack([a, b]) # concatenate column-vectors
mm = np.matmul
def sigmoid(x):
return 1.0/(1.0+np.exp(-x))
def lstm(parameters, C, h, x):
Wf, bf, Wi, bi, Wc, bc, Wo, bo = parameters
z = cc(h, x)
forget_gate = sigmoid( mm(Wf, z) + bf )
input_gate = sigmoid( mm(Wi, z) + bi )
output_gate = sigmoid( mm(Wo, z) + bo )
C_hat = np.tanh( mm(Wc, z) + bc )
C_next = forget_gate * C + input_gate * C_hat
h_next = output_gate * np.tanh( C )
return C_next, h_next
def genRandomParameters(n_in, n_out, gen=np.random.randn):
n_z = n_out + n_in
W, b = {}, {}
for k in ['f', 'i', 'c', 'o']:
b[k] = gen(n_out, 1)
W[k] = gen(n_out, n_z)
return W['f'], b['f'], W['i'], b['i'], W['c'], b['c'], W['o'], b['o']
if __name__ == '__main__':
n_in = 30
n_out = 20
parameters = genRandomParameters(n_in, n_out)
C, h = {}, {}
x = np.random.randn(n_in, 1)
C[0] = np.random.randn(n_out, 1)
h[0] = np.random.randn(n_out, 1)
C[1], h[1] = lstm(parameters, C[0], h[0], x)
print("Here is the input at step 0\n")
print(x)
print("\n")
for i in range(2):
print("Here is the state {}\n".format(i))
print("\nC = ")
print(C[i])
print("\nh = ")
print(h[i])
print("\n")
```
Output
Here is the input at step 0 [[-0.05627357] [-0.29153725] [-1.28900898] [ 1.70351343] [ 0.19033016] [-0.68916496] [-1.57563025] [ 0.63144996] [ 0.95782593] [ 1.06586039] [-1.30962654] [-0.21002679] [-1.74616201] [ 0.86942999] [ 1.46664102] [-0.528788 ] [ 0.91849627] [-0.5678924 ] [-0.54006408] [-0.98086492] [ 1.02920194] [-1.13406362] [-0.56705029] [-2.31024086] [ 0.30760565] [ 2.1718798 ] [-0.25775158] [ 0.97286009] [ 0.1501893 ] [-0.75663434]] Here is the state 0 C = [[ 0.22116217] [-0.18373211] [ 0.8237491 ] [-1.99654357] [-0.62020134] [ 0.92473404] [ 0.82270896] [ 1.58050728] [-0.64807869] [ 0.64705852] [-0.83207134] [-0.06340116] [ 0.37436416] [ 2.10939065] [ 0.39075321] [ 1.1974727 ] [-0.17977976] [ 1.15609134] [-0.72855696] [-0.26332118]] h = [[-1.30358873] [ 0.52950027] [-0.77330343] [-1.29107344] [ 1.94376534] [-0.24147924] [ 1.14879692] [ 1.88227173] [ 0.41226936] [-0.93637718] [-1.28293428] [-0.42759261] [-1.03448204] [ 0.2648483 ] [ 1.25716327] [-1.1590435 ] [ 0.47442464] [-1.51877055] [ 2.41283117] [-0.90421226]] Here is the state 1 C = [[ 2.09379677e-04] [ 6.57006251e-01] [ 1.69534554e+00] [ 7.77343533e-01] [-1.00762388e+00] [-1.80160070e-01] [ 9.93025673e-01] [ 9.59586938e-01] [-6.39248726e-02] [ 1.56690524e+00] [-6.61027783e-01] [ 9.21934182e-01] [-9.82113268e-01] [ 2.10935721e+00] [ 9.99832539e-01] [-1.94399191e-01] [-1.17692612e+00] [ 1.99813640e+00] [-2.05687643e-03] [-1.25625351e+00]] h = [[ 2.15788523e-01] [-1.81692141e-01] [ 2.44754749e-02] [-1.47709331e-03] [-3.41002130e-04] [ 6.82350229e-01] [ 2.08100037e-06] [ 1.64620825e-01] [-3.94912532e-01] [ 1.83452280e-02] [-6.81294774e-01] [-5.38801179e-02] [ 3.45337878e-01] [ 6.71864623e-01] [ 3.70973883e-01] [ 8.31566331e-01] [-1.77707301e-01] [ 8.15617291e-01] [-6.16516306e-01] [-4.97042055e-02]]