Python Classes and Their Use in Keras

Introduction to Classes

In object-oriented languages, such as Python, classes are one of the fundamental building blocks.

Creating a new class creates a new object, where every class instance can be characterized by its attributes to maintain its state, and methods to modify its state.

Defining a Class

The class keyword allows for the creation of a new class definition, immediately followed by the class name:

class MyClass:   <statements>

In this manner, a new class object bound to the specified class name (MyClass, in this particular case) is created. Each class object can support instantiation and attribute references, as we will see shortly.

Instantiation and Attribute References

Instantiation is the creation of a new instance of a class.

To create a new instance of a class, we can call it using its class name and assign it to a variable. This will create a new, empty class object:

x = MyClass()

Upon creating a new instance of a class, Python calls its object constructor method, __init()__, which often takes arguments that are used to set the instantiated object’s attributes.

Let’s say, for instance, that we would like to define a new class named, Dog:

class Dog:
family = “Canine”

def __init__(self, name, breed): = name
self.breed = breed

Here, the constructor method takes two arguments, name and breed, which can be passed to it upon instantiating the object:

dog1 = Dog("tommy", "labra")

In the example that we are considering, name and breed are known as instance variables (or attributes), because they are bound to a specific instance. This means that such attributes belong only to the object in which they have been set, but not to any other object instantiated from the same class.

On the other hand, family is a class variable (or attribute), because it is shared by all instances of the same class.

You may also note that the first argument of the constructor method (or any other method) is often called self. This argument refers to the object that we are in the process of creating. It is good practice to follow the convention of setting the first argument to self, to ensure the readability of your code for other programmers.

Once we have set our object’s attributes, they can be accessed using the dot operator. For example, considering again the dog1 instance of the Dog class, its name attribute may be accessed as follows:


Producing the following output:


Creating Methods and Passing Arguments

In addition to having a constructor method, a class object can also have several other methods for modifying its state.

Similar to the constructor method, each instance method can take several arguments, with the first one being the argument self that lets us set and access the object’s attributes:

class Dog:
family = "Canine"
def __init__(self, name, breed): = name
self.breed = breed
def info(self):
print(, "is a female", self.breed)

Different methods of the same object can also use the self argument to call each other:

class Dog:
family = "Canine"
def __init__(self, name, breed): = name
self.breed = breed
self.tricks = []
def add_tricks(self, x):
def info(self, x):
print(, "is a female", self.breed, "that", self.tricks[0])

An output string can then be generated as follows:

dog1 = Dog("tommy", "labra")"barks on command")

We find that, in doing so, the barks on command input is appended to the tricks list when the info() method calls the add_tricks() method. The following output is produced:

tommy is a female labra that barks on command

Class Inheritance

Another feature that Python supports is class inheritance.

Inheritance is a mechanism that allows a subclass (also known as a derived or child class) to access all attributes and methods of a superclass (also known as a base or parent class).

The syntax for using a subclass is the following:

class SubClass(BaseClass):

It is also possible that a subclass inherits from multiple base classes, too. In this case, the syntax would be as follows:

class SubClass(BaseClass1, BaseClass2, BaseClass3):

Class attributes and methods are searched for in the base class, and also in subsequent base classes in the case of multiple inheritance.

Python further allows that a method in a subclass overrides another method in the base class that carries the same name. An overriding method in the subclass may be replacing the base class method, or simply extending its capabilities. When an overriding subclass method is available, it is this method that is executed when called, rather than the method with the same name in the base class.

Using Classes in Keras

A practical use of classes in Keras is to write one’s own callbacks.

A callback is a powerful tool in Keras that allows us to have a look at the behaviour of our model during the different stages of training, testing and prediction.

Indeed, we may pass a list of callbacks to any of the following:

  • keras.Model.evaluate()
  • keras.Model.predict()

The Keras API comes with several built-in callbacks. Nonetheless, we might wish to write our own and, for this purpose, we shall be seeing how to build a custom callback class. In order to do so, we can inherit several methods from the callback base class, which can provide us with information of when:

  • Training, testing and prediction starts and ends.
  • An epoch starts and ends.
  • A training, testing and prediction batch starts and ends.

Let’s first consider a simple example of a custom callback that reports back every time that an epoch starts and ends. We will name this custom callback class, EpochCallback, and override the epoch-level methods, on_epoch_begin() and on_epoch_end(), from the base class, keras.callbacks.Callback:

