PyTorchバックエンドの確率的プログラミング言語Pyroと生成モデルのツールPixyz

大川 洋平
PyTorch
Published in
18 min readDec 18, 2020

変分オートエンコーダーを題材に、確率的プログラミング言語Pyroと生成モデルのツールPixyzについて紹介します。

目次

  • PyroとPixyz
  • PyTorchユーザーが気になる点
  • 畳み込みを使った変分オートエンコーダーによるMNIST学習
  • Pyroを使ったVAEの実装
  • Pixyzを使ったVAEの実装
  • 終わりに

PyroとPixyz

確率モデリングを行うためのプログラミングツールとして、確率的プログラミング言語(Probabilistic Programming Language, PPL)があります。Pythonを使ったPPLにはPyMCやPyStan、TensorFlow Probabilityなどがありますが、Uber社がPyTorchをバックエンドにしたPPLであるPyroをリリースしています。Pyroは柔軟で汎用的な確率モデリングのためのツールを目指しており、観測データの確率分布をモデリングする生成モデルにもPyroは対応しています。

ところでPyTorchをバックエンドにした生成モデルのツールとして東京大学大学院工学系研究科 特任助教の鈴木雅大さんが開発されているPixyzがあります。生成モデル自体が確率分布を扱うものですから、必然Pixyzも確率モデリングに対応する側面を持っています。ただしあくまでPixyzのターゲットは生成モデル、特に深層生成モデルとなっています。

PyroとPixyzはどちらもPyTorchをバックエンドにしており、PyTorchの深層学習モジュールとの連携が容易であり、GPUを使った高速な演算が可能であるという大きなメリットがあります。この記事ではPyTorchについて知識がある読者を対象に、変分オートエンコーダー(Variational Autoencoders, VAE)を題材としてPyroとPixyzを紹介します。

PyTorchユーザーが気になる点

PyroやPixyzを使用する上で、PyTorchユーザーがおそらく気になると思われる点を幾つか挙げてみましょう。

PyTorchで実装したニューラルネットワークは流用できる?

PyroとPixyzともにYESです。ただし両者ともそのまま使用するのではなく、それぞれ少しだけ工夫が必要です。

Pyroではtorch.nn.Moduleを継承したクラス内に、追加でmodelメソッドguideメソッドを実装する必要があります。VAEではmodelメソッドは生成ネットワーク、つまり潜在変数からの観測データのサンプリングが該当します。一方、推論ネットワーク、つまり観測データからの潜在変数の推論がguideメソッドに対応します。PyTorchで実装したニューラルネットワークのインスタンスを使ってmodelメソッドとguideメソッドを記述すればよいため、ニューラルネットワークの実装自体には手を加える必要はありません。(modelとguideをそれぞれtorch.nn.Moduleを継承した別クラスとして定義することも可能です。)

Pixyzでは生成モデルと推論モデルをそれぞれpixyz.distributionsパッケージ内のクラス(Distributions API)を継承した自作クラスとして実装する必要があります。Distributions APIにはベルヌーイ分布や正規分布に対応する確率分布クラスが用意されており、出力に合わせたクラスを継承して生成ネットワークと推論ネットワークを実装します。しかしPyTorchで実装したニューラルネットワークのインスタンスを該当のクラスでラッピングしてやればよいだけですから、追加のコーディングはほぼないと言えます。

誤差逆伝播とネットワークパラメータの更新はPyTorchの実装のままでよい?

PyroとPixyzともにNoです。しかし、両者ともにツールが複雑な処理を隠蔽してくれますので、むしろ実装は簡単になるかもしれません。

PyroではSVIクラス(Stochastic Variational Inference)のstepメソッドを実行するだけで、損失関数の計算からネットワークパラメータの更新までを行ってくれます。SVIのインスタンスを作成する際には、実装したmodelメソッドとguideメソッド、最適化アルゴリズムと損失関数を指定します。テスト時にはstepメソッドの代わりにevaluate_lossメソッドを実行することで損失値の計算だけ行うことが出来ます。

Pixyzではpixyz.modelsパッケージ内のModelクラス(Model API)のtrainメソッドを実行することでネットワークパラメータの更新までを行ってくれます。Modelのインスタンスを作成する際に、実装した生成ネットワークと推論ネットワーク、最適化アルゴリズムと損失関数を指定します。testメソッドを実行することで損失値だけ計算することも可能です。

