MNISTを加工して物体検出(Object Detection)のデータセットを自作し、作成したデータセットを学習するニューラルネットワークをPyTorchで実装します。
目次
- はじめに
- 物体検出とは
- MNISTを加工したデータセットを自作する
- シンプルなPointNet構造のネットワークを定義する
- 損失関数と評価関数を定義する
- 物体検出を学習する
- おわりに
はじめに
深層学習の実装方法を学び始める際に、MNISTを使った画像のクラス分類(Classification)は良い題材になっています。最初にMNISTクラス分類から始めたPyTorchユーザーの方も多いのではないでしょうか。しかしPyTorchビギナーユーザーにとって、クラス分類に比べると物体検出タスクの実装は少し敷居が高く感じられるかもしれません。
本記事はPyTorch初心者の方を対象に、物体検出の実装例を紹介します。MNISTを加工して物体検出用のデータセットを自作し、 CenterNet (Objects as points) [1] をシンプルにしたネットワークで学習を行います。本記事は「PyTorch初心者のためのMNISTセマンティックセグメンテーション」の続編ですが、本記事単独で読むことができます。
本記事のサンプルコードは、以下のリポジトリに掲載しています。
物体検出とキーポイントベースの手法
物体検出は、バウンディングボックスによって物体のクラス分類と位置特定を同時に行うタスクです。深層学習を使った物体検出手法の多くで、アンカーボックスと呼ばれる基準ボックスからの相対関係でバウンディングボックスを推定する方法が採用されています [2, 3, 4]。画像中に高密度に配置したアンカーボックスを絞り込んでバウンディングボックスを予測しますが、正解ボックス(教師ボックス)と重なるアンカーボックスは少数のため学習が遅くなることや、アンカーボックスに関するハイパーパラメータの設計が必要になることが問題にもなります [5]。
そこでアンカーボックスを使わない、キーポイントベースの物体検出の手法も提案されるようになってきました。ConerNet [5]、ExtremeNet [6]、CenterNet (Keypoint triplets for object detection) [7] 、CenterNet (Objects as points) [1] などが提案されています。偶然でしょうが、”CenterNet”という名称で二つのキーポイントベースの物体検出手法が提案されていますので注意してください。
キーポイントベースの物体検出手法は構造がシンプルなため、PyTorch初心者でも実装は難しくありません。 本記事では後者の CenterNet (Objects as points) (以下、単にCenterNetと記述)をシンプルにしたネットワークを実装して物体検出を行います。実装するネットワークは、共通の畳み込みニューラルネットワーク部と、キーポイントマップ、オフセットマップ、サイズマップをそれぞれ出力する3つの独立した部分から成っています(図2)。
キーポイントマップはクラス別に物体中心の確率をヒートマップ形式で表現するもので、クラス数だけのチャネルを持っています。キーポイントマップは入力画像より解像度が低いため、マップから検出した物体中心の座標は真の物体中心から乖離します。オフセットのマップはその乖離を補正するためのもので、XとYの2チャネルが存在します。サイズマップはバウンディングボックスの幅と高さを表現します。
MNISTを加工したデータセットを自作する
まずMNISTを加工して物体検出用のデータセットを作成しましょう。DetectionMNIST
クラスは、正解データとしてバウンディングボックスと各種マップを作ります。
__getitem__
メソッドでは、画素値が threshold
より大きい画素を囲むようにバウンディングボックスを作っています。
ここではキーポイントベースのモデルを扱うために、__getitem__
メソッドから _make_maps
メソッドを呼び出してキーポイント、オフセット、サイズのそれぞれのマップも作成しています。マップの作成は損失関数の計算時に行うなど様々な実装の形態が考えられますが、本記事ではデータセット側でマップを作成することにします。
MNISTの本来の解像度は28×28ですが、それぞれのマップは7×7の解像度になっています。図4に示した「4」の画像のキーポイントマップでは、「4」のチャネルのみ物体中心をピークにしてガウシアン分布でヒートマップが形成されています。図5はオフセットとサイズのマップです。キーポイント以外の位置では値を設定しません。
図6に、各種マップとバウンディングボックスの関係性を示します。図中の赤丸がキーポイントマップのピーク位置に対応する物体中心位置、そこから伸びた矢印の先がオフセットで補正した物体中心位置です。補正した物体中心と、サイズマップで記したバウンディングボックスの幅と高さを使ってバウンディングボックス(図中点線)を描画しています。
DetectionMNIST
クラスではAlbumentationsを使ったデータ拡張も行っています。データ拡張の結果の一例を図7に示しています。
Albumentationsでバウンディングボックスのデータ拡張を行う際は、bbox_params
でバウンディングボックスの形式を指定することに注意してください。本記事ではバウンディングボックスは [xmin, ymin, xmax, ymax] のPASCAL VOC形式で表現しています。
物体検出用のデータローダーを作る際には、少し注意が必要です。DetectionMNIST
はそれぞれの画像についてバウンディングボックスを形状 [M, 4]のTensor、そのラベルを形状 [M] のTensorとしてtorch.utils.data.DataLoader
に提供します。ここでMはそれぞれの画像中のバウンディングボックスの個数に、4は[xmin, ymin, xmax, ymax]の4つの座標に対応します。一般的な物体検出を考えると、画像に写るバウンディングボックスの個数は不定な(=Mの数が揃っていない)ため、torch.utils.data.DataLoader
はバウンディングボックスやラベルを一つのミニバッチTensorに束ねることはできません。そこで引数のcollate_fn
を使うことで、画像とマップはそれぞれ4次元Tensorに、バウンディングボックスとラベルはTensorのlistのままでミニバッチを作成します。
MNISTでは画像中の物体(数字)は必ず一つであるため DetectionMNIST
ではM=1で統一されていますが、一般的な物体検出を意識してcollate_fn
を使ったデータローダーの実装にしています。
シンプルなPointNet構造のネットワークを定義する
図2に示したPointNetを簡単にしたネットワークを SimplePointnet
クラスとして実装します。backbone
では stride=2
の畳み込みで2回ダウンサンプリングを行い、28×28の解像度の画像から7×7の特徴マップを作成しています。keypoint_head
、 offset_head
、 size_head
でそれぞれの7×7のマップを作成しています。
キーポイントマップは最小値0最大値1になるため、keypoint_head
の最後にシグモイド関数を使っています。バウンディングボックスの幅と高さは負になることはないため、 size_head
の最後にはReLUを入れています。
ここで推論時にネットワークが出力したキーポイントマップから、物体中心を推定する方法を説明します。学習が進んだネットワークが出力するキーポイントマップは、正解のキーポイントマップ(教師キーポイントマップ)同様に物体中心をピークにしたガウシアン分布になります。そこでネットワークが出力したキーポイントマップに対して、隣接する8近傍より大きな画素値の位置をピークとして検出します。
キーポイントマップからピーク検出する順序を示したものが図8です。最初にキーポイントマップに対してカーネルサイズ3、ストライド1 のマックスプーリングでモルフォロジー膨張処理を行います。次に膨張結果とキーポイントマップが等しく、かつ、他チャネルより大きな位置をTrueにしたマスクを作成します。ただしノイズなどでピーク位置以外も抽出してしまう可能性があるため、マスクとキーポイントマップの積をとることでピーク位置だけ値が残ったマップを得ます。最終的に、指定した閾値以上である座標をピークとして検出することができます。
損失関数と評価関数を定義する
損失関数は PointnetLoss
クラスとして実装しています。損失関数は、キーポイントマップに関する損失と、オフセットとサイズのそれぞれのL1損失の重み付き和で構成されています。ここではキーポイントマップに関する損失は簡単に Binary Cross Entropy 損失で実装しています。元論文を参考にしたPenalty Reduced Focal Lossでの損失も実装して切り替えられるようにしていますので、興味のある方は御覧ください。
本記事では推論結果の評価指標として、正解ボックスと推論ボックスの Intersection over Union (IoU) から判定した真陽性(TP)・偽陽性(FP)・偽陰性(FN)のバウンディングボックスの個数をカウントすることにします。ここでは詳細には触れませんが、物体検出タスクではTP・FP・FNの個数を元に算出したAverage Precision (AP) を評価指標とするケースが多くみられます。
評価関数のクラスとして、BboxScoring
クラスを実装しています。バウンディングボックス同士のIoUは、torchvision.ops.box_iou
を使いました。この関数はM個とN個のバウンディングボックスに対してM×NのIoUの行列を算出します。正解バウンディングボックスに対してIoUが大きい順に推論バウンディングボックスを当てはめ、TP・FP・FNを算出しています。
物体検出を学習する
実装したコードの全体は、本記事のリポジトリを御覧ください。物体検出の学習は learning_detection.py
、推論は inference_detection.py
で実装しています。
20エポック学習したシンプルなPointNetで、テストデータに対して推論を行った結果が図9です。正解のバウンディングボックスを実線、モデルの推論結果を破線で示しています。色は各クラスに対応しています。多くの画像で、正解ボックスに近いバウンディングボックスを推論できました。推論結果を評価したところ、テストデータの10000個のバウンディングボックスに対して、TP 9825個、 FP 178個、FN175個の精度となりました(IoU0.75以上をTPとして判定)。
おわりに
本記事ではPyTorchで物体検出を実装する方法を紹介するために、MNISTを加工した自作データセットを作成し、シンプルなPointNetを学習させました。
本記事で実装したネットワークは簡素なものでしたが、扱ったデータセットが低解像度で背景が黒一色、しかも画像には必ず一つだけ物体が写っているという単純なものだったため、うまく対応できました。一般的な物体検出データセットの多くは、より複雑で難しいものになります。TorchVisionは torchvision.models
でさまざまな物体検出のモデルを提供しています。実際の物体検出データセットに、TorchVisionが提供するモデルを試してみてください。
本記事がPyTorchで物体認識を実装する参考になれば幸いです。
参考文献
[1] Zhou, Xingyi, Dequan Wang, and Philipp Krähenbühl. “Objects as points.” arXiv preprint arXiv:1904.07850 (2019).
[2] Ren, Shaoqing, et al. “Faster R-CNN: towards real-time object detection with region proposal networks.” IEEE transactions on pattern analysis and machine intelligence 39.6 (2016): 1137–1149.
[3] Redmon, Joseph, et al. “You only look once: Unified, real-time object detection.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
[4] Liu, Wei, et al. “Ssd: Single shot multibox detector.” European conference on computer vision. Springer, Cham, 2016.
[5] Law, Hei, and Jia Deng. “Cornernet: Detecting objects as paired keypoints.” Proceedings of the European conference on computer vision (ECCV). 2018.
[6] Zhou, Xingyi, Jiacheng Zhuo, and Philipp Krahenbuhl. “Bottom-up object detection by grouping extreme and center points.” Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019.
[7] Duan, Kaiwen, et al. “Centernet: Keypoint triplets for object detection.” Proceedings of the IEEE/CVF International Conference on Computer Vision. 2019.