Sınıflandırma Modellerinde Focal Kaybının Kullanımı ve Cross Entropy Kaybından Üstünlüğü

Sefa İşci
KaVe

--

Focal kaybına aslında Cross Entropy kaybının genelleştirilmiş hali de diyebiliriz. Ve parametreler ile oynayarak Cross Entropy kaybına kıyasla model üzerinde daha başarılı sonuçlar elde edebiliriz.

Focal kaybı aslında sınıf dengesizliği (class imbalance) problemini halletmek için kullanılmaktadır. Ama düzgün bir şekilde kullanılırsa modele regülarizasyon efekti vererek tahmin skorlarını (confidence scores) post processing ile manipüle edecek hale getirebilir. Ve bunun sonucunda da daha başarılı bir model oluşturulabilir. Bu kaybı daha iyi anlayabilmek için öncelik olarak Cross Entropy kaybını gözden geçirmemiz gerekir.

Cross Entropy Kaybı

1 numaralı denklemde ilgili etiket düğümüne giden lineer denklemin sonucu, 2 numaralı denklemde bütün etiketler için aldığımız lineer çıktıya göre softmax dönüşümü, 3 numaralı denklemde de aldığımız softmax sonuçlarına göre cross entropy kayıpları bulunmaktadır. Aslında aldığımız sonuçları target değerleri ile çarptığımız zaman (3 numaralı denklemdeki t bileşenleri) tek bir sonuç elde etmekteyiz. Çünkü target değerimizi one-hot encoding kullanarak bir vektör haline getirip, ilgili bileşenler ile sıra sıra çarptığımızda diğer cross entropy kayıpları sıfır değerini almakta ve dolayısı ile elimizde bir satır örneği için bir adet cross entropy kaybı oluşmaktadır. Bunu basit bir örnek üzerinde gösterirsem daha iyi anlaşılacağını düşünüyorum :)

elimizde 3 etiket olduğunu varsayalım.
bir satır örneği üzerinden 3 etiketli bir modelin softmax sonucu

yukarıdaki resimde bir örnek üzerinden softmax sonucunu görmektesiniz. Ve bu gerçek etiket değerinin 2 numaralı sınıf olduğunu varsayalım. Bu etiket değerini one-hot encoding matrisine çevirdiğimiz zaman [0 1 0] vektörünü elde ederiz. Cross entropy sonucu da aşağıda göründüğü gibidir.

Dengeleştirilmiş Cross Entropy Kaybı (Balanced Cross Entropy Loss)

Balanced CE kaybında ise CE kaybına ek bir alfa parametresi eklemekteyiz.

4 numaralı denklemde eklediğimiz parametrede dengesiz sınıflandırma olan eğitim verimizin veri sayısı az olan sınıfın hata kaybını artırarak o sınıfa daha fazla önem vermesini sağlıyor. Bazı kaynaklar ikili sınıflandırmada çok veriye sahip sınıfa alfa değerini verirken az veriye sahip sınıfa (1-alfa) değerini vermektedir. Diğer kaynaklarda özellikle çoklu sınıflandırma olan durumlarda bu durum çok rule-based olduğundan dolayı verilerin frekansına bakarak bu değişkeni değiştirmektedir. kabaca örnek vermek gerekirse; 1 numaralı etiketten 2000, 2 numaralı etiketten 10000 ve 3 numaralı etiketten 5000 tane olmak üzere toplam 17000 örnek olduğunu varsayalım. Alfa değeri;
1 numaralı etiket için 1–2000/17000 = 0.882
2 numaralı etiket için 1–10000/17000 = 0.411
3 numaralı etiket için 1–5000/17000 = 0.705 değerlerini almaktadır.
Bunun sonucunda elimizdeki frekans sayısına bağlı olarak kayıp fonksiyonumuzu arttırıp azaltmaktayız. 2000 örneği olan 1 numaralı etiket için cross entropy kaybını 0.882 çarparken, 10000 örneği olan 2 numaralı etiket için cross entropy kaybını 0.411 ile çarpmaktayız. Böylece hatayı az örnek olan sınıf için daha az azaltıp modelimizin o sınıf örneğini geriye yayılım (backpropagation) yaparken daha çok önemsemesini sağlıyoruz.

Focal Kaybı

Focal kaybında dengeleştirilmiş cross-entropy kaybına yeni bir değişken ekleniyor.