確率分布からデータをサンプリングできる?

もちろんPyroとPixyzともにYesです。ただし両ツールで少しだけ気にしておく点があります。

Pyroでは確率分布を指定してpyro.sampleメソッドを実行することでサンプリングできますが、その際にはサンプリングした変数に固有な名前を付けておく必要があります。guideメソッドで推論した潜在変数の名前と、modelメソッドで観測データの生成に使用する潜在変数の名前が一致していることで、guideメソッドとmodelメソッドで連携をとることが出来ます。

PixyzではDistributions APIを継承した確率分布のクラスにsampleメソッドを実行することでデータをサンプリングできます。ただし条件付き確率としてクラス定義したものは、条件に対応するTensorをdict形式で渡す必要があります。またsampleメソッドは単一のサンプルを得るため、分布の期待値を得るにはsample_meanメソッドを使う必要があります。

畳み込みを使った変分オートエンコーダーによるMNIST学習

本記事ではPyroとPixyzを紹介する題材として、畳み込みを使ったVAEを取り上げます。サンプルコードを紹介する前に、簡単にVAEについて説明します。VAEの詳細な説明は論文 [1, 2] を御覧ください。日本語の書籍では [3] に詳細な解説が記載されています。

VAEは観測データ x の確率分布をパラメータ θ を使って p_θ(x) としてモデリングするもので、直接観測できるデータ x の背後に観測できない潜在変数 z が存在している構造を仮定しています。z はデータ空間よりも低次元な潜在空間に存在しており、その確率分布は扱いやすい単純な分布であると仮定します。

変分オートエンコーダーの構造

x が与えられた際の z の事後分布 p_θ(z|x) は、xz の同時確率分布 p_θ(x, z) と p_θ(x) によって求めることが出来ます。

しかしこの ​p_θ(z|x) を求めることは難しいため、変分パラメータと呼ばれる​パラメータ φ をもつ扱いやすい分布関数 q_φ(z|x) ​で近似します。 q_φ(z|x) は近似事後分布と呼ばれます。

VAEを学習するには、データ分布 p_θ(x)​ の対数尤度である log p_θ(x) ​を最大化する θ を求める必要があります。ここで log p_θ(x)​ は q_φ(z|x) を含んだ形として以下に変形できます。

上の式の最後の右辺第1項はエビデンス下界(evidence lower bound, ELBO) と呼ばれるもので、右辺第2項は近似事後分布 q_φ(z|x) ​と真の事後分布 p_θ(z|x) のKLダイバージェンスです。KLダイバージェンスは非負の値をとるため、学習の手段としてはELBOを最大化することによって log p_θ(x)​ の最大化を図ります。ELBOはさらに以下に展開できます。

p_θ(z)​ は z ​の事前分布であり、扱いやすい単純な分布関数を仮定したものです。ELBOの右辺第1項は近似事後分布 q_φ(z|x) と潜在変数の事前分布 p_θ(z)​ とのKLダイバージェンスを、右辺第2項は z が与えられた際の x の事後分布について対数尤度を q_φ(z|x) に関する期待値でとったものです。

実装する際にはELBOを負にしたものを損失関数として設定します。負のELBOは変分エネルギーとも呼ばれ、変分エネルギーを最小化することによってELBOを最大化してパラメータ θ と変分パラメータ φ を学習します。

VAEの学習を行う際には、まず​ p_θ(z) を単純な分布関数に仮定します。本記事では​ p_θ(z) は多変量の標準正規分布としています。

次に近似事後分布​ q_φ(z|x) を以下の形として、真の事後分布を近似します。 f_φφ をパラメータに持つ関数です。

上の q_φ(z|x) の式は、観測データの個々のサンプルから推論した潜在空間中の正規分布を組み合わせて事後分布を近似することを意味しています。また、推論した個々の正規分布の形は共通した φ によって決まるため、観測データの一部を使って φ を更新すると他の観測データサンプルに対応する正規分布も更新できるという計算効率の高い性質があります。

変分パラメータの更新

