Unsupervised Anomaly Detection with Generative Adversarial Networks to Guide Marker Discovery (AnoGAN)

การตรวจจับความผิดปกติในภาพด้วย Generative Adversarial Network (GAN)

Mail'Suesarn Wilainuch
Super AI Engineer

--

สวัสดีครับ วันนี้ผมจะมา รีวิว paper งานวิจัยที่ใช้ โมเดลอย่าง Generative Adversarial Network (GAN) มาทำการตรวจจับความผิดปกติในข้อมูลประเภทรูปภาพครับ ถ้าพร้อมแล้วเริ่มกันเลยครับ!

paper ที่จะมารีวิวในวันนี้ชื่อ Unsupervised Anomaly Detection with Generative Adversarial Networks to Guide Marker Discovery (หรือ AnoGAN)

สำหรับใครที่ยังไม่รู้จักว่า Generative Adversarial Network (GAN) คืออะไรผมแนะนำให้ลองหาอ่านศึกษาได้จากที่อื่นก่อนครับ เพราะผมจะไม่ได้ลงรายละเอียดเกี่ยวกับ GAN เนื่องจากในงานนี้จะเป็นการนำ GAN มาประยุกต์ใช้งานครับ ถ้าเพื่อนๆมีความเข้าใจในตัว GAN ระดับนึงแล้วจะเข้าใจงานนี้ได้ไม่ยากครับ

ก่อนอื่นขอบอกว่าในงาน Anomaly Detection นั้นมีหลายประเภท ขึ้นอยู่กับข้อมูลที่เราสนใจ หรือลักษณะงาน แต่ในงานนี้จะเป็นข้อมูลรูปภาพครับ และ จริงๆแล้วนิยามของคำว่า “Anomaly” หรือ ความผิดปกตินั้นก็มีหลายความหมาย เพราะขึ้นอยู่กับมุมมอง

ขอเท้าความถึง GAN เล็กน้อย GAN หรือ Generative Adversarial Network เป็นโมเดลตระกูล Generative model ที่เรียนรู้ ที่จะสร้างของ หรือข้อมูลที่เราต้องการ ข้อมูลของเราในที่นี้คือรูปภาพ โดย GAN จะประกอบด้วย AI 2 ตัว หรือ 2 โมเดล ที่แข่งกัน ตัวหนึ่งทำหน้าที่สร้างภาพ เรียก Generator จะเรียนรู้ที่จะสร้างภาพให้เหมือนกับใน Dataset ซึ่งก็คือภาพที่เราใช้เทรน และอีกตัวทำหน้าที่ ตรวจสอบว่าภาพที่มันได้รับเป็นภาพจริงใน Dataset หรือภาพที่ถูกสร้างจาก Generator เรียก Discriminator ซึ่งจะเรียนรู้ที่จะแยกระหว่างภาพจาก Dataset กับ ภาพที่ถูกสร้างจาก Generator ดังนั้นถ้าการเทรนเป็นไปด้วยดีสุดท้าย Generator จะสามารถสร้างภาพที่เหมือนกับใน Dataset มากจนแม้แต่บางทีเราก็แยกไม่ออก

รูปที่ 1 : GAN Architecture Image source

จากรูป GAN Architecture จะเห็นว่า Discriminator รับ Input เป็นภาพจาก Training set หรือก็คือ Dataset ของเรา บางครั้งเราจะเรียกว่า ภาพจริง(Real image) และรับ ภาพจาก Generator หรือเรียกว่า ภาพปลอม(Fake image) และทายว่าภาพนั้นจริงหรือปลอม ส่วน Generator จะรับ Input เป็น Random noise หรือก็คือ ชุดตัวเลขจำนวนจริงที่เราสุ่มขึ้นมา โดยปกติจะเป็นการสุ่มแบบ Gaussian-distribution หรือ Uniform-distribution บางครั้งเราจะเรียกชุดตัวเลขนี้ว่า latent space หรือ latent code มักแทนด้วยตัวแปร z (ชุดตัวเลข 1 มิติ หรือมองเป็น vector ก็ได้ แต่ถ้าเรานำไปพล็อต 2 มิติ ก็จะเห็นเป็นจุดๆ (noise image) อย่างในรูปที่ 1)