import tensorflow.keras as kerasclass EpochCallback(keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
print("Starting epoch {}".format(epoch + 1))
def on_epoch_end(self, epoch, logs=None):
print("Finished epoch {}".format(epoch + 1))

In order to test the custom callback that we have just defined, we need a model to train. For this purpose, let’s define a simple Keras model:

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
def simple_model():
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(128, activation="relu"))
model.add(Dense(10, activation="softmax"))
return model

We also need a dataset to train on, for which purpose we will be using the MNIST dataset:

from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
# Loading the MNIST training and testing data splits
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Pre-processing the training data
x_train = x_train / 255.0
x_train = x_train.reshape(60000, 28, 28, 1)
y_train_cat = to_categorical(y_train, 10)

Now, let’s try out the custom callback by adding it to the list of callbacks that we pass as input to the method:

model = simple_model(),

The callback that we have just created produces the following output:

Starting epoch 1
Finished epoch 1
Starting epoch 2
Finished epoch 2
Starting epoch 3
Finished epoch 3
Starting epoch 4
Finished epoch 4
Starting epoch 5
Finished epoch 5

We can create another custom callback that monitors the loss value at the end of each epoch, and stores the model weights only if the loss has decreased. To this end, we will be reading the loss value from the log dict, which stores the metrics at the end of each batch and epoch. We will also be accessing the model corresponding to the current round of training, testing or prediction, by means of self.model.

Let’s call this custom callback, CheckpointCallback:

import numpy as npclass CheckpointCallback(keras.callbacks.Callback):def __init__(self):
super(CheckpointCallback, self).__init__()
self.best_weights = None
def on_train_begin(self, logs=None):
self.best_loss = np.Inf
def on_epoch_end(self, epoch, logs=None):
current_loss = logs.get("loss")
print("Current loss is {}".format(current_loss))
if np.less(current_loss, self.best_loss):
self.best_loss = current_loss
self.best_weights = self.model.get_weights()
print("Storing the model weights at epoch {} \n".format(epoch + 1))

We can try this out again, this time including the CheckpointCallback into the list of callbacks:

model = simple_model(),
callbacks=[EpochCallback(), CheckpointCallback()],

The following output of the two callbacks together is now produced:

Starting epoch 1
Finished epoch 1
Current loss is 0.6327750086784363
Storing the model weights at epoch 1
Starting epoch 2
Finished epoch 2
Current loss is 0.3391888439655304
Storing the model weights at epoch 2
Starting epoch 3
Finished epoch 3
Current loss is 0.29216915369033813
Storing the model weights at epoch 3
Starting epoch 4
Finished epoch 4
Current loss is 0.2625095248222351
Storing the model weights at epoch 4
Starting epoch 5
Finished epoch 5
Current loss is 0.23906977474689484
Storing the model weights at epoch 5

Other classes in Keras

Besides callbacks, we can also make derived classes in Keras for custom metrics (derived from keras.metrics.Metrics), custom layers (derived from keras.layers.Layer), custom regularizer (derived from keras.regularizers.Regularizer) or even custom models (derived from keras.Model, for such as changing the behavior of invoking a model). All you have to do is to follow the guideline to change the member functions of a class. You must use exactly the same name and parameters in the member functions.

Below is an example from Keras documentation:

class BinaryTruePositives(tf.keras.metrics.Metric):def __init__(self, name='binary_true_positives', **kwargs):
super(BinaryTruePositives, self).__init__(name=name, **kwargs)
self.true_positives = self.add_weight(name='tp', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.cast(y_true, tf.bool)
y_pred = tf.cast(y_pred, tf.bool)
values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
values = tf.cast(values, self.dtype)
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, self.dtype)
values = tf.multiply(values, sample_weight)
def result(self):
return self.true_positives
def reset_states(self):
m = BinaryTruePositives()
m.update_state([0, 1, 1, 1], [0, 1, 0, 0])
print('Intermediate result:', float(m.result()))
m.update_state([1, 1, 1, 1], [0, 1, 1, 0])
print('Final result:', float(m.result()))

This reveals why we would need a class for custom metric: A metric is not just a function but a function that computes its value incrementally, once per batch of training data during the training cycle. Eventually, the result is reported at the result() function at the end of an epoch and reset its memory using reset_state() function so you can start afresh in the next epoch.

For the details on what exactly have to be derived, you should refer to Keras documentation

Till then follow me for more !



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store