q_φ(z|x) は入力が観測データ x 、出力が正規分布の平均と分散の値であるニューラルネットワークに設計します。これを推論ネットワーク、あるいは確率的エンコーダーと呼びます。すなわち推論ネットワークの重みが変分パラメータ φ です。推論ネットワークの出力と標準正規分布(=仮定した潜在変数の事前分布)のKLダイバージェンスによって損失関数の第1項を計算できます。推論ネットワークの出力が正規分布に近いほど損失関数の第1項は小さくなるため、観測データから推定される潜在変数が正規分布に従う働きを持ちます。

一方で p_θ(x|z) ​を多変量ベルヌーイ分布に従うと規定し、p_θ(x|z) を潜在変数を入力としてデータ確率を出力するニューラルネットワークに設計します。これは生成ネットワーク、あるいは確率的デコーダーと呼ばれます。観測データ x の値域を [0, 1] として、観測データと生成ネットワーク出力とのバイナリ交差エントロピーによって損失関数の第2項を計算できます。観測データを推論ネットワークと生成ネットワークに通した結果が元に近いほど損失関数の第2項は小さくなるため、損失関数の第2項は再構成誤差に対応しています。

以降から本記事で使用するサンプルコードについて説明します。サンプルコードは以下のGitHubリポジトリに保存していますので御覧ください。

本記事では推論ネットワーク q_φ(z|x) と生成ネットワーク p_θ(x|z) ​ をそれぞれ畳み込みを使ったニューラルネットワークに設計します。PyroとPixyz それぞれのサンプルコードで共通のニューラルネットワークを使うことにしましょう。

推論ネットワークである Encoder クラスは nn.MaxPool2d によって特徴マップを空間的にダウンサンプリングしていき、最終的に正規分布の平均と分散に対応する locscale を出力します。

生成ネットワークである Decoder クラスは nn.ConvTranspose2d によって特徴マップを空間的にアップサンプリングしていき、最終的に生成画像を出力します。出力は nn.Sigmoid を通っており、出力の値域は [0, 1] になっています。

本記事でモデリングの対象とするデータは、手書き数字のMNISTです。MNISTのデータローダーを作成するコードを以下に記載します。PyroとPixyzのサンプルコードで共通して使用します。

潜在変数の補間を行う関数を以下に記載します。こちらもPyroとPixyzのサンプルコードで共通して使っています。潜在変数 z の先頭4サンプルを選択し、2×2に空間的に配置したものをnn.functional.interpolate で8×8に双線形補間、その後に64サンプルになるようTensorの軸の順序を調整しています。

Pyroを使ったVAEの実装

Pyroの実装ではまず model メソッドと guide メソッドを持ったニューラルネットワークのクラスを定義する必要があります。(modelとguideをそれぞれ別のクラスとして定義することも可能です。)サンプルコードではこのクラスを VAE という名前の自作クラスとします。

ここでは推論ネットワークの機能を Encoder クラスのインスタンスとして、生成ネットワークの機能を Decoder クラスのインスタンスとして持たせています。

modelとguideのソースコード中にpyro.plate があります。これは with pyro.plate 中の pyro.sample と組み合わせることで、指定した大きさのサンプルを確率分布からサンプリングすることを意味しています。このとき、Tensor中で独立な次元を指定する必要があります。指定には to_event で次元を指定しますが、右側から数えた次元の順番であることに注意してください。例えば2次元Tensorで第1次元が独立なら to_event(1) となり、4次元Tensorで第1次元が独立なら to_event(3) となります。

VAE クラスのコード全体は以下となります。model で事前分布からサンプリングした潜在変数と、guide で観測データから推定した潜在変数に、 'z' と共通の名前をつける必要があることに注目してください。

学習ループの前に、定義した model メソッドと guide メソッド、最適化アルゴリズムと損失関数を SVI に設定する必要があります。最適化アルゴリズムは、PyTorchの torch.optimパッケージのクラスを pyro.optim.PyroOptim でラッピングして使用します。

Pyroを使ったVAEのサンプルコードの全体は以下になります。訓練時においては SVI.step メソッドで損失計算と学習パラメータ更新を、テスト時においては SVI.evaluate_loss メソッドで損失値を計算できます。

50エポック学習後の画像再構成の結果が以下です。うまく再構成できているように見えます。

Pyroを使ったVAEの画像再構成(上段:元画像 下段:再構成画像)