ดังนั้นเราจะเห็นว่า Generator เรียนรู้ที่จะจับคู่ชุดตัวเลข (latent space) ไปเป็นภาพ ( image space) หรือเรามองกลับกันได้ว่า ภาพถูกเข้ารหัสไปเป็นชุดตัวเลขผ่าน Generator ซึ่งนี้คือไอเดียหลักที่ paper ใช้ในการประยุกต์ใช้กับการตรวจจับความผิดปกติ

Introduction

งานวิจัยนี้ได้ใช้ข้อมูลภาพถ่ายทางการแพทย์ เกี่ยวกับดวงตา optical coherence tomography (OCT) scan โดยจะมีการทำ preprocessing คือสกัดเอาเฉพาะบริเวณที่ต้องการ และทำ normalization

รูปที่ 2 : OCT Image source

ดังนั้นเป้าหมายของการตรวจจับความผิดปกติในที่นี้ ก็คือ ความผิดปกติในจอประสาทตานั่นเอง (Retina) ซึ่งความผิดปกติที่เกิดขึ้น ก็จะถูกแพทย์วินิจฉัยต่อไปว่าเป็นโรคใด เนื่องจากวิธีการที่งานนี้ใช้จัดเป็น Unsupervised learning ดังนั้นเราจะไม่ได้ระบุคลาสหรือระบุว่าเป็นลักษณะของโรคใด แต่จะเป็นการระบุว่าพบความผิดปกติ

โดยปกติโมเดลทั่วไปที่เป็น classification จะต้องใช้ข้อมูลที่มีเฉลยจำนวนมาก เพื่อที่จะเทรนให้สามารถทำงานได้ดี แต่ในบางครั้งเราก็ไม่สามารถหาข้อมูลได้มากพอ ในบางคลาส ดังนั้น วิธี Unsupervised learning จะไม่มีข้อจำกัดนี้ และเราอาจจะตรวจพบความผิดปกติใหม่ที่เราไม่รู้จักมาก่อนก็เป็นได้

Proposed method

วิธีการคือ เราจะสร้าง Dataset ที่มีเฉพาะ healthy cases คือภาพของ Retina ที่ปกติสุขภาพดีไม่มีโรค หรือความผิดปกติใดๆ โดยที่เราจะ ใช้ Dataset นี้ในการเทรน GAN จนกระทั่ง Generator สามารถสร้างภาพที่เหมือนกับใน Dataset ได้ นั่นคือ Generator เรียนรู้ที่จะทำการ mapping จาก latent space ไปเป็น image space

รูปที่ 3 : เทรน GAN ด้วย healthy cases

ดังนั้นเป้าหมายคือการ หา z (latent space) ที่สอดคล้องกับ G(z) (ภาพที่สร้างโดย Generator) ที่มีความคล้ายกับภาพที่เราต้องการทดสอบมากที่สุด พูดง่ายๆคือ เราต้องการค้นหา z ของภาพที่เราต้องการจะทดสอบ x (query image) ถ้าภาพที่เราทดสอบ (x) คือ 1 ในภาพที่เราใช้เทรน (โมเดลเคยเห็นแล้ว) เราก็จะสามารถหา z ที่เป็นเหมือนรหัสของภาพนั้นได้ แต่ถ้าเราใช้ภาพที่โมเดลไม่เคยเห็นมาก่อน (Testset) แต่ภาพนั้นอยู่ใน Domain เดียวกัน คือภาพ healthy cases เราก็น่าจะสามารถค้นหา z ที่ใกล้เคียงที่สุดได้ แต่ถ้าภาพที่เราทดสอบแตกต่างจาก trainset มากก็จะไม่สามารถหา z ที่ใกล้เคียงได้ หรืออาจกล่าวได้ว่าภาพนั้น เป็นภาพที่ผิดปกติไปจาก healthy cases ดังนั้นถ้าเรามี Dataset ที่ครอบคลุมกรณีของ healthy cases มากเท่าใดก็ยิ่งทำให้โมเดลสามารถแยกระหว่างภาพปกติที่สุขภาพดีกับภาพที่มีความผิดปกติได้ดีมากขึ้นเท่านั้น

