Weighted Categorical Cross-Entropy Loss in Keras

Siladittya Manna
The Owl
Published in
3 min readAug 28, 2023

In this article, we will be looking at the implementation of the Weighted Categorical Cross-Entropy loss.

For an implementation of the Weighted Binary Cross-Entropy loss, look here

Now, let us move on to the topic of this article and have a look into the mathematical formula for the weighed categorical cross-entropy loss.

Categorical Cross-Entropy Loss

where the summation is over the classes.

Now, we will proceed step by step through the implementations. We will start with the Weighted Categorical Cross-Entropy.

Disclaimer: All the codes in the articles mentioned above and in this article were done in TFv2.12 and Keras-2.12.0 in a Kaggle Notebook environment. Updated frameworks may behave differently.

Import Required Libraries

import tensorflow as tf
from keras.saving import saving_lib
import tensorflow.keras as keras

Defining the class

To implement the weighted version of the loss, we will follow the source code of the vanilla versions and make minor changes to them, to avoid any chance of instability in the loss calculation.

To instantiate, we will provide 3 arguments: weights, label_smoothing, and axis. We will not be using the from_logits argument in this implementation, as we assume that we will be applying sigmoid or softmax to the logits before giving them as input to the loss functions.

The class consists of 2 primary methods, __init__ and__call__. In addition to that, we will use another function weighted_binary_crossentropy for the final calculation of the loss.

In addition to calculating the loss using this custom loss function, it is also essential that this custom object can be saved and loaded later for inference. We will use the decorator mentioned below

@keras.saving.register_keras_serializable(name="WeightedBinaryCrossentropy")

before the function weighted_binary_crossentropy and the class WeightedBinaryCrossentropy. Quoting from the Tensorflow documentation,

This is the preferred method, as custom object registration greatly simplifies saving and loading code. Adding the @keras.saving.register_keras_serializable decorator to the class definition of a custom object registers the object globally in a master list, allowing Keras to recognize the object when loading the model.

Furthermore, we also need to add the methods get_config and from_config to the class WeightedBinaryCrossentropy to enable the instantiation of the object from the deserialized configuration at the time of loading the model.

Before using the above decorator for the class or function definitions, we need to use the following line of code to ensure that the custom objects are cleared if we rerun the function or class definitions without restarting the kernel or just want to clear the previously registered definitions.

tf.keras.saving.get_custom_objects().clear()

Let us first look at the weighted_binary_crossentropy function.

Now, let us have a look at the Weighted Categorical Cross-Entropy loss object definition

The get_method and from_method are used during saving the model and reconstructing the loss when loading the model.

To instantiate this loss, we have to do the following:

wcce = WeightedCategoricalCrossentropy(weights = [1.5, 1.2, 1.0]) #for 3 classes

When called using the following code

wcce(targets,outputs)

We get the weighted binary cross-entropy loss averaged over the samples in a batch.

When saving the model compiled with this loss, no error will be shown. However, during loading the model in a completely different script or notebook, we need to define the object and functions before loading the model. More details about loading custom objects are available in the article given below.

Clap and share if you liked the article. Follow for more.

--

--

Siladittya Manna
The Owl

Senior Research Fellow @ CVPR Unit, Indian Statistical Institute, Kolkata || Research Interest : Computer Vision, SSL, MIA. || https://sadimanna.github.io