One Class Classification on OCT image dataset

Nop-anon Singborvonan
Super AI Engineer
Published in
4 min readMar 28, 2021
(Source https://www.kaggle.com/paultimothymooney/kermany2018)

สะวีดัสครับ เอ๊ย!! สวัสดีครับ เอ๊ย!! ถูกแล้ว!! กลับเข้าสู่ซีรี่ Super ai กันอีกรอบ เเต่รอบนี้ไม่ใช่บทความในส่วนของ Ai ทั่วไปเเล้ว เเต่จะเป็นเนื้อหาทางด้านเทคนิคกันบ้าง ในบทความนี้ผมจะพาทุกคนสร้างโมเดลกันบ้างครับ ผมจะไม่ได้ลงลึกในเชิงทฤษฎีมากนะครับ แนะนำว่า ใครยังไม่รู้จัก One Class ลองศึกษาเพิ่มเติมว่าไอเดียหรือพื้นฐานของมันเป็นยังไงก่อนครับ สำหรับ dataset ที่ผมจะใช้ มาจาก kaggle ลิ้งนี้เลย( https://www.kaggle.com/paultimothymooney/kermany2018 ) dataset ที่เราจะใช้กันเป็นภาพของตาครับเป็นโรค 3 โรคที่เเตกต่างกัน ลักษณะของภาพต่างกันและอีกชนิดเป็นภาพของตาปกติครับ ผมขอข้ามบางช่วงเช่น การวิเคราะห์ข้อมูล หรือ preprocess ไปนะครับ เนื่องจากจะทำให้บทความนี้ยาวเกินไป เเต่จะบอกว่าแบ่งข้อมูลแบบไหนในสัดส่วนเท่าไหร่ครับ

ข้อมูลของเรามันจะมีทั้งหมด 3 โฟล์เดอร์ครับ โดยเขาแบ่งมาให้ คือ Train Test และ Validate แต่เนื่องจาก Validate ที่เขาแบ่งมาน้อยเหลือเกิน ดังนั้นเเล้ว ผมขอไม่ใช้นะ(555555+ หยิ่งอ่ะ แต่น้อยจริง Test ก็น้อยครับ)

จะเห็นว่า ข้อมูลในส่วนของ Test เนี่ยมัน balance อยู่เเล้วนะครับ ผมจะไม่ทำอะไรกับมัน เเต่ในส่วนของ Train เนื่องจากอย่างที่บอกว่า validate ที่เขาให้มามันน้อยเสียเหลือเกินครับ ดังนั้นเเล้ว ผมขอแบ่ง 25% ของTrain ไปเป็นตัว validate นะครับ

เเต่จากข้อมูลถ้าดูดีๆ เราจะพบว่าข้อมูลมันมี ID หรือตัวเลขตอนต้น ซึ่งเป็นรหัสผู้ป่วยเป็น ID เดี๋ยวกันครับ ดังนั้นแล้ว เราจึงควรที่จะให้ผู้ป่วยคนเดียวกัน อยู่ในส่วนของข้อมูลเดียวกัน เช่น ภาพID 1111111–1.jpg อยู่ใน Trainset เราก็ควรให้ 1111111–2.jpg อยู่ใน Trainset เหมือนกัน เพราะภาพมาจากผู้ป่วยคนเดียวกัน มันจะมีลักษณะที่คล้ายคลึงกัน ถ้าอีกภาพไปอยู่ใน validate มันจะเหมือนเป็นการ สปอยโมเดลเราครับ

หลังจาก แบ่งข้อมูลเสร็จเรียบร้อย ผมขอทำการ Oversamping กับตัว Trainset ให้ข้อมูลมัน balance เสียก่อน

เย้ ข้อมูล balance แล้วทำไรต่อดี……. ผมจะทำการสร้างโมเดลที่ใช้ในการ Extract Feature ของภาพเเต่ละภาพครับโดยใช้ Restnet50 และใช้ pretrain ของ image net (ผมได้ลองทั้งแบบที่ใช้ pretrain อย่างเดียวใน Extract Feature เพื่อเข้าสู่โมเดล One Class แล้ว พบว่าการเทรนเพิ่มเพื่อปรับ weight ของRestnet50 ให้ผลการการพยากรณ์ที่แม่งยำกว่า ใช้ pretrain เพียงอย่างเดียวเป็นอย่างมาก (acc~69%ของโมเดล One class))

Extract Feature

มาสร้าง Restnet50…

จะเห็นว่าผมตั้งชื่อโมเดลเป็น autoencoder 55555555+ อย่าไปสนใจนะครับ ผมลืมแก้ เพราะก่อนจะใช้ restnet50 ผมลองสร้าง autoencoder มาก่อน เเล้วภาพที่ได้มามันเบลอเกินไป อีกทั้ง พอเอาผลของการ Extract Feature ที่ laten space ไปใช้ต่อกับ โมเดล one class ผลมันไม่ได้เรื่องเลย 555+ ผมเลยเปลี่ยนมาใช้ restnet50 แทน จากโมเดลจะเห็นว่า เราจะได้ออกมา สองโมเดลหลังจากการ Fit เสร็จ คือ โมเดล autoencoder โมเดล และ restnet_model ซึ่งผมจะอธิบายให้ฟังครับว่ามันต่างกันอย่างไร (ภาพตอนโหลดเข้า เป็น graysclae ขนาด 224*224 ครับ ผมได้ลองภาพขนาดเล็กที่ 64*64,112*112,224*224 แล้วครับ ผลของการทดลองดีขึ้นตามขนาดของภาพที่ใหญ่ขึ้นครับ)

  1. โมเดล Autoencoder (อยากเปลี่ยนชื่อจัง555+ มันชื่อAutoencoder เเต่มันไม่ใช่ โมเดลAutoencoder นะ อื้มๆๆๆๆๆ Autoencoderที่ไม่ใช่Autoencoder getsunova มั๊ยครับ555555+) ไอ่เจ้าโมเดลตัวนี้ถ้าดูที่ layer สุดท้าย มันจะมี 4 node ซึ่งเท่ากับจำนวนของ class ของภาพเรา 4 โรค (Normal,DME,CNV,DRUSEM) ผมส้รางโมเดลตัวนี้เพื่อปรับ weight ของ restnet50 ข้างบน ให้เหมาะสมกับข้อมูลชุดนี้มากขึ้น
  2. โมเดล restnet_model ไอ่ตัวนี้เเหละครับที่จะนำไปใช้ต่อ เพราะ เมื่อเราสั่ง fit ทั้งโมเดลเเล้ว restnet_model มันจะโดน fit ไปด้วย ผลลัพธ์ของโมเดลนี้มันจะออกมาเป็น tensor 2 dim (-1,2048) หลังจากที่เราเรียกโมเดล restnet_model ผลลัพท์ที่ได้ เราก็จะนำไปใช้กับโมเดล One Class ต่อไป…

มาดูผลลัพท์ของโมเดล

Validate set :

Test set :

ผลออกมาโอเคมากเลย โอเคจนคิดว่า ทำไมเก่งเกิน…..5555555+ แต่ตามชื่อของบทความนี้ครับ One class ไม่ใช่ 4 Class ดังนั้นเเล้ว อย่าไปสนใจมันครับ (เอาตรงๆครับ ตอนผมทำตอนแรก ผมตั้งใจแค่ปรับ weight ของRestnet50 เท่านั้นครับ ไม่สนว่า โมเดลที่ทำนาย 4 class ผลมันจะเป็นอย่างไร แต่ผมมันดันออกมาดีเอง555555+)

One Class Model

ก่อนที่เราจะสร้าง one class เราต้องเตรียมข้อมูลของเราใหม่เสียก่อน คือไอ่เจ้า one class ของผมเนี่ย มันจะเรียนรู้บนข้อมูล class เดียว ซึ่งในที่นี้คือ Normal

ทำไมถึงเป็น normal เพราะผมต้องการที่จะสร้างโมเดลที่จะตรวจจับความเป็น normal ของภาพ oct เอาง่ายๆที่เป็นภาษาคนคือ ผมอยากบอกว่าตาของคนนี้ปกติ หรือตาของคนคนนี้ผิดปกติ ดังนั้นเเล้วถ้าเราสร้างโมเดล one class ได้สำเร็จ แปลว่า เมื่อเรามี dataset ชุดใหม่ ที่ไม่ใช่โรค 3 โรค ตามที่เราเคยทำตอน Extract Feature (DME,CNV,DRUSEM) โมเดล one class ของเรามันจะบอกได้ว่า ตาของคนคนนี้ปกติหรือไม่ปกติ แม้ผู้ป่วยเป็นโรคอื่นที่ไม่เคย train มาก่อน เพราะตอนเรา train one class ผมจะสอนมันเพียง class เดียว คือ class ของ normal

ดังนั้นเเล้ว เราต้องเตรียมข้อมูลใหม่เสียก่อนที่จะเข้าไปเทรน เราต้องเทรนมันแค่ normal เพียง class เดียว (*** อันนี้ไม่จำเป็นต้องเทรนเพียง normal เพียงคลาสเดียวนะครับ การทำ one class สามารถทำได้หลากหลาย เราอาจจะผสม abnormal เล็กน้อยปนไปกับ normal ก็ได้ แต่ในที่นี้ผมเลือกที่จะหยิบแค่ normal เพียวๆเลย เพราะมันเตรียมข้อมูลง่าย แค่นั้นเลยครับ เอาตรงๆ ขี้เกียด55555+)

ขั้นเเรก เตรียมข้อมูล บนข้อมูลชุดเดิมเลยนะครับ เเต่คราวนี้ ผมจะหยิบมาเเค่ 1 class ใน trainset คือ class normal ซึ่ง ผมหยิบเอา 75% ของ normal ทั้งหมดที่อยู่ใน trainset เดิมให้อยู่ในtrainset เช่นเดิม ส่วนอีก 25% ของ normal และ 100% ของอีก 3 class (DME,CNV,DRUSEM) ย้าย ไปรวมกันกลายเป็น Validate และกำหนดให้ชุด Validate เดิม กลายเป็นชุด Test ชุดใหม่ และชุด test เดิม กลายเป็น Smalltest แทน

Train set :

Validate set :

Test set :

Smalltest set :

(****ในที่นี้ 1 คือ Normal และ 0 คือ abnormal)

เมื่อเตรียมข้อมูล พร้อมเเล้ว มาลุยกันเลย

ขึ้นเเรก เรียก โมเดล restnet_model มา Extract Feature ของภาพเสียก่อน

จากนั้น ผมเลือกใช้โมเดล GaussianMixture (อันนี้ผมเลือกเอง ใครลองทำใช้อันอื่นได้นะครับ เช่น SVM.OneClass เช่นเดียวกับ restnet 50 ครับ ผมเลือกมาเอง อาจลองใช้อะไรที่มันซับซ้อนกว่านี้ parameter มากว่านี้ก็ได้ครับ)

ผลลัพธ์ที่ได้ตอนสุดท้ายมันจะเป็น prob ออกมาครับ ดังนั้นเเล้ว ก็ลองหา threshold ที่เหมาะสมสำหรับข้อมูลที่ทดลองครับ

Test set :

Smalltest set:

อย่าลืมนะครับ ว่า ผมกำหนด 1 เป็น normal(5555555+)

สุดท้ายเเล้วหากผมทำอะไรผิดพลาด หรือไม่แม่นยำตรงไหน ก็ต้องขออภัยไว้ ณ ที่นี้ด้วยครับ ผมหวังว่า บทความนี้จะมีประโยชน์ต่อผู้ที่กำลังศึกษาหาความรู้ครับ

--

--