ส่วนขั้นตอนการค้นหา z นั้นจะทำผ่านการ อัพเดต z ที่เราสุ่มขึ้นมาจนกระทั่งได้ค่าที่ใกล้เคียงที่สุด โดยกระบวนการนี้ จะใช้ loss function 2 ตัวในการอัพเดต z

1. Residual Loss

residual loss จะวัดความแตกต่างระหว่าง query image x กับ generated image G(z) ใน image space กำหนดโดย

Residual Loss

ซึ่งก็คือการนำภาพที่ต้องการทดสอบ x มาลบกับภาพที่สร้างจาก Generator โดย z1 จะสุ่มขึ้นมา จากสมการนี้เราจะพบว่า ถ้า x เป็นรูปปกติไม่มีความผิดปกติ ก็จะมีความแตกต่างน้อยส่งผลให้ค่าที่ได้น้อย

2. Discrimination Loss

Discrimination Loss

discrimination loss ใช้ประโยชน์จาก discriminator ที่เรียนรู้ที่จะแยกความแตกต่างระหว่างภาพจริงจาก Dataset และภาพปลอมจาก Generator ซึ่งก็คือภาพใน Domain ของภาพปกติสุขภาพดี ดังนั้น parameter ของ discriminator ก็จะสอดคล้องกับภาพปกติ ดังนั้นจึงใช้เทคนิค feature matching ที่ last convolution layer ของ discriminator ที่เป็น richer intermediate feature ดังนั้นถ้าภาพทดสอบเป็นภาพปกติสุขภาพดีก็จะทำให้ intermediate feature ของ discriminator มีความคล้ายกับภาพจริงใน Dataset เป็นเหมือนการใช้ discriminator เป็น feature extractor

สุดท้ายเราจะนำ loss function ทั้งสองมารวมเข้าด้วยกันเป็น overall loss และทำการ weight ผลรวมของทั้ง 2 เทอม และใช้ในการอัพเดต z หรือก็คือการ mapping จาก image space ไปเป็น latent space

overall loss

โดยจะมีเพียง z ที่จะถูกอัพเดตผ่าน backpropagation ส่วน parameter หรือ weight ของ Generator และ Discriminator จะถูก fixed ไว้

Detection of Anomalies

ส่วนวิธีการประเมินและระบุว่าภาพที่ทดสอบเป็น ภาพปกติหรือผิดปกติสามารถใช้ overall loss ที่รอบสุดท้ายของการอัพเดต z เป็น anomaly score ได้โดยตรง

anomaly score

anomaly score จะมีค่ามากสำหรับ ภาพที่ผิดปกติ แต่ถ้ามีค่าน้อยหมายความว่าภาพนั้นมีความคล้ายกับ Dataset มากซึ่งก็คือ ภาพปกติสุขภาพดี

นอกจาก anomaly score ยังสามารถระบุตำแหน่งบริเวณของภาพที่มีความผิดปกติได้ จากการ นำภาพที่ต้องการทดสอบลบกับ ภาพที่สร้างจาก generator โดยใช้ z ที่ถูกอัพเดตในรอบสุดท้าย เรียกว่า residual image

residual image

จากการทดลอง paper นี้รัน 500 รอบ backpropagation ในการอัพเดต z และใช ้λ = 0.1

Results

รูปที่ 4 : Results

จากรูปที่ 4 สามารถอธิบายได้ดังนี้ แถวแรกหรือ Query image คือรูปที่ใช้ทดสอบ แถวที่สอง Generated image คือรูปที่ได้จาก Generator ที่มี input เป็น z ที่ถูกอัพเดตรอบสุดท้าย แถวที่สาม Residual overlay เกิดจากการนำ Query image แถวที่ 1 ลบกับ Generated image แถวที่ 2 บริเวณที่ไฮไลท์สีแดงคือ บริเวณที่มีความผิดปกติ แถวที่สี่ คือ Ground truth ของโรค retinal fluid

ซึ่งเราจะเห็นว่า ภาพปกติแม้โมเดลจะไม่เคยเห็นมาก่อน (Testset)แต่ก็สามารถหา z ที่ใกล้เคียงกับภาพที่ใช้เทรนได้ ดูได้จาก generated image ที่ได้มีความคล้ายกับภาพทดสอบ ซึ่งก็คือภาพปกติสุขภาพดี ส่วนภาพที่มีความผิดปกติ โมเดลไม่สามารถหา z ที่ใกล้เคียงได้ จึงทำให้ generated image ที่ได้ไม่เหมือนกับภาพทดสอบซึ่งก็คือภาพผิดปกติ

