แก้ปัญหา Imbalanced Classes ด้วย Focal Loss

Natthawat Phongchit
4 min readMay 7, 2020

--

สวัสดีผู้อ่านทุกท่านครับ ก่อนหน้านี้ผมเคยเขียนบทความที่เล่าประวัติความเป็นมาของ Object Detection (กดได้เลย) ซึ่งได้มีการอ้างถึง loss function ตัวนึงที่ชื่อว่า Focal Loss สำหรับบทความนี้เราจะมาเจาะลึกลงไปใน loss function ตัวนี้กันว่ามันมีการทำงานยังไง ช่วยแก้ปัญหาอะไรบ้าง และสามารถนำมาประยุกต์กับงานอื่นๆได้ไหม

Introduction

  • Focal Loss หรืองานวิจัยที่ชื่อว่า Focal Loss for Dense Object Detection ซึ่งตีพิมพ์ในงานวิจัย ICCV ปี 2017 และยังได้รับรางวัล Best Student Paper Awards อีกด้วย
  • Focal Loss เข้ามาช่วยแก้ปัญหาในการฝึกเน็ตเวิร์คแบบ Single Stage ในงานประเภท Object Detection (หากใครยังไม่เคยได้ยิน กดเข้าไปอ่านได้เลยครับ) ที่พบว่าปัญหาในการฝึกตัวแบบตกอยู่ที่ความไม่สมดุลของชุดข้อมูลฝึก
  • ปัจจุบันเปเปอร์ Focal Loss มีค่าอ้างอิง (Citation) อยู่ที่ 3082 (04 May 2020)
  • หลายๆเปเปอร์ในปัจจุบันมีการอ้างอิงงานวิจัยนี้ มาปรับแต่งเพิ่มเติม เช่น การปรับปรุงในส่วน Feature Extractor และทำให้ผลลัพธ์ยังอยู่ในระดับ State-of-the-Art อีกด้วย อย่างเปเปอร์ EfficientDet: Scalable and Efficient Object Detection ที่ทำการปรับปรุง Backbone ของโมเดลให้ใช้งาน EfficientNet

Problems

ตัวอย่างการสุ่มยิ่ง Anchors ของเน็ตเวิร์ค Yolo
  • ในการฝึกเน็ตเวิร์คแบบ Single-Stage ตัวเน็ตเวิร์คจะพยายามสุ่มยิ่งสิ่งที่เรียกว่า Anchor (กรอบสี่เหลี่ยมที่ถูกสุ่มวาดบนรูปภาพเพื่อตามหาวัตถุ) ไปบนรูปภาพในจำนวนหนึ่ง เพื่อตามหาวัตถุที่สนใจ (ใน Two-Stage ไม่เกิดปัญหานี้เพราะมีการแบ่งโครงข่ายที่ออกเป็น 2 สเตจ คือส่วนเสนอพื้นที่ และส่วนรู้จำวัตถุ)
  • เนื่องจากเราต้องฝึกว่า “พื้นที่ตรงนี้เป็นวัตถุที่สนใจ ส่วนตรงนี้ไม่มีวัตถุที่ สนใจ” ซึ่งในความเป็นจริงวัตถุจะมีน้อยกว่าพื้นที่ที่ไม่ใช่วัตถุมาก เราเรียกพื้นที่ที่ไม่ใช่วัตถุว่าพื้นหลัง (background)
  • ปัญหาของความไม่สมดุลเกิดขึ้นจากพวกพื้นหลังนับเป็นตัวอย่างที่ง่าย ทำให้เกิดปัญหาว่าเครื่องเน้นไปที่การรู้จำพื้นหลังที่มันง่าย แต่ไม่เก่งกับวัตถุที่สนใจ

จำนวนตัวเลขของ Anchor ที่ถูกยิงในแต่ละโมเดล

