言語モデルの性能が、実装により異なる件を解決する

piqcy
programming-soda
Published in
10 min readOct 11, 2018

言語モデルは、自然言語処理において最も基本的なタスクです。ただ、実装によってその性能は大きく左右されます。本記事では「言語モデルの実装」として紹介されることの多い2つの実装パターン、そしてその性能差について解説を行っていきます。

本記事を読むことで、以下の点が理解できます。

  1. 言語モデルの実装として紹介されることの多い2つのパターン、バッチ型とシーケンシャル型の違い
  2. バッチ型とシーケンシャル型の性能差と、その原因
  3. 性能差を埋めるための工夫

また、検証のため使用した実装は以下で参照可能です。開発中の自然言語前処理ライブラリchariotを使用しています(良かったら★をよろしくです)。

言語モデルの実装パターン

「言語モデルの実装」として紹介されることの多い実装は、主に2つあります。1つがバッチ型、もう一つがシーケンシャル型です(このタイプは私が名付けたので、一般的ではありません)。そして、バッチ型の性能は素直に実装するとシーケンシャル型に大きく劣ります。本節ではまず実装の違い、次節で性能差の原因について解説します。

バッチ型のモデルは、固定長の系列から予測を行うタイプの実装です。特にKerasによる言語モデルの実装ではこのタイプが使われていることが多いです。一定長の系列を与えて、系列に続く単語(文字)を予測させる形です。図にすると以下のようになります。

バッチ型の言語モデル実装

コードにすると、概ね以下のような実装になります。batch_size×sequence_lengthのデータを与え、各系列についての予測値を出力する形になります。

model = K.Sequential()
model.add(K.layers.Embedding(input_dim=vocab_size,
output_dim=embedding_size))
model.add(K.layers.LSTM(hidden_size))
model.add(K.layers.Dense(vocab_size, activation="softmax"))

シーケンシャル型のモデルは、逐次予測を行うタイプの実装です。PyTorchやChainerといった動的グラフのフレームワークでは、こちらの実装が多く使われています。実装を図にすると以下のようになります。sequence_lengthは1となり、それを繰り返し入力していく形になります。そして、一定長(下図ではbptt_size)でBack Propagation(BPTT)を行います。

シーケンシャル型の言語モデル実装

コードにすると、以下のようになります。コード中のdataは、batch_size × 1をbptt_size個セットにしたものになります。targetsはそれを一つずらしたものです。

hidden = model.init_hidden(batch_size)
for batch, i in enumerate(range(0, train_size - 1, bptt_size)):
data, targets = get_batch(train_data, i)
model.zero_grad()
output, hidden = model(data, hidden)
loss = cross_entropy(output.reshape(-1, vocab_size), targets)
loss.backward()

バッチ型言語モデルのsequence_lengthを、シーケンシャルモデルにおけるbptt_sizeと考えると、両者はほぼ等価にも思えます。しかし、実際に学習させるとその性能には大きな開きがあります。

以下はWikiText2を学習させた場合の結果です。シーケンシャル型はperplexityが100を切る一方、バッチ型は100以上になっています。

バッチ型の実行結果(Keras実装)
シーケンシャル型の実行結果(PyTorch実装)

スコアが悪いだけでなく、バッチ型は早い段階でOverfitをしています。バッチ型のvalidationスコアだけ抜き出したものが以下になりますが、すごいOverfitしているのが分かると思います。

バッチ型の実行結果(validation scoreのみ)

このような差が発生するのには、以下の理由があります(おそらく)。

  1. バッチ型の言語モデル実装では、sequence_length分だけ事前情報が与えられているパターンしか学習できない。
  2. バッチ間ではStateが引き継がれないため、sequence_lengthの初期では予測が不利になる。

上記2点の対策を行うことで、以下のように過学習を避け性能を上げることができます。

対策後のバッチ型実行結果

では、順にみていきましょう。

バッチ型言語モデルにおける学習の問題点

バッチ型の実装における最大の問題点は、sequence_length分だけ事前情報が与えられているパターンしか学習できないという点です。