事前分布からサンプリングした潜在変数を、生成ネットワークを通して画像に変換した結果が以下です。幾つか怪しい数字もありますが、手書き数字のような画像が生成されていることが見てとれます。

Pyroを使ったVAEの生成画像

Pixyzを使ったVAEの実装

Pixyzでは推論ネットワークと生成ネットワークをそれぞれDistoribution APIのクラスを継承した自作クラスとして定義する必要があります。推論ネットワークは出力が正規分布であるため pixyz.distributions.Normal クラスを、生成ネットワークは出力がベルヌーイ分布であるため pixyz.distributions.Bernoulli を継承します。それぞれ初期化メソッドで、条件つき確率の確率変数とその条件に該当する変数を varcond_var で定義します。

損失関数はLoss APIを使って定義します。 pp_θ(x|z) 、qq_φ(z|x) ​、 p_priorp_θ(z) に対応しており、数式のまま損失関数を実装すればよいことが分かります。定義した損失関数と推論ネットワーク、生成ネットワーク、最適化アルゴリズムをModel APIに設定します。

生成ネットワークから観測データをサンプリングする際は、 p.sample_meanメソッドを使って分布の期待値を出力します。

p.sampleメソッドを使うと1サンプルだけサンプリングされ、画素値が白黒のバイナリ画像が生成されるため注意してください。この理由は生成ネットワークの出力をベルヌーイ分布としているため(=GeneratorクラスをBernoulliクラスを継承して実装したため、)です。p.sample メソッドを使った場合の例を以下に示します。

p.sampleを使った場合の生成画像

PixyzでのVAEのサンプルコードの全体は以下になります。以下のサンプルコードでは生成ネットワークからのサンプリングは p.sample_mean を使っています。 学習時は Model.train メソッドで損失計算からパラメータ更新までを実行でき、テスト時は Model.test メソッドで損失値の計算のみを行っています。

50エポック学習後の再構成画像は、Pyroの場合と同じようにうまく再構成できています。

Pixyzを使ったVAEの画像再構成(上段:元画像 下段:再構成画像)

潜在変数を補間したものから補間画像を生成した結果を示します。潜在空間の中で、「7」と「2」の間や「2」と「0」の間から、「3」や「8」らしき画像が生成されてます。

Pixyzを使ったVAEでの画像補間

終わりに

本記事では確率的プログラミング言語Pyroと生成モデルのツールPixyzを紹介しました。どちらもPyTorchの深層学習モジュールとの連携が容易であり、GPU演算が可能で、かつ、複雑な処理が隠蔽されて実装が簡単になっている素晴らしいツールです。

もしPyroとPixyzどちらを使うか迷っている読者の方がいらっしゃるなら、生成モデルの実装ならPixyzを使うことをお勧めします。PixyzはニューラルネットワークをDistribution APIでラッピングするだけで実装ができますし、Pixyzの論文ではPyroと速度比較した際にPixyzの方が優位だったという試験結果も記載されています [4]。本記事では紹介しませんでしたが、jupyter notebookを使えば定義した確率分布をLATEXフォーマットの数式として確認することも可能です。

一方Pyroは汎用的な確率的プログラミング言語ですから、目的が生成モデル以外ならPyroという選択になるのではないでしょうか。Pyroは様々な機能を提供しており、たとえばpyro.contrib.forecastという時系列解析の機能やpyro.contrib.gp というガウス過程に対応した機能なども提供しています。

本記事がPyroとPixyzに触れるきっかけになれば幸いです。

参考文献

[1] Kingma, Diederik P., and Max Welling. “Auto-encoding variational bayes.” arXiv preprint arXiv:1312.6114 (2013).

[2] Kingma, Diederik P., and Max Welling. “An introduction to variational autoencoders.” arXiv preprint arXiv:1906.02691 (2019).

[3] 須山 敦志(2019)『ベイズ深層学習』講談社

[4]鈴木雅大, et al. “Pixyz: 複雑な深層生成モデル開発のためのフレームワーク.” 人工知能学会全国大会論文集 一般社団法人 人工知能学会. 一般社団法人 人工知能学会, 2019.

--

--

大川 洋平
PyTorch
Writer for

機械学習と深層学習を使った外観検査やロボット制御技術の開発に取り組んでいます。著書「PyTorchニューラルネットワーク 実装ハンドブック」「NumPy&SciPy数値計算実装ハンドブック」