จะเห็นว่า RetinaNet ยิง Anchor ออกไปประมาณ 100k ก็จริง ซึ่งเยอะกว่าชาวบ้านเขา แต่ตัวเน็ตเวิร์คเองสามารถรองรับจำนวน Anchor ขนาดนี้ได้ด้วยการเปลี่ยนแค่ Loss Function

ในงานวิจัยมีการพูดถึงอัตราส่วนของคลาสที่ไม่สมดุล จะอยู่ประมาณนี้ 1:1000

Focal Loss

ก่อนไปลงลึกถึงรายละเอียดของ Focal Loss อยากกล่าวถึงตัว Cross-Entropy with Softmax ก่อน หากใครไม่เคยได้ยินหรืออาจจะจำไม่ได้ สามารถกดเข้าไปอ่านบทความผมได้ที่ Link

Balanced Cross Entropy

  • ปกติแล้วค่า Loss แบบ Cross Entropy (CE) ของคลาสคำตอบ t ซึ่งเครื่องทายมาด้วยความมั่นใจระดับ pt จะมีค่าเป็น
Croos Entropy (CE) formula
  • แต่พอปรับตามสัดส่วน เราจะใส่สัมประสิทธิ์เข้าไปด้วย ซึ่งเลือกได้หลายแบบซึ่งแบบที่ตรงไปตรงมาที่สุดก็จะเป็นสัดส่วนผกผัน
  • เช่น ในการทำ Object Detection ถ้าพื้นหลังมีสัดส่วน 90% สัมประสิทธิ์ของค่า loss ของพื้นหลังจะเป็น 0.1 ส่วนของวัตถุจะเป็น 0.9 ดังนี้
BCE Example
  • วิธีปรับค่านี้เรียกว่า balanced cross entropy วิธีปรับค่า loss เมื่อตัวอย่างในแต่ละคลาสไม่สมดุลนี้ช่วยแก้ปัญหาได้ระดับหนึ่ง
  • แต่เมื่อพบความไม่สมดุลระดับ 1:1000 ปัญหาจะกลับมารุนแรงอีกครั้ง
  • การใช้ค่า balanced CE แก้ปัญหาเรื่องตัวอย่างประเภทพื้นหลังที่มีจำนวนที่ท่วมท้น แต่มันไม่ได้แยกว่าตัวอย่างไหนง่ายตัวอย่างในยาก
  • ที่ต้องมาคอยสังเกตตัวอย่างยากง่ายนั้นก็เพราะเราต้องการให้เครื่องใส่ใจตัวอย่างที่ยากให้มากขึ้น ซึ่งมักมีปริมาณน้อยกว่าตัวอย่างที่ง่าย
  • งานวิจัยเรื่อง Online Hard-Example Mining (OHEM) แสดงให้เห็นถึงประโยชน์จากการสนใจตัวอย่างที่ยากให้มากขึ้น (โดย R. Girshick เจ้าเดิม)
  • Focal loss เป็นการปรับสมการคำนวณ loss ให้ดีขึ้นอีกระดับ ด้วยการทำให้ตัวอย่างง่ายส่งผลกับค่า loss น้อยลงไป (น้อยลงในเชิงเปรียบเทียบ)
  • ซึ่งคำว่าง่ายยากนี้ปรับแต่งได้ด้วยพารามิเตอร์ γ ที่เราเลือก
  • ใช้ร่วมกับแนวคิด balanced CE ได้ด้วย แก้ปัญหาหลายประเด็นพร้อมกัน

สมการ Focal Loss

  • เป้าหมายของ Focal Loss คือต้องการให้ตัวอย่างที่ยากจัด ๆ ส่งผลกับค่า loss มากขึ้น ในขณะที่ตัวอย่างที่ง่ายจะมีสัดส่วนต่อค่า loss น้อยลง กระทำได้โดยสมการ
Binary Cross Entropy Loss
Better Notation
Cross Entropy Loss alternate definition
Focal Loss
  • จากสมการ Focal Loss ข้างบนนั้น หากเครื่องทายคลาสที่ถูกต้องด้วยความมั่นใจ ค่า pt จะสูงทำให้สมประสิทธิ์ (1 − pt) ของตัวอย่างมีค่าน้อย
