แก้ปัญหา Imbalanced Classes ด้วย Focal Loss
สวัสดีผู้อ่านทุกท่านครับ ก่อนหน้านี้ผมเคยเขียนบทความที่เล่าประวัติความเป็นมาของ Object Detection (กดได้เลย) ซึ่งได้มีการอ้างถึง loss function ตัวนึงที่ชื่อว่า Focal Loss สำหรับบทความนี้เราจะมาเจาะลึกลงไปใน loss function ตัวนี้กันว่ามันมีการทำงานยังไง ช่วยแก้ปัญหาอะไรบ้าง และสามารถนำมาประยุกต์กับงานอื่นๆได้ไหม
Introduction
- Focal Loss หรืองานวิจัยที่ชื่อว่า Focal Loss for Dense Object Detection ซึ่งตีพิมพ์ในงานวิจัย ICCV ปี 2017 และยังได้รับรางวัล Best Student Paper Awards อีกด้วย
- Focal Loss เข้ามาช่วยแก้ปัญหาในการฝึกเน็ตเวิร์คแบบ Single Stage ในงานประเภท Object Detection (หากใครยังไม่เคยได้ยิน กดเข้าไปอ่านได้เลยครับ) ที่พบว่าปัญหาในการฝึกตัวแบบตกอยู่ที่ความไม่สมดุลของชุดข้อมูลฝึก
- ปัจจุบันเปเปอร์ Focal Loss มีค่าอ้างอิง (Citation) อยู่ที่ 3082 (04 May 2020)
- หลายๆเปเปอร์ในปัจจุบันมีการอ้างอิงงานวิจัยนี้ มาปรับแต่งเพิ่มเติม เช่น การปรับปรุงในส่วน Feature Extractor และทำให้ผลลัพธ์ยังอยู่ในระดับ State-of-the-Art อีกด้วย อย่างเปเปอร์ EfficientDet: Scalable and Efficient Object Detection ที่ทำการปรับปรุง Backbone ของโมเดลให้ใช้งาน EfficientNet
Problems
- ในการฝึกเน็ตเวิร์คแบบ Single-Stage ตัวเน็ตเวิร์คจะพยายามสุ่มยิ่งสิ่งที่เรียกว่า Anchor (กรอบสี่เหลี่ยมที่ถูกสุ่มวาดบนรูปภาพเพื่อตามหาวัตถุ) ไปบนรูปภาพในจำนวนหนึ่ง เพื่อตามหาวัตถุที่สนใจ (ใน Two-Stage ไม่เกิดปัญหานี้เพราะมีการแบ่งโครงข่ายที่ออกเป็น 2 สเตจ คือส่วนเสนอพื้นที่ และส่วนรู้จำวัตถุ)
- เนื่องจากเราต้องฝึกว่า “พื้นที่ตรงนี้เป็นวัตถุที่สนใจ ส่วนตรงนี้ไม่มีวัตถุที่ สนใจ” ซึ่งในความเป็นจริงวัตถุจะมีน้อยกว่าพื้นที่ที่ไม่ใช่วัตถุมาก เราเรียกพื้นที่ที่ไม่ใช่วัตถุว่าพื้นหลัง (background)
- ปัญหาของความไม่สมดุลเกิดขึ้นจากพวกพื้นหลังนับเป็นตัวอย่างที่ง่าย ทำให้เกิดปัญหาว่าเครื่องเน้นไปที่การรู้จำพื้นหลังที่มันง่าย แต่ไม่เก่งกับวัตถุที่สนใจ
จำนวนตัวเลขของ Anchor ที่ถูกยิงในแต่ละโมเดล
จะเห็นว่า RetinaNet ยิง Anchor ออกไปประมาณ 100k ก็จริง ซึ่งเยอะกว่าชาวบ้านเขา แต่ตัวเน็ตเวิร์คเองสามารถรองรับจำนวน Anchor ขนาดนี้ได้ด้วยการเปลี่ยนแค่ Loss Function
ในงานวิจัยมีการพูดถึงอัตราส่วนของคลาสที่ไม่สมดุล จะอยู่ประมาณนี้ 1:1000
Focal Loss
ก่อนไปลงลึกถึงรายละเอียดของ Focal Loss อยากกล่าวถึงตัว Cross-Entropy with Softmax ก่อน หากใครไม่เคยได้ยินหรืออาจจะจำไม่ได้ สามารถกดเข้าไปอ่านบทความผมได้ที่ Link
Balanced Cross Entropy
- ปกติแล้วค่า Loss แบบ Cross Entropy (CE) ของคลาสคำตอบ t ซึ่งเครื่องทายมาด้วยความมั่นใจระดับ pt จะมีค่าเป็น
- แต่พอปรับตามสัดส่วน เราจะใส่สัมประสิทธิ์เข้าไปด้วย ซึ่งเลือกได้หลายแบบซึ่งแบบที่ตรงไปตรงมาที่สุดก็จะเป็นสัดส่วนผกผัน
- เช่น ในการทำ Object Detection ถ้าพื้นหลังมีสัดส่วน 90% สัมประสิทธิ์ของค่า loss ของพื้นหลังจะเป็น 0.1 ส่วนของวัตถุจะเป็น 0.9 ดังนี้
- วิธีปรับค่านี้เรียกว่า balanced cross entropy วิธีปรับค่า loss เมื่อตัวอย่างในแต่ละคลาสไม่สมดุลนี้ช่วยแก้ปัญหาได้ระดับหนึ่ง
- แต่เมื่อพบความไม่สมดุลระดับ 1:1000 ปัญหาจะกลับมารุนแรงอีกครั้ง
- การใช้ค่า balanced CE แก้ปัญหาเรื่องตัวอย่างประเภทพื้นหลังที่มีจำนวนที่ท่วมท้น แต่มันไม่ได้แยกว่าตัวอย่างไหนง่ายตัวอย่างในยาก
- ที่ต้องมาคอยสังเกตตัวอย่างยากง่ายนั้นก็เพราะเราต้องการให้เครื่องใส่ใจตัวอย่างที่ยากให้มากขึ้น ซึ่งมักมีปริมาณน้อยกว่าตัวอย่างที่ง่าย
- งานวิจัยเรื่อง Online Hard-Example Mining (OHEM) แสดงให้เห็นถึงประโยชน์จากการสนใจตัวอย่างที่ยากให้มากขึ้น (โดย R. Girshick เจ้าเดิม)
- Focal loss เป็นการปรับสมการคำนวณ loss ให้ดีขึ้นอีกระดับ ด้วยการทำให้ตัวอย่างง่ายส่งผลกับค่า loss น้อยลงไป (น้อยลงในเชิงเปรียบเทียบ)
- ซึ่งคำว่าง่ายยากนี้ปรับแต่งได้ด้วยพารามิเตอร์ γ ที่เราเลือก
- ใช้ร่วมกับแนวคิด balanced CE ได้ด้วย แก้ปัญหาหลายประเด็นพร้อมกัน
สมการ Focal Loss
- เป้าหมายของ Focal Loss คือต้องการให้ตัวอย่างที่ยากจัด ๆ ส่งผลกับค่า loss มากขึ้น ในขณะที่ตัวอย่างที่ง่ายจะมีสัดส่วนต่อค่า loss น้อยลง กระทำได้โดยสมการ
- จากสมการ Focal Loss ข้างบนนั้น หากเครื่องทายคลาสที่ถูกต้องด้วยความมั่นใจ ค่า pt จะสูงทำให้สมประสิทธิ์ (1 − pt) ของตัวอย่างมีค่าน้อย
- จากภาพด้านบน เรามาลองขยายความกันสักนิด ลองสักเกตุเส้นสีม่วงที่กำหนดให้ค่า γ มีค่าเท่ากับสอง ทางด้านขวาของทางกราฟ ค่า Loss ของตัวอย่างที่ตอบถูกด้วยความมั่นใจสูงๆ จะถูกกดให้มีค่าลดลงไปมากกว่าเดิมค่อนข้างมาก ยิ่งถ้าเครื่องตอบด้วยความมั่นใจสูงเท่าไร Loss ก็จะยิ่งเข้าใกล้ 0
แล้วมันหมายความว่ายังไง มันจะมาช่วยแก้ปัญหา Imbalanced Classes และ Hard Example ได้ยังไง ?
เรามาลองดูตัวอย่างกันสักนิด
- สมมุติว่าเราต้องการให้โมเดลทำนายค่า 10 (เริ่มต้นจาก 0–9) โดยเราจะสนใจการทำนายคลาสที่ 9 ซึ่งคือค่าสุดท้ายในลิสต์ zval
- ฝั่งซ้ายเป็นการคำนวณ Loss Function แบบปกติ ซึ่งค่า Loss ของการทำนายคลาสที่ 9 อยู่ที่ 0.000408
- แต่สำหรับฝั่งขวา เป็นการคำนวณ Loss แบบ Focal Loss ซึ่งค่า Loss จะอยู่ที่ 0.0000000000681
- จากตัวอย่างที่ 1 จะเห็นว่าถ้าเน็ตเวิร์คเราสามารถทำนายได้ด้วยความแม่นยำที่สูง Focal Loss จะกด Loss ให้น้อยลงกว่าเดิมค่อนข้างมากแทบจะใกล้เข้า 0 เลยที่เดียว
- ตัวอย่างนี้จะเห็นว่าค่าที่อยู่ใน zval มีบ้างค่าที่ต่างกับค่าของคลาสที่ 9 คือ 5 เพียงแค่นิดเดียว ซึ่งนั่นหมายความว่าโมเดลจะทำนายออกมาด้วยความมั่นใจที่ค่อนข้างน้อย
- เราลองมาดูค่า Loss จากการคำนวณปกติ ค่า Loss ที่ได้จะอยู่ที่ 0.7437 แต่กลับกันถ้าเรามาลองคำนวณ Loss จาก Focal Loss ค่าที่ได้จะอยู่ที่ 0.2047
- Loss ของ Focal Loss จะถูกกดลงมาจาก 0.7437 ทำให้ Loss มีค่าน้อยลงกว่าเดิมค่อนข้างมาก
การกด Loss ให้มีค่าน้อยลงมันช่วยอะไร ?
- ปกติแล้วถ้าเราใช้ Loss แบบปกติจำพวก Binary Cross-Entropy ค่า Loss โดยร่วมทั้งหมดจะถูกคิดจากการที่เน็ตเวิร์คตอบผิดและตอบถูกด้วยความมั่นใจที่สูงมากๆ ซึ่งเน็ตเวิร์คจะพยายามปรับจูนให้ตัวมันเองสามารถตอบคำถามให้ถูกต้องและตอบคำถามง่ายๆที่มีความมั่นใจสูงอยู่แล้วให้สูงขึ้นไปอีก
- แต่การที่เรากด Loss ตัวอย่างง่ายๆลงไป มันหมายความว่า เรากำลังต้องการให้เน็ตเวิร์คเราโฟกัสกับสิ่งที่ยังไม่สามารถตอบถูก หรือตอบถูกแต่ความมั่นใจค่อนข้างน้อย
- ถึง Loss จากการคำนวณแบบ Focal Loss จะน้อยลงกว่าปกติ แต่อย่าลืมว่า Loss ที่ถูกคำนวณมานั่นมาเหลือแค่ที่มาจากการตอบผิดและตอบถูกด้วยความไม่มั่นใจ มันจึงทำให้เน็ตเวิร์คเรียนรู้กับตัวอย่างแบบนั่นมากขึ้น
Focal Loss มีปัญหาอะไรไหมในการใช้งาน ?
- จากข้างบนก็ดูเหมือนจะดี แต่ปัญหาหลักๆในการใช้งาน Focal Loss คือ ยิ่งเราให้ค่า γ สูงเท่าไร ค่า Loss ที่เครื่องตอบถูกด้วยความมั่นใจที่สูงจะถูกกดลงมาเข้าใกล้ 0 มากขึ้น
- มันทำให้เน็ตเวิร์คสามารถเรียนตัวอย่างง่ายและตอบด้วยความมั่นใจประมาณนึงเช่น 70% แต่หากเราใช้ γ ค่า Loss ในระดับความมั่นใจ 70% อาจเหลือ 0 ทำให้เน็ตเวิร์คไม่สามารถเรียนรู้เพื่อตอบคำถามเหล่านี้ด้วยความมั่นใจที่สูงขึ้นได้อีก (70% เป็นค่าที่ยกตัวอย่างขึ้นมา)
- จริงๆแล้วเราสามารถแก้ปัญหาเหล่านั้นได้ โดยในตอนแรกนั่นเราอาจจะเริ่มฝึกเน็ตเวิร์คด้วย Focal Loss ที่มีค่า γ ค่อนข้างสูง เพื่อให้โมเดลสามารถทำนายตัวอย่างที่ทั้งง่ายและยากได้ถูกต้อง
- พอฝึกมาถึงจังหวะนึง เราก็เทรนเน็ตเวิร์คอีกรอบโดยใช้น้ำหนัก (Weights) ตัวเดิมที่เทรนมาแล้ว แต่ปรับลดค่า γ ลงเพื่อ Optimize ค่าความมั่นใจจากการเทรนรอบแรก
Summary
Focal Loss เองถูกคิดค้นมาเพื่อแก้ปัญหาในงานด้าน Object Detection แบบ Single-Stage-Detection ที่มีความ Imbalanced ของ Data ค่อนข้างสูงระหว่างวัถตุและพื้นหลังก็จริง แต่ด้วยหลักการของมัน ,มันพยายามเข้ามาแก้ปัญหา Imbalanced Classes และ Hard Example ซึ่งเราสามารถประยุกต์ใช้กับงานที่เกี่ยวกับ Classification ได้เกือบทุกประเภท ไม่ว่าจะเป็น Text-Classification หรือแม้แต่ Speech Classification หรืองานประเภท Representation Learning ที่ต้องใช้ Task Classification มาช่วยในการเทรน (เช่นพวก Face Recognition)
เป็นไงกันมั้งครับกับ Loss Function ที่ชื่อว่า Focal Loss ยังไงใครที่ชอบบทความนี้สามารถกดตบมือและกดติดตามเพื่อรอบทความอื่นๆจากผมได้อีกนะครับ ขอบคุณครับ
References
[1] Focal Loss for Dense Object Detection: https://arxiv.org/abs/1708.02002
[2] Review: RetinaNet — Focal Loss (Object Detection): https://towardsdatascience.com/review-retinanet-focal-loss-object-detection-38fba6afabe4