Bu değişken sayesinde gamma değerininde etkisiyle modelimiz düşük tahmin skoru almış örneklere daha çok önem verirken yüksek bir tahminde bulunmuş örneklere daha az önem vermesini sağlamaktayız. gamma değişkenini arttırdıkça bu özellik üzerine daha çok düşülmektedir. Verimizin eşit bir dağılıma sahip olduğunu varsaydığımızda pt=0.6 ve pt=0.9 ile gamma=1 ve gamma=2 değerlerinin kıyaslamasını aşağıda sizlere paylaşıyorum.

gamma=1 ve alfa=1 olduğunda 0.9 tahmini yapılan kayıp 1–0.9 = 0.1 ile çarpılırken, 0.6 tahmini yapılan kayıp 1–0.6 = 0.4 ile çarpılmaktadır. Bu sonuçta kayıp hesaplanırken 0.6 tahmini yapılmış olan kayıbın 0.9'a oranla 4 kat daha öneme sahip olmasını sağlar.

gamma=2 ve alfa=1 olduğunda 0.9 tahmini yapılan kayıp (1–0.9)² = 0.01 ile çarpılırken, 0.6 tahmini yapılan kayıp (1–0.6)² = 0.16 ile çarpılmaktadır. Bu sonuçta kayıp hesaplanırken 0.6 tahmini yapılmış olan kayıbın 0.9'a oranla 16 kat daha öneme sahip olmasını sağlar.

Bu değişkenler sayesinde modelimiz dengesiz veri setinde (imbalanced dataset) ve güven~tahmin skorunu çok yüksek tahmin edip yanlış sınıf tahmini yapan durumlarda (tahmini 1 numaralı sınıf olarak 0.95 güven skoru ile yaparken gerçek etiketimizin aslında 2 numaralı sınıf olması gibi vb…) daha güzel çalışmaktadır. 0.95 ile tahmin eden model focal kaybı ile eğitim sonucunda skor tahmini 0.8 e düşürececek iken 0.6 ile yaptığı tahmini de 0.7~0.8 gibi bir değere çıkaracaktır. Bir nevi gamma değeri güven skorları üzerinden regülarizasyon işlemini gerçekleştirmektedir. Ayrıca ileride uygulayacağımız threshold ile (post processing işleminden bahsediyorum) FP ve FN değerlerini manipüle etmemiz daha kolay olmaktadır.

Ayrıca alfa = 1 ve gamma = 0 olduğunda da Cross Entropy kaybını elde ettiğimizi unutmamamız gerekir.

Normalde bu kaybı obje tespiti için background sınıfların (objelerin olmadığı gridler) çok fazla olmasından dolayı (Objelerin bulunduğu gridler genellikle bulunmayan yerlere göre daha azdır) kullanmışlar. Ama bu kaybı yapılı verilerde, NLP işlerinde ve diğer Computer Vision algoritmalarında da kullanabilirsiniz.
İlgili makaleye bakmak isterseniz bu linke tıklamanız yeterli :)

pytorch’da da focal kaybını kullanmanız için kodu gist ile ekledim.

cats_dogs açık veri setini kullanarak örnek bir alpha_dict oluşturma işlemi

Tavsiye

Focal kaybını kullanırken ilk başta gamma değerini 1'den başlatmanızı tavsiye ediyorum. Fakat eğitimin sonlarına doğru gamma değerini 5'e kadar çıkarabilirsiniz. gamma=5 olduğu zaman kayıp çok düşük olacağından güven olasılığı yüksek gelmiş olan örnekler yokmuş gibi olacak, yani bir nevi sadece güven olasılığı düşük olan değerler ile eğitim yapıyormuşsunuz havasını vereceksiniz.
“Bunu eğittiğiniz modeli başarısız olduğu örnekler üzerinden fine-tuning etmek gibi de düşünebilirsiniz.”
Her bir epoch başına farklı gamma değerleri ile eğitim yapabilirsiniz. Aşağıda sizlere pytorch kodu ile nasıl kullanabileceğinizi göstermiş oldum.

Aynı zamanda LightGBM içerisine de eklenebileceği bir linkte buldum. Bu linke de bakarak nasıl kullandığını görebilirsiniz. Fakat kendinizin focal kaybını yazmanızı tavsiye ederim. Sonradan bunu zaten istediğiniz yere custom loss function olarak ekleyebilirsiniz. Yapılı (structured) veya yapısız (unstructured) veri olması da fark etmez!

Umarım faydalı olmuşumdur :)
Okuduğunuz için teşekkürler!

--

--