Source: https://arxiv.org/pdf/1708.02002.pdf
  • จากภาพด้านบน เรามาลองขยายความกันสักนิด ลองสักเกตุเส้นสีม่วงที่กำหนดให้ค่า γ มีค่าเท่ากับสอง ทางด้านขวาของทางกราฟ ค่า Loss ของตัวอย่างที่ตอบถูกด้วยความมั่นใจสูงๆ จะถูกกดให้มีค่าลดลงไปมากกว่าเดิมค่อนข้างมาก ยิ่งถ้าเครื่องตอบด้วยความมั่นใจสูงเท่าไร Loss ก็จะยิ่งเข้าใกล้ 0

แล้วมันหมายความว่ายังไง มันจะมาช่วยแก้ปัญหา Imbalanced Classes และ Hard Example ได้ยังไง ?

เรามาลองดูตัวอย่างกันสักนิด

  • สมมุติว่าเราต้องการให้โมเดลทำนายค่า 10 (เริ่มต้นจาก 0–9) โดยเราจะสนใจการทำนายคลาสที่ 9 ซึ่งคือค่าสุดท้ายในลิสต์ zval
ตัวอย่างที่ 1 ซ้าย: การคำนวณแบบ Binary Cross Entropy ขวา: การคำนวณแบบ Focal Loss
  • ฝั่งซ้ายเป็นการคำนวณ Loss Function แบบปกติ ซึ่งค่า Loss ของการทำนายคลาสที่ 9 อยู่ที่ 0.000408
  • แต่สำหรับฝั่งขวา เป็นการคำนวณ Loss แบบ Focal Loss ซึ่งค่า Loss จะอยู่ที่ 0.0000000000681
  • จากตัวอย่างที่ 1 จะเห็นว่าถ้าเน็ตเวิร์คเราสามารถทำนายได้ด้วยความแม่นยำที่สูง Focal Loss จะกด Loss ให้น้อยลงกว่าเดิมค่อนข้างมากแทบจะใกล้เข้า 0 เลยที่เดียว
ตัวอย่างที่ 2ซ้าย: การคำนวณแบบ Binary Cross Entropy ขวา: การคำนวณแบบ Focal Loss
  • ตัวอย่างนี้จะเห็นว่าค่าที่อยู่ใน zval มีบ้างค่าที่ต่างกับค่าของคลาสที่ 9 คือ 5 เพียงแค่นิดเดียว ซึ่งนั่นหมายความว่าโมเดลจะทำนายออกมาด้วยความมั่นใจที่ค่อนข้างน้อย
  • เราลองมาดูค่า Loss จากการคำนวณปกติ ค่า Loss ที่ได้จะอยู่ที่ 0.7437 แต่กลับกันถ้าเรามาลองคำนวณ Loss จาก Focal Loss ค่าที่ได้จะอยู่ที่ 0.2047
  • Loss ของ Focal Loss จะถูกกดลงมาจาก 0.7437 ทำให้ Loss มีค่าน้อยลงกว่าเดิมค่อนข้างมาก

การกด Loss ให้มีค่าน้อยลงมันช่วยอะไร ?

  • ปกติแล้วถ้าเราใช้ Loss แบบปกติจำพวก Binary Cross-Entropy ค่า Loss โดยร่วมทั้งหมดจะถูกคิดจากการที่เน็ตเวิร์คตอบผิดและตอบถูกด้วยความมั่นใจที่สูงมากๆ ซึ่งเน็ตเวิร์คจะพยายามปรับจูนให้ตัวมันเองสามารถตอบคำถามให้ถูกต้องและตอบคำถามง่ายๆที่มีความมั่นใจสูงอยู่แล้วให้สูงขึ้นไปอีก
  • แต่การที่เรากด Loss ตัวอย่างง่ายๆลงไป มันหมายความว่า เรากำลังต้องการให้เน็ตเวิร์คเราโฟกัสกับสิ่งที่ยังไม่สามารถตอบถูก หรือตอบถูกแต่ความมั่นใจค่อนข้างน้อย
  • ถึง Loss จากการคำนวณแบบ Focal Loss จะน้อยลงกว่าปกติ แต่อย่าลืมว่า Loss ที่ถูกคำนวณมานั่นมาเหลือแค่ที่มาจากการตอบผิดและตอบถูกด้วยความไม่มั่นใจ มันจึงทำให้เน็ตเวิร์คเรียนรู้กับตัวอย่างแบบนั่นมากขึ้น

