手刻 LSTM 模型

Winston Chen
Taiwan AI Academy
Published in
5 min readSep 17, 2019

Overview

RNN 與 LSTM 相較於 DNN 與 CNN 來說較不容易理解,透過實作手刻模型,能夠幫助理解 LSTM 模型的運作模式,本文是透過 tensorflow 中基礎的指令實作 LSTM 模型。

1. What is RNN?

下圖是將一般常見的 DNN 旋轉九十度後的圖,X 是輸入的訓練資料,A 則是通過 hidden layer 後計算得到的結果,將訓練資料經過 activation function 後便可以得到輸出 h,也就是模型的輸出。

source:http://colah.github.io/posts/2015-08-Understanding-LSTMs/

一般而言在訓練 DNN 時,每一筆資料都是獨立存在的,然而當遇到訓練資料是音訊或是文字時,通常這一類的訓練資料皆存在時序上的依賴,所以在訓練時會將資料的順序納入模型中,這時候每筆資料經過 hidden layer 得到的東西我們稱之為狀態(State),而狀態會紀錄當下這筆資料的特性,而狀態透過某些轉換後會傳入下一筆資料的狀態中與下一筆資料計算所得的狀態結合後計算出輸出h,而每一筆訓練資料都可以得到一組輸出h,但是通常在 RNN 中我們只會將最後一筆資料得到的結果當為最後的輸出,因為這個輸出考慮了之前的每一筆資料與其對應的狀態進而得到最後的結果。

source:http://colah.github.io/posts/2015-08-Understanding-LSTMs/

2. What is Long Short-term memory?

上述的 RNN 模型相較於普通的 DNN 確實能夠將資料的時序性納入模型,但是當今天訓練資料過於多時,一開始的狀態對於模型最後的輸出影響已經趨近於消失了,所以發展出 LSTM,由機器自己去學習在訓練資料中,哪些狀態可以忘記,哪些資料必須保存。

相較於 RNN,除了上一個狀態到下一個狀態的轉換外,LSTM 模型中多了三個參數:input gate、forget gate 與 output gate,首先,input gate 會控制當前這筆資料輸入的比例,而 forget gate 則是控制上一個狀態要保留多少比例傳入當前的狀態,而最後模型要輸出結果時,會先通過 output gate 決定要輸出多少比例再輸出結果。

source:http://colah.github.io/posts/2015-08-Understanding-LSTMs/

3. Implement

我們使用了特斯拉股價的收盤價作為訓練資料。

3.1 Loading data and data preprocessing

將資料讀進來後,只取收盤價作為訓練資料,並且將資料正規化。

3.2 Generating training data and testing data

我們可以通過使用窗函數大小來決定每組訓練資料中有幾個資料點,在這裡每一組訓練資料會包含七資料點,然後對應的目標值為第八個資料點,最後將訓練資料切一部分為測試資料。

3.3 Building model

在LSTM中,最為重要的便是四個 gate 的參數,所以我們分別先定義四個 gate 的參數,一開始先使用 truncated normal distribution 產生初始參數,而這組參數會在之後的訓練過程中不斷的更新,進而找到收斂點,最後則是通過輸出層的參數後便能得到最終輸出結果。

將四個 gate 的參數定義完之後,接著就是進行 LSTM 的計算,我們會透過 input 與上一個資料點的 output 分別計算出四組 gate 的參數,接著將上一個資料點的狀態通過 forget gate 加上 input 通過 input gate 後更新這個資料點的狀態,output 則是將這一個資料點的狀態通過 tanh activation 後通過output gate 便可以得到輸出。

最後就是定義訓練的流程了!首先,在每一個 batch 開始之前,先將state[-1] 與 output[-1] 初始值設定為0,接著每一次訓練都要皆要按照順序計算每一個資料點的狀態與輸出並且持續更新,計算完七個資料點後,便將最後一個資料點的輸出通過輸出層就得到 LSTM model 的結果。

3.4 Result

參考資料

http://colah.github.io/posts/2015-08-Understanding-LSTMs/

https://github.com/lucko515/tesla-stocks-prediction/blob/master/lstm_from_scratch_tensorflow.ipynb?fbclid=IwAR3F6xUxC8aRJcvuD2e0vzVvvwxnyFqvvDTk3prved77qPjEA0BhZhIU4ak

--

--