นอกจากอาการ retinal fluid ที่สนใจเราจะเห็นว่าโมเดลสามารถตรวจพบความผิดปกติอย่างอื่นที่เราไม่ได้ระบุไว้ใน ground truth (กรอบสีเขียว) เรียก Hyperreflective foci

เพื่อให้เห็นภาพของการ อัพเดต z ลองดูตัวอย่างต่อไปนี้ สมมติว่าเราเทรน GAN ด้วย MNIST Dataset ซึ่งเป็นภาพตัวเลข 0–9 ในรอบแรกเราสุ่ม z1 และป้อนให้ Generator ปรากฎว่าได้เป็น เลข 1 ออกมา โดยที่ภาพที่เราต้องการทดลองคือ ภาพเลข 7 ที่โมเดลไม่เคยเห็นมาก่อน

จากนั้นเราอัพเดต z1 ด้วย residual loss และ discriminative loss โดย fixed weight ของ Generator และ Discriminator

สุดท้ายเราจะได้ generated image ที่คล้ายกับ query image ซึ่งถึงแม้ว่าโมเดลจะไม่เคยเห็นมาก่อน แต่เราก็สามารถหา z ที่ใกล้เคียงมากๆได้ เนื่องจากภาพเลข 7 เป็นภาพใน domain เดียวกันกับที่เราใช้เทรนโมเดล ซึ่งสามารถมองเป็นภาพปกติได้เช่นกัน

Conclusion

ข้อดีของวิธีนี้ คือการที่เราไม่จำเป็นต้องใช้ label ในการเทรนโมเดลและข้อมูลภาพปกติสุขภาพดีก็ยังหาได้ง่ายกว่าภาพผิดปกติของแต่ละโรค ทำให้การวิเคราะห์ความผิดปกติในเบื้องต้นเป็นไปได้อย่างรวดเร็ว ถึงแม้ว่าเราจะไม่รู้ว่าเป็นโรคอะไรในตอนแรกแต่เราก็จะบอกได้ว่ามีความผิดปกติเกิดขึ้น ภาพที่ถูกทำนายว่ามีความผิดปกติก็จะเป็น candidate ให้แพทย์ได้วินิจฉัยต่อไป

ส่วนข้อเสียที่เห็นได้ชัดจะเป็นเรื่องของความเร็วในการประมวลผล ที่ใช้เวลานานเนื่องจากต้องใช้เวลาในการอัพเดต z สำหรับแต่ละ query image และอีกประเด็นที่สำคัญคือจะเห็นว่าวิธีนี้จะใช้งานตัว generator ที่ต้องสามารถสร้าง image ที่เหมือนกับ dataset ได้ หรือพูดอีกอย่างก็คือ การ train โมเดล GAN จะต้องสำเร็จ และนั่นก็เป็นเรื่องที่ยากเหมือนกัน ขึ้นกับหลายปัจจัย ทั้งความซับซ้อนของตัว image, ขนาดภาพและ ข้อจำกัดต่างๆ ที่แต่ละโจทย์ไม่เหมือนกัน ดังนั้นถ้าจะนำวิธีนี้ไปประยุกต์ใช้ คงต้องคิดวิเคราะห์ให้ดีว่าเหมาะสมหรือไม่

ขอขอบคุณทุกๆท่านที่ให้ความสนใจ หวังว่าบทความนี้จะมีประโยชน์ต่อผู้ที่สนใจไม่มากก็น้อย ถ้ามีข้อผิดพลาดประการใดต้องขออภัยไว้ ณ ที่นี้ด้วยครับ

“A MAN IS NOT WHAT HE OWNS.
A MAN IS WHAT HE LOVES.”
-CHATRI SITYODTONG

Reference

[1] https://arxiv.org/abs/1703.05921

--

--

Mail'Suesarn Wilainuch
Super AI Engineer

Researcher as Machine Learning & Deep Learning Engineer at Perceptra 🩺