Weighted Binary Cross-Entropy Loss in Keras

Siladittya Manna
The Owl
Published in
3 min readAug 28, 2023

While there are several implementations to calculate weighted binary and cross-entropy losses widely available on the web, in this article we present a structured way of calculating these losses during training.

For a similar implementation of weighted categorical cross-entropy loss look here.

First, we will take a look at the mathematical formulas

Weighted Binary Cross-Entropy Loss

Now, we will proceed step by step through the implementations. We will start with Weighted Binary Cross-Entropy.

Disclaimer: All the codes in the articles mentioned above and in this article were done in TFv2.12 and Keras-2.12.0 in a Kaggle Notebook environment. Updated frameworks may behave differently.

Import Required Libraries

import tensorflow as tf
from keras.saving import saving_lib
import tensorflow.keras as keras

Defining the class

To implement the weighted version of the loss, we will follow the source code of the vanilla versions and make minor changes, to avoid any chance of instability in the loss calculation.

To instantiate the loss object, we will provide 3 arguments: weights, label_smoothing, and axis. We will not use the from_logits argument in this implementation, as we assume that we will apply sigmoid or softmax to the logits before giving them as input to the loss functions.

The class consists of 2 primary methods, __init__ and __call__. In addition to that, we will use another function weighted_binary_crossentropy for the final calculation of the loss.

In addition to calculating the loss using this custom loss function, it is also essential that this custom object can be saved and loaded later for inference. We will use the decorator mentioned below

 @keras.saving.register_keras_serializable(name="WeightedBinaryCrossentropy")

before the function weighted_binary_crossentropy and the class WeightedBinaryCrossentropy. Quoting from the Tensorflow documentation,

This is the preferred method, as custom object registration greatly simplifies saving and loading code. Adding the @keras.saving.register_keras_serializable decorator to the class definition of a custom object registers the object globally in a master list, allowing Keras to recognize the object when loading the model.

Furthermore, we also need to add the methods get_config and from_config to the class WeightedBinaryCrossentropy to enable the instantiation of the object from the deserialized configuration at the time of loading the model.

Before using the above decorator for the class or function definitions, we need to use the following line of code to ensure that the custom objects are cleared if we rerun the function or class definitions without restarting the kernel or just want to clear the previously registered definitions.

tf.keras.saving.get_custom_objects().clear()

Let us first look at the weighted_binary_crossentropy function

Watch how the values are clipped to ensure numerical stability. The argument weights are the class weights.

Now, let us have a look at the Weighted Binary Cross-Entropy loss object

The get_method and from_method are used during saving the model and reconstructing the loss when loading the model.

To instantiate this loss, we have to do the following:

wbce = WeightedBinaryCrossentropy(weights = [1.5, 1.0])

When called using the following code

wbce(targets,outputs)

We get the weighted binary cross-entropy loss averaged over the samples in a batch.

When saving the model compiled with this loss, no error will be shown. However, during loading the model in a completely different script or notebook, we need to define the object and functions before loading the model. More details about loading custom objects are available in the article given below.

Clap and share if you liked the article. Follow for more.

--

--

Siladittya Manna
The Owl

Senior Research Fellow @ CVPR Unit, Indian Statistical Institute, Kolkata || Research Interest : Computer Vision, SSL, MIA. || https://sadimanna.github.io