Focal Loss มีปัญหาอะไรไหมในการใช้งาน ?

  • จากข้างบนก็ดูเหมือนจะดี แต่ปัญหาหลักๆในการใช้งาน Focal Loss คือ ยิ่งเราให้ค่า γ สูงเท่าไร ค่า Loss ที่เครื่องตอบถูกด้วยความมั่นใจที่สูงจะถูกกดลงมาเข้าใกล้ 0 มากขึ้น
  • มันทำให้เน็ตเวิร์คสามารถเรียนตัวอย่างง่ายและตอบด้วยความมั่นใจประมาณนึงเช่น 70% แต่หากเราใช้ γ ค่า Loss ในระดับความมั่นใจ 70% อาจเหลือ 0 ทำให้เน็ตเวิร์คไม่สามารถเรียนรู้เพื่อตอบคำถามเหล่านี้ด้วยความมั่นใจที่สูงขึ้นได้อีก (70% เป็นค่าที่ยกตัวอย่างขึ้นมา)
  • จริงๆแล้วเราสามารถแก้ปัญหาเหล่านั้นได้ โดยในตอนแรกนั่นเราอาจจะเริ่มฝึกเน็ตเวิร์คด้วย Focal Loss ที่มีค่า γ ค่อนข้างสูง เพื่อให้โมเดลสามารถทำนายตัวอย่างที่ทั้งง่ายและยากได้ถูกต้อง
  • พอฝึกมาถึงจังหวะนึง เราก็เทรนเน็ตเวิร์คอีกรอบโดยใช้น้ำหนัก (Weights) ตัวเดิมที่เทรนมาแล้ว แต่ปรับลดค่า γ ลงเพื่อ Optimize ค่าความมั่นใจจากการเทรนรอบแรก

Summary

Focal Loss เองถูกคิดค้นมาเพื่อแก้ปัญหาในงานด้าน Object Detection แบบ Single-Stage-Detection ที่มีความ Imbalanced ของ Data ค่อนข้างสูงระหว่างวัถตุและพื้นหลังก็จริง แต่ด้วยหลักการของมัน ,มันพยายามเข้ามาแก้ปัญหา Imbalanced Classes และ Hard Example ซึ่งเราสามารถประยุกต์ใช้กับงานที่เกี่ยวกับ Classification ได้เกือบทุกประเภท ไม่ว่าจะเป็น Text-Classification หรือแม้แต่ Speech Classification หรืองานประเภท Representation Learning ที่ต้องใช้ Task Classification มาช่วยในการเทรน (เช่นพวก Face Recognition)

เป็นไงกันมั้งครับกับ Loss Function ที่ชื่อว่า Focal Loss ยังไงใครที่ชอบบทความนี้สามารถกดตบมือและกดติดตามเพื่อรอบทความอื่นๆจากผมได้อีกนะครับ ขอบคุณครับ

References

[1] Focal Loss for Dense Object Detection: https://arxiv.org/abs/1708.02002

[2] Review: RetinaNet — Focal Loss (Object Detection): https://towardsdatascience.com/review-retinanet-focal-loss-object-detection-38fba6afabe4

--

--

Natthawat Phongchit

Interested in Computer Vision, NLP, Reinforcement Learning.