Keras LSTM to Java–08-Understanding-LSTMs/

We have lot of amazing frameworks for deep learning which allow us easy and fast prototyping and learning complex architectures even not thinking about what happening inside of them. But sometimes you need to deploy your model somewhere… let’s say where you can’t use your favorite

from keras.layers.recurrent import LSTM

I recently faced this problem, when I had to deploy recurrent neural network for action recognition trained in Keras in Java. My client doesn’t want to use some microservices architecture, he wants everything in Java and basta cosi :)

So, let’s see how we can do it.

Loading weights from trained model

First, I trained 2-layers LSTM model with softmax on the top, classifying in 3 classes:

Embedding is vector length of 11, hidden units = 15.

First, let’s load our weights from .hdf5 file and see the structure:

Output looks like this:

/layer_0/param_0 (11L, 15L)
/layer_0/param_1 (15L, 15L)
/layer_0/param_10 (15L, 15L)
/layer_0/param_11 (15L,)
/layer_0/param_2 (15L,)
/layer_0/param_3 (11L, 15L)
/layer_0/param_4 (15L, 15L)
/layer_0/param_5 (15L,)
/layer_0/param_6 (11L, 15L)
/layer_0/param_7 (15L, 15L)
/layer_0/param_8 (15L,)
/layer_0/param_9 (11L, 15L)
/layer_1/param_0 (15L, 15L)
/layer_1/param_1 (15L, 15L)
/layer_1/param_10 (15L, 15L)
/layer_1/param_11 (15L,)
/layer_1/param_2 (15L,)
/layer_1/param_3 (15L, 15L)
/layer_1/param_4 (15L, 15L)
/layer_1/param_5 (15L,)
/layer_1/param_6 (15L, 15L)
/layer_1/param_7 (15L, 15L)
/layer_1/param_8 (15L,)
/layer_1/param_9 (15L, 15L)
/layer_2/param_0 (15L, 3L)
/layer_2/param_1 (3L,)
[lstm_1_W_i, lstm_1_U_i, lstm_1_b_i, lstm_1_W_c, lstm_1_U_c, lstm_1_b_c, lstm_1_W_f, lstm_1_U_f, lstm_1_b_f, lstm_1_W_o, lstm_1_U_o, lstm_1_b_o]
[lstm_2_W_i, lstm_2_U_i, lstm_2_b_i, lstm_2_W_c, lstm_2_U_c, lstm_2_b_c, lstm_2_W_f, lstm_2_U_f, lstm_2_b_f, lstm_2_W_o, lstm_2_U_o, lstm_2_b_o]
[dense_1_W, dense_1_b]

param_0, param_1 doesn’t look very representative, I can’t understand what these weights are responsible for. But output of layer.trainable_weights shows us exactly what we want. And if we check one of the most popular tutorials in LSTMs… We are just lucky! Notations of weight matrices are the same! We can understand, that param_0 is just W_i, param_1 — U_i and so on.

Now let’s save matrices in easy to read format:

for i, l in enumerate(layers):
for key in l.keys():
np.savetxt(‘./weights/’ + str(i) + ‘_’ + key +’.txt’, l[key])

Now they are nicely stored as .txt files looking like:

1.118575707077980042e-01 6.875121593475341797e-01 -6.481686234474182129e-01 -1.580208778381347656e+00 4.655661880970001221e-01 5.263564586639404297e-01 8.121936023235321045e-02 -3.387819603085517883e-02 3.302907645702362061e-01 5.148626565933227539e-01 -5.716431140899658203e-01 2.599978260695934296e-02 -1.588541269302368164e-01 1.211983323097229004e+00 -6.513738632202148438e-01

Java LSTM from scratch

I am going to follow mentioned above tutorial for implementing LSTM. For all details check out code on Github, here are just some parts of it.

First, I decided to use jblas for matrices routines, we will use them a lot.

Let’s write method to load every saved weight matrix:

We also need classes for:

You can check them out on Github, here I just post code for forward propagation routine in LSTM. Everything is pretty straightforward, just step-by-step matrices multiplications (carefully think about dimensions!).

To build our network we can use next approach (yeah-yeah, just making it look prettier, adding layers like in Keras :D)

Just left to check results and compare them to Keras output. They are the same :)


Actually, it was a nice exercise to code some math, especially when you are used to use “out-of-box” instruments. Also it shows, that LSTMs are not that complicated in implementation and in case of need you can always port it to any language to make it work on any devices.

Thank you for attention!

One clap, two clap, three clap, forty?

By clapping more or less, you can signal to us which stories really stand out.