MNISTを加工したセマンティックセグメンテーション(Semantic Segmentation)のデータセットを自作し、作成したデータセットを学習するニューラルネットワークをPyTorchで実装します。
目次
- はじめに
- セマンティックセグメンテーションとは
- MNISTを加工したデータセットを自作する
- シンプルなU-Net構造のネットワークを定義する
- 交差エントロピー損失と評価関数を定義する
- セマンティックセグメンテーションを学習する
- おわりに
はじめに
MNISTを使った画像のクラス分類(Classification)は、深層学習の実装方法を学ぶうえで良い題材です。MNISTでは簡単な構造のニューラルネットワークを短時間だけ訓練してもクラス分類ができるようになるため、ニューラルネットワークを学習させる仕組みを理解しやすいでしょう。PyTorchの公式リポジトリでもMNISTクラス分類の実装例が公開されています。
本記事ではMNISTクラス分類を学んだPyTorchビギナーユーザーを対象に、セマンティックセグメンテーションの実装例を紹介します。MNISTを加工してセマンティックセグメンテーション用のデータセットを自作し、U-Net [1] をシンプルにしたネットワークで学習を行います。
本記事のサンプルコードは、以下のリポジトリに掲載しています。
セマンティックセグメンテーションとは
セマンティックセグメンテーションは、画素単位でクラス分類を行うタスクです。クラスを識別するためには物体の全体像を捉えることが重要になるため、畳み込みニューラルネットワークでは広い受容野を形成することで物体全体の特徴を捉えます。広い受容野を形成するにはダウンサンプリングによって特徴マップの解像度を縮小することが効果的ですが、一方で解像度が小さくなった特徴マップを単純にアップサンプリングで元の解像度に戻すだけでは物体の輪郭を正確に捉えることが難しくなるというデメリットも存在します。
U-Net は、スキップ結合を持ったエンコーダー・デコーダー構造のセマンティックセグメンテーションのモデルです(図1)。スキップ結合によって物体全体の特徴を捉えながら正確な位置特定を可能にしています。
U-Netのエンコーダーはダウンサンプリングと畳み込みを繰り返すことで広い受容野を形成します。一方でデコーダーは、エンコーダー側の特徴マップをスキップ結合で組み込みながらアップサンプリングと畳み込みを行います。
U-Netは後のセマンティックセグメンテーションの手法に大きな影響を与えました。ボリュームイメージに対応した 3D U-Net [2] や、U-Net構造を発展させたU-Net++ [3] 、Attention U-Net [4] 、U-Netが入れ子構造になったU²-Net [5] などが提案されています。
本記事では、シンプルなU-Net構造のニューラルネットワークでMNISTのセマンティックセグメンテーションを行います(図2)。ニューラルネットワークはダウンサンプリング(図中赤線)と3×3の畳み込み(図中黄色線)を繰り返すことで広い受容野を形成します。アップサンプリング(図中青線)した特徴マップと高解像度な特徴マップをチャネル方向に結合(図中灰色線)し、大域的な特徴と局所的な特徴を合わせた特徴マップを作成しています。最終的にネットワークは、1×1の畳み込み(図中紫線)を使って各画素におけるクラスに対するロジットを出力します。
MNISTを加工したデータセットを自作する
MNISTでセマンティックセグメンテーションを行うために、MNISTを加工したデータセットを作成しましょう。本記事では、 torch.utils.data.Dataset
を継承した SegmentationMNIST
クラスとして実装します。 SegmentationMNIST
はTorchvisionの torchvision.datasets.MNIST
をコンポジションで利用することにします。
SegmentationMNIST
クラスは __getitem__
メソッドで画素値が threshold
を超えた画素を前景、閾値以下の画素を背景として正解マスク(教師マスク)を作成しています。背景には0番のラベル、前景には元のラベル番号に1加算したラベルを与えていることに注意してください。つまり、この自作セグメンテーション用データセットのクラス数は11です。正解マスクはone-hotエンコーディングしておらず、データ型は torch.int64
になっています。画像と正解マスクの例を図3に示します。
__getitem__
メソッドではデータ拡張(Data Augmentation)も行っています。データ拡張には、Albumentationsを使用しました。セマンティックセグメンテーションでは、画像に適用した幾何変換のデータ拡張と同一の変換を正解マスクに対して行う必要がありますが、Albumentationsで簡単に実装することができます。Albumentationsでtorch.Tensorへ変換するには、albumentations.pytorch.transforms.ToTensorV2を使いましょう。図4は、ランダムアフィン変換をデータ拡張に使用した例です。
シンプルなU-Net構造のネットワークを定義する
図2に示したネットワークを、SimpleUNet
クラスとして実装します。ここではダウンサンプリングは torch.nn.MaxPool2d
でのマックスプーリングを、アップサンプリングは torch.nn.Upsample
での双線形補間を使っています。ネットワーク中の層を nn.Sequential
でブロックにまとめ、そのブロックを nn.ModuleList
でエンコーダーやデコーダーにまとめています。
順伝播である forward
メソッドでは、エンコーダーで解像度が低くなっていく特徴マップを順にlistに追加してきいます。作成したlistから特徴マップを [::-1]
で逆順に取り出すと、解像度が低いものから高いものの順になります。逆順に取り出した特徴マップとデコーダーのブロックを zip
で同時にループし、エンコーダー側の特徴マップとデコーダーで作成した特徴マップを torch.cat
でチャネル方向に連結しながら処理していきます。
交差エントロピー損失と評価関数を定義する
ここでPyTorchでクラス分類を実装する方法を振り返ってみましょう。クラス分類の損失を実装する簡単な方法は、ニューラルネットワークの出力をTensor形状が [N, C] のロジット、正解ラベル(教師ラベル)をTensor形状が [N] の整数として、ロジットと正解ラベルを torch.nn.CrossEntropyLoss
に与えてソフトマックス交差エントロピー損失を計算することです。ここで N はミニバッチサイズ、C はクラス数です。正解ラベルはone-hot形式ではなくラベル番号を整数で与えることに注意しましょう。
実は torch.nn.CrossEntropyLoss
は、ロジットのTensor形状が [N, C, d_1, d_2, …, d_k] 、正解ラベルのTensor形状が [N, d_1, d_2, ..., d_k] であれば、任意のTensor形状に対してソフトマックス交差エントロピー損失を計算することができます。この機能を利用すれば、セマンティックセグメンテーションでロジットのTensor形状を [N, C, H, W] 、正解マスクのTensor形状を [N, H, W] としてセマンティックセグメンテーションで交差エントロピーを実装することができます。ここで H は画像高さ、W は画像の幅です。 torch.nn.CrossEntropyLoss
では正解ラベルのデータ型は torch.int64
にする必要があるため、SegmentationMNIST
の __getitem__
メソッド内で正解マスクのデータ型を torch.int64
に変換しています。
ネットワークがロジットを出力すると、出力したロジットに対してチャネル方向に torch.argmax
をとることで、ロジットを推論マスクに変換することができます。本記事では推論したマスクの評価指標として、前景についての真陽性(TP)・偽陽性(FP)・偽陰性(FN)の面積を算出することにします。実際の多くのケースでは、このTP・FP・FNの数値からIoUやDICEスコアといった評価値が算出されます。
実装した MaskScoring
クラスはマスクをone-hot形式にして論理演算を行い、TP・FP・FNの面積を算出するクラスです。one-hot形式への変換には torch.nn.functional.one_hot
が便利です。
セマンティックセグメンテーションを学習する
セマンティックセグメンテーションの学習を行うコードはリポジトリの learning_segmentation.py
に、推論を行うコードは inference_segmentation.py
に実装してあります。
20エポック学習したシンプルなU-Netモデルで、テストデータに対して推論を行った結果の一例を図6に示します。元画像と推論結果のマスクを重ねて表示しています。一部を誤っていますが(下から2段目、右から3列目の「3」の画像など)、おおむね正確に推定できています。テストデータに対してモデルを評価したところは、TPの面積が1,247,593に対して、FPの面積は21,385、FNの面積は31,288でした。
おわりに
本記事ではPyTorchでセマンティックセグメンテーションを実装する方法を紹介するために、MNISTを加工した自作データセットを作成し、シンプルなU-Netを学習させました。
TorchVisionは、torchvision.datasets
からさまざまなセマンティックセグメンテーションのデータセットにアクセス可能です。本記事では低解像度で背景が黒一色という単純な画像を扱いましたが、より高解像度で複雑な画像のデータセットをTorchVisionから試してみてはいかがでしょうか。より複雑な画像を学習するためには、モデルの構造や損失関数を工夫する必要があります。
本記事でPyTorchでのセマンティックセグメンテーション実装に興味を持っていただければ幸いです。
参考文献
[1] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. “U-net: Convolutional networks for biomedical image segmentation.” International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.
[2] Çiçek, Özgün, et al. “3D U-Net: learning dense volumetric segmentation from sparse annotation.” International conference on medical image computing and computer-assisted intervention. Springer, Cham, 2016.
[3] Zhou, Zongwei, et al. “Unet++: A nested u-net architecture for medical image segmentation.” Deep learning in medical image analysis and multimodal learning for clinical decision support. Springer, Cham, 2018. 3–11.
[4] Oktay, Ozan, et al. “Attention u-net: Learning where to look for the pancreas.” arXiv preprint arXiv:1804.03999 (2018).
[5] Qin, Xuebin, et al. “U2-Net: Going deeper with nested U-structure for salient object detection.” Pattern Recognition 106 (2020): 107404.