例として、A,B,C,Dという4つの系列を与えるケースを考えます。シーケンシャル型ではA=>B、B=>C・・・といった形で予測/学習を行っていきます。B=>Cの予測の際は、その前にA=>Bが処理されたという情報が引き継がれます。そのため、実質的にはA, BからCを予測しているのと同じになります。一方、バッチ型の場合はA, B, CからDを予測するだけです。

シーケンシャル型とバッチ型の学習内容の差異

バッチ型でシーケンシャル型と同じ内容を学習するなら、系列の長さ毎にデータを作る必要が出てきます。A,B,C,Dの4つがあったら、A, B, C=>Dだけでなく、A=>B、A,B=>Cも学習でテータに含まないといけないということです。このように対策しても精度が改善することは確認済みですが、この場合シーケンス分だけデータが増え学習に時間がかかります。隠れ層の計算を毎回最初からやっていることになり非効率的です。

そのため、各ステップの隠れ層の状態から予測する形にします。図にすると以下のような形です。これにより、隠れ層の再計算をすることなく、都度に予測する形の学習が可能になります。

対策後のバッチ型の学習

Kerasの場合以下のような実装になります。

model = K.Sequential()
model.add(K.layers.Embedding(input_dim=vocab_size,
output_dim=embedding_size)
model.add(K.layers.LSTM(hidden_size, return_sequences=True))
model.add(K.layers.TimeDistributed(K.layers.Dense(vocab_size,
activation="softmax"))

これでかなり改善します。

学習方法改善後のバッチ型モデルの精度(validationスコアのみ)

ただ、8epoch目以降少し不吉な動きをしています。そのため、2点目の対策も行います。

バッチ型言語モデルにおけるStateの引継ぎ

シーケンシャル型の実装では、内部状態のリセットが行われません。え、と思った方もいるかもしれませんが、PyTorchは明らかにChainerでもリセットを行っている節は見られません。

リセットを行わないのは、ある意味では当然です。途中で切るのはあくまで学習上(BPTT)の都合であって、実際はその先も文は続いているわけですから、それまで文を処理した内部状態はリセットすべきではありません。リセットしてしまったら、「それまでの文」の情報が失われてしまいます。逆に言えば、バッチ型では系列初期の予測は常にコンテキスト不足の状態になっています。

そのため、バッチ型でも以下の対策が必要になります。

  • バッチ単位で内部状態をリセットしない
  • バッチの順番はShuffleしない

Kerasでは、stateful=Trueにすることで内部状態の引継ぎを行えます。ただ、validationの時は一旦止めるといった器用なことができないのと、バッチのshapeを明示的に与える必要があります。

statefulにした場合のvalidationスコア

これで、シーケンシャル型のスコアとほぼ同等になりました。ものすごく具体的には、KerasのSequentialモデルをfitで学習させたい時でも、上記の工夫によりPyTorch/Chainerと同等の精度を出すことが可能です。

その他の注意点

最後に、その他の注意すべき点に触れておきます。他のフレームワークの実装を参考に、自分がメインで使っているフレームワークで実装する、というシチュエーションはよくあると思います。この時、単純にモデルの構成だけでなくその初期値にも気を払う必要があります。

LSTMを例にとり、ChainerとKerasを比較してみます。LSTMは入力に対する重みと隠れ層にかかる重みの2種類があります(Chainerではlateral/upward、Kerasではkernel/recurrent)。ChainerのLSTMは、デフォルトの場合何れもscale=1.0の LeCunNormalで初期化され、sigmoidでアクティベーションされます(sigmoidは高速化のためか、ちょっとカスタマイズされた実装になっています)。これに対し、Kerasはkernel(lateral)はglorot_uniform、recurrent(upward)はorthogonalと別々になっており、アクティベーションは中央が線形になったhard_sigmoidが使われています。

そんなに性能差が出るかというと差は微妙なのですが、設定が異なる点が多いのはわかっていただけるのではないかなと思います。他のフレームワークを基に実装したけど性能がいまいち出ない、という場合この辺りを疑ってみるとよいと思います。大体目に付くところはチェックしているので、目に見えない箇所に原因があることが多いです。ちなみに私はこれが原因で数日つぶしたことがあります(以下の初期値の差異は、1.11.0では修正されています)。

見えない差異にも気を付けて、良い言語モデルライフを!

--

--

piqcy
programming-soda

All change is not growth, as all movement is not forward. Ellen Glasgow