不到50行程式讓 AI 幫你偵測異常資料

Abby Yeh
Taiwan AI Academy
Published in
7 min readJun 18, 2019

A Simple Way to Detect Abnormal Images

Photo by Markus Spiske on Unsplash

越來越多應用不適用一般的 CNN classifier,像是瑕疵檢測、癌症預測。幾萬筆資料中可能只有幾百筆是希望被檢測出來的資料。然而當資料比例懸殊的時候,模型常常會學會把所有資料都預測為資料量多的類別來達到高準確率,根本無法檢測出希望被檢測出的資料。

這種情況下,可以運用統計學的 oversampling、downsmapling 來處理不平均的資料,或是調整 class weights 去強迫模型去判對資料量少的類別。

但有時候可能這些資料少到用上述的方法會導致嚴重的 overfitting。這時候就可以用異常檢測 (anomaly detection) 來代替分類。異常檢測的概念就是只使用資料多的那個類別的資料作爲正常資料來訓練模型,在預測時就可以判斷資料是否正常。

只看正常資料要怎麼分辨出資料是不是異常呢?下面就是使用簡單的 CNN feature extractor 加上 clustering 來達成瑕疵檢測的例子。

figure 1. 論文中的示意圖

這個方法是借用來自一篇 ICML 2018 的論文 Deep One-Class Classification 的概念。打開這篇論文你可能會被滿滿的數學式嚇到,但其實這篇論文所介紹的 Deep SVDD 概念其實十分簡單。 Deep SVDD 參考了 Support Vector Machine 的概念,將資料轉換到另一個維度的空間上,讓 one class classifier 可以把正常資料和異常資料分開。

利用 Keras + Scikit Learn 這兩個套件,幾乎不用寫程式就可以實現類似的概念。

— Python 實作 —

首先要匯入需要的函式庫,我們只需要這四個函式就可以來實作了。

在論文中,是使用神經網路學會 feature representations 並找到圓心和半徑來分開圓內的正常資料與圓外異常的資料。雖然這個網路架構十分簡單, loss function 也不難實作,但是需要自己把 gradient 傳回去。對不熟悉函式庫的人會有點難度。

所以這裡我們會用偷吃步,用 ImageNet 的 pre-trained model 來代替論文中的 feature representations。這樣只要簡單的四行程式就可以把圖片映射到一個比較好分割的空間。

先用 Keras 載入用 ImageNet 訓練過的模型 ,接著選定一層 layer,過 global average pooling 降低維度,就完成 feature representations 的模型了。

one class classifier 的部分用 Scikit Learn 的 clustering 就可以輕鬆實現了。

figure 2. clustering examples from Scikit Learn documentation.

用 Scikit Learn 的 Gaussian mixture model 來模仿 one class classifier ,這個模型是找到平均值和標準差來表示資料和論文中找原心和半徑不同。把平均值視為是圓心就很接近原本的方法了。

要判斷是否為異常值時,用 Gaussian mixture model 的 score_samples 函式,得到資料為該類別機率來代替資料到圓心的距離 (離分佈中心越近機率越大),最後根據資料的機率分布挑出一個 threshold 代替半徑就完成模型啦。

接下來就讀入資料試試看吧!讀入 Kears 裡的 MNIST 資料集,將「 1」 作為正常值,「7」 作為異常值。x_ok 為訓練資料中標籤為「 1」的資料,x_test、y_test 為測試資料中標籤為「 1」或「7」的資料。

reshape_x: 因為使用 Keras 的 VGG pre-trained model,圖片最小必須是 30*30,所以將原本的圖片放為兩倍大,並將圖片變為 3 維。

訓練的部分只有兩行,使用前面用 VGG16建立的模型預測訓練圖片 (x_ok) 得到新的 features 作為 input data,訓練 Gaussian mixture model。

— 預測結果 —

利用 score_samples 計算訓練資料的分數,將平均加三個標準差作爲 threshold 就可以來預測了。

一樣使用前面用 VGG16建立的模型預測測試圖片 (x_test) 得到新的 features 作為 input data,用 Gaussian mixture model 的 score_samples 算出測試資料的分數。並將大於 threshold 的值判為正常,小於 threshold 的判為異常就可以得到大約 0.97 的正確率。

the output of svdd_part6.py

接著來視覺化預測的結果吧!將測試資料中的 「1」標為藍色,「7」標為粉紅色,畫成 x 軸為第幾筆資料,y 軸為 score 的分散圖。再用折線圖標上 threshold 就可以得到下面這張圖片。

visualization of output,score 都是負數是因為 score_samples 的結果是取完 log 的機率

從視覺化的結果可以看出來,預測的結果很不錯,藍色的點都集中在 threshold 上方而且離 threshold 有一小段距離,而粉紅色的點大部分也都在 threshold 下方。

但是還是沒有百分之百的的判對。把原始的 x_test 畫出來,就可以發現只要「1」不是直直的一豎,「7」上面那橫不夠寬都很容易判錯。

上面為實際為「1」判斷為「7」,下面為實際為「1」判斷為「1」。
上面為實際為「7」判斷為「1」,下面為實際為「7」判斷為「7」。

想要得到更好的準確率,或許可以把 VGG16 換成自己訓練的 autoencoder,畢竟 ImageNet 的圖片和 MNIST 還是有些落差。換成其他資料差別更大資料及,模型的表現只會更差。想要挑戰的話,也可以嘗試用文中提到 Deep SVDD,把資料轉移到一個真的可以用一個圓劃分正常值與異常值的空間。

— Deep SVDD—

當然真正的 Deep SVDD 是比這個方法複雜得多,有興趣的話可以參考作者提供的 PyTorch 實作,其方法概略如下:

  1. 建立一個 CNN 模型將圖片轉成 d 維的向量, 一個 d 維向量— 中心和 一個 1 維 向量 — 半徑。
  2. 訓練一個 autoencoder,其 encoder 的架構和前述的 CNN 模型相同),將圖片 encode 到 d 維的向量,並作為CNN 模型的 pre-trained model。
  3. 最小化 CNN output 到中心的距離平方減半徑平方的平均,再加上半徑平方 (regularization), 來訓練CNN 模型、中心和半徑 。

— 參考資料 —

[1] ICML 2018 Paper Deep One-Class Classification

[2] PyTorch implementation from the paper’s author

[3] Scikit Learn documentation

--

--