Kerasくんとgeneratorの魔法

Sharat Chinnapa
The HumAIn Blog
Published in
5 min readMay 30, 2019
魔法の書がありました!

ゴーリストのチナパです!さてさて、今回は面白いのがありますよ〜

タイトルがちょっとハリーポッターっぽくて失礼します、機械学習の記事です。

機械学習というと、ビッグデータが思いつくと思います。
大量のデータで数日感学習させたり…

でも一つだけ問題があります

51ギガ?!

可愛そうなパソコンが諦めました。そんなRAM持ってませんって。

大量のデータで学習した時にこういうのもよくあります。私の場合には、数ギガバイトのcsvファイルの文字データをベクトル化しようとして、このようなことになりましたが、画像処理・音声分析の世界にはもっと頻繁でしょう。

ここではPythonの便利な機能「generator」を使うと良いです。

「generator」?なにそれ

range(10)

これもgeneratorです。何度も見たかと思います。

つまりgeneratorはループに使えるiterator系のobjectです


for x in range(10):
print(x) #0~9がprintされます

上記のようにfor ループに使えますが、generator の本当の取り柄は全てのデータを同時にメモリー(RAM)にロードしていないことです。

ただし、このようにしますと:

my_gen = range(10)
for x in my_gen:
print(x) #無事0~9がprintされます
# 二回めはうまく行かない
for x in my_gen:
print(x) #何もprintされない

これはメモリーに保存されてないからです。
ただし、データ量が多すぎてパソコンが泣いている私たちにはこの機能こそが素晴らしいです。

自分のgeneratorはどうやってつける?

Pythonにはメソッドの中に`return`を使う場合がよくありますが、`yield`を使う場合もあります。メソッドには`yield`が使われてるとそのメソッドがgeneratorになります。

以下のメソッドが簡単なgeneratorです。nx0 からnx(n-1)までの数字を出します。

def my_generator(n):
i = 0
while i < n:
i += 1
yield n*(i-1)
for x in my_generator(5):
print(x) # 0, 5, 10, 15, 20のようなものが順番にprintされます

でもこんなのはどうやって機械学習に活かせますか?

そうですね、本題に着きました。

keras のfit_generator()メソッドを利用します。このために、永遠までデータを出してくれるgeneratorメソッドが必要です。

つまり以下のようなものは◎です。

def gen():
while True:
yield 1 # ただしこんなデータ要りません!

すみません、しっかりします。前提として、データをcsvファイルであると想定しています。これはもちろん必須ではありませんが、pandasを便利に使えるし、文字データならcsvにまとめやすいので、そういうことにしましょう。

ポイントとしては `while True`の中に、
pd.read_csv(skiprows=i, nrows=batch_size) があることです。
これはcsvデータの必要な部分(ここでは32行)をデータ化しようとします。
この例には、nullデータを外すための処理も入れてます、これは事前に掃除されたデータであれば無くしてもいい部分ですね。

yield data[‘input_x’], data[‘outputs_y’]

もありますが、複数のインプットが必要な場合には
yield [data[‘input_x1’], data[‘input_x’]], data[‘outputs_y’]
見たいにまとめたらうまく行きます。

学習させましょう!

kerasの`fit_generator`メソッドで学習データとバリデーションデータのそれぞれのgeneratorを作っておきます。

ここの変わったところがsteps_per_epoch= train_size // batch_size
あたりですね。全てのデータがメモリーに入っていない上、generate_input_data() のメソッドが永遠まで続くようになってますので、1epochの大きさが分かるために、設定しないといけません。

お時間たっぷりかけると思いますが、無事大量のデータで学習できるようになってます!

まとめ

今回は

  1. Pythonのgeneratorについて学びました
  2. kerasで使えるgeneratorを作成しました
  3. model.fit_generator() を利用しました。

ここまで読んでくれてありがとうございます。
では、また!

--

--

Sharat Chinnapa
The HumAIn Blog

Programmer, writer, dancer, learning how to make the world a better place at HumAIn.