DenseNetの論文を読む(初心者)
ResNet、ResNeXt、Xceptionに引き続き、今回もDenseNetの論文を読んでまとめていきます。いつも通り、初心者向けです。
参照した論文は以下
Densely Connected Convolutional Networks
DenseNetの構造
ResNetもDensNetともに1×1、3×3のレイヤーを並べたものをブロック単位として積み重ね、浅い層から深い層へのresidualを伝えるという基本構造は似ています。
ただし、
ResNetでは正規ルートの値とresidualの値は”足し算”しますが、
x = keras.layers.Add()[main,residual]
DenseNetの場合だと”連結”します。
x = keras.layers.Concatnate()[main,residual]
growth rate
# DenseBlock
main = input
x = BatchNormalization()(input)
x = Activation(“relu”)(x)
x = Conv2D(128, (1, 1))(x)
x = BatchNormalization()(x)
x = Activation(“relu”)(x)
x = Conv2D(k, (3, 3), padding=”same”)(x)
x = Concatenate()([main, x])
Densblockをコード化すると上記のように書けますが”k”は自分で決めることができます。例えばk = 32にして上記のコードをfor文で回していくと、 Concatnate()[main,x]
によってフィルターが32ずつ増えていくのがわかりますね。これがDenseNetの1つの大きな特徴であり”growth rate”と呼ばれるものです。この”growth rate”を操ることでどの程度新しい情報をメイン側に追加するかをコントロールできるというわけです。
Compression
DenseNetはblockの間にtransition layer (1×1convと2×2平均pooling)と呼ばれるものが挟まれています。コード化すると
def transition_layer(input, input_channels):
n_channels = int(input_channels * compression)
x = Conv2D(n_channels, (1, 1))(input)
x = AveragePooling2D((2, 2))(x)
return x, n_channels
compression (圧縮係数)θは0<θ≤1の間で設定ができます(今回の論文では0.5で実験)
1×1のフィルターでチャンネル数を圧縮を行います。また続くAveragePoolingにより解像度は半分にダウンサンプリングされます。
Bottleneck layers
DenseBlock内の3×3フィルターの前に1×1フィルターを挟むと計算量が落ちますと書いてありましたが、この手法はDenseNetモデル以外にもたくさん使われているので、主だった特徴とはいえないですね。
ちなみに計算量が落ちる理由は以前のResNeXtの記事で(かなり丁寧に)説明したので省略します。
実験結果
CIFAR、SVHN、ImageNetのデータセットを使用し、いろいろ条件を変えて訓練してますが、、、、、その結果は論文見たほうが早いので省略します。ものすごくざっくりいうとResNetよりも性能いいと論文は言っています。
なぜ性能がよくなったのか??
ResNetもDenseNetも基本構造は似ていますが、Residualの入力が加算ではなく”連結”であるという一見小さな変更点が大きな効力を発揮します。
連結によって学習された機能マップは、後続のすべてのレイヤーからアクセスできます。これにより、ネットワーク全体での機能の再利用が促進され、モデルがよりコンパクトになります。
個々の層がより短い接続を介して損失関数から追加の監視を受けます。論文ではこれを一種の“deep supervision”だと言っています。
図は各層の重みの平均値をヒートマップで可視化したもので、論文では
・全ての層で重みが同ブロックの多くの入力に分散している
・浅い層の出力特徴量がdense block全体で使われている
・Transition layerでも重みが分散している
・最初の層の情報が直接最後の層に伝わっている
と言っていますが、いまいち僕はよくわかってません。
モデル全体を通して重みがいい感じに分散していますねーーーってことが言いたいのでしょうか。
最後に
レイヤー同士が密接につながり、どれくらいの情報を他のレイヤーに伝えるのかgrowth_rate:kによってコントロールできるというのがDenseNetの発明ですね。