Weight Pruning with Keras

Sayan Nath
Analytics Vidhya
Published in
5 min readMar 28, 2021

In this blog, we will be understanding the concept of weight pruning with Keras. Basically, weight pruning is a model optimization technique. In weight pruning, it gradually zeroes out model weight during the training process to achieve model sparsity.

This technique brings improvements via model compression. This technique is widely used to decrease the latency of the model.

I will be implementing weight pruning in the Fashion MNIST dataset where I have made a comparison between the normal way and the pruning method.

The example which I will be implementing will be required Tensorflow version 2.4 as well as

Tensorflow-model-optimization for that we need to install this package.

pip install -q tensorflow-model-optimization

Import the necessary dependencies

import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tempfile
from sklearn.metrics import accuracy_score
from sys import getsizeof
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D, GlobalAvgPool2D, Dropout
%load_ext tensorboard

Let’s define some helper function to determine the file size of the models which we will generate

def get_file_size(file_path):
size = os.path.getsize(file_path)
return size

def convert_bytes(size, unit=None):
if unit == "KB":
return print('File Size: ' + str(round(size/1024, 3)) + 'Kilobytes')
elif unit == 'MB':
return print('File Size: ' + str(round(size/(1024*1024), 3)) + 'Megabytes')
else:
return print('File Size: ' + str(size) + 'bytes')

Load the Fashion MNIST dataset

The Fashion MNIST dataset contains 70,000 grayscale images in 10 categories. The images show individual articles of clothing at low resolution (28 by 28 pixels), as seen here:

Fashion MNIST Dataset
fashion_mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

#Storing test labels
test_labels = y_test

x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
y_train = tf.one_hot(y_train, 10)

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
y_test = tf.one_hot(y_test, 10)

Define the Labels

class_name = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot']

Display the shape of the training as well as testing images and labels

print("Training Image Shape: ",x_train.shape) 
print("Training Label Shape", y_train.shape)
print("Testing Image Shape: ",x_test.shape)
print("Testing Label Shape", y_test.shape)
Training Image Shape: (60000, 28, 28, 1)
Training Label Shape (60000, 10)
Testing Image Shape: (10000, 28, 28, 1)
Testing Label Shape (10000, 10)

Define the Hyperparameters

AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 64
EPOCHS = 10
NUM_CLASSES=10

Create the Data Pipeline

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))

train_ds = (
train_ds
.shuffle(BATCH_SIZE * 100)
.batch(BATCH_SIZE)
)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))

test_ds = (
test_ds
.batch(BATCH_SIZE)
)

The pipeline is ready!

Visualise the Training Images

sample_images, sample_labels = next(iter(train_ds))
plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(zip(sample_images[:9], sample_labels[:9])):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image.numpy().squeeze())
plt.title(class_name[np.argmax(label.numpy().tolist())])
plt.axis("off")
Training Images

Define the Model

def training_model():
model = tf.keras.Sequential(
[
Conv2D(16, (5, 5), activation="relu", input_shape=(28, 28, 1)),
MaxPooling2D(pool_size=(2, 2)),
Conv2D(32, (5, 5), activation="relu"),
MaxPooling2D(pool_size=(2, 2)),
Dropout(0.2),
GlobalAvgPool2D(),
Flatten(),
Dense(128, activation="relu"),
Dense(NUM_CLASSES, activation="softmax"),
]
)
return model

For the sake of reproducibility, we serialize the initial random weights of our shallow network.

initial_model = training_model()
initial_model.save_weights("initial_weights.h5")

Let’s Compile and Train our Model

model.load_weights("initial_weights.h5")model.summary()model.compile(optimizer='adam',
loss="categorical_crossentropy",
metrics=['accuracy'])
model.fit(train_ds, validation_data=test_ds, epochs=EPOCHS)test_loss, test_acc = model.evaluate(test_ds)
print("Baseline Test accuracy: {:.2f}%".format(test_acc * 100))

Model Summary

Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_2 (Conv2D) (None, 24, 24, 16) 416
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 12, 12, 16) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 8, 8, 32) 12832
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 4, 4, 32) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 4, 4, 32) 0
_________________________________________________________________
global_average_pooling2d_1 ( (None, 32) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 32) 0
_________________________________________________________________
dense_2 (Dense) (None, 128) 4224
_________________________________________________________________
dense_3 (Dense) (None, 10) 1290
=================================================================
Total params: 18,762
Trainable params: 18,762
Non-trainable params: 0
_________________________________________________________________

After training our model, we get our baseline accuracy of 86.97%.

Save out Baseline Model

_, keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model, keras_file, include_optimizer=False)

print('Saved Baseline Model to:', keras_file)

Fine-tune Model with Pruning

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

Define the Hyperparamteres

VALIDATION_SPLIT = 0.1 # 10% of training set will be used for validation set.
EPOCHS=6

Note: We have taken the epoch less, in our baseline model it was around 10

images, labels = next(iter(train_ds))

num_images = images.shape[0] * (1 - VALIDATION_SPLIT)
end_step = np.ceil(num_images / BATCH_SIZE).astype(np.int32) * EPOCHS

Define Model for Pruning

In this example, you start the model with 50% sparsity (50% zeros in weights) and end with 80% sparsity.

# Define model for pruning
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50, final_sparsity=0.80, begin_step=0, end_step=end_step)
}
model = training_model()
model.load_weights("initial_weights.h5")

model_for_pruning = prune_low_magnitude(model, **pruning_params)

prune_low_magnitude requires a recompile

model_for_pruning.compile(optimizer='adam',
loss="categorical_crossentropy",
metrics=['accuracy'])

model_for_pruning.summary()

Summary

Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
prune_low_magnitude_conv2d_4 (None, 24, 24, 16) 818
_________________________________________________________________
prune_low_magnitude_max_pool (None, 12, 12, 16) 1
_________________________________________________________________
prune_low_magnitude_conv2d_5 (None, 8, 8, 32) 25634
_________________________________________________________________
prune_low_magnitude_max_pool (None, 4, 4, 32) 1
_________________________________________________________________
prune_low_magnitude_dropout_ (None, 4, 4, 32) 1
_________________________________________________________________
prune_low_magnitude_global_a (None, 32) 1
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 32) 1
_________________________________________________________________
prune_low_magnitude_dense_4 (None, 128) 8322
_________________________________________________________________
prune_low_magnitude_dense_5 (None, 10) 2572
=================================================================
Total params: 37,351
Trainable params: 18,762
Non-trainable params: 18,589
_________________________________________________________________

Train your Pruning Model

logdir = tempfile.mkdtemp()callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]
model_for_pruning.fit(train_ds, validation_data=test_ds, epochs=EPOCHS, callbacks=callbacks)
_, model_for_pruning_accuracy = model_for_pruning.evaluate(test_ds)
print("Pruned test accuracy: {:.2f}%".format(model_for_pruning_accuracy * 100))

Pruned Model Accuracy is 82.90%

Save the Pruning Model

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

_, pruned_keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)

After saving the respective model, I converted the saved models to TF-Lite Model. After converting the model into TF-Lite I decided to make an inference on the TF-Lite Model.

Note: I made the inference on the test images.

Results

Test accuracy TFLITE Baseline Model : 0.829
Test accuracy TFLITE Pruned Model : 0.829

We can clearly see that both of the accuracies remains the same while making the inference on the TF-Lite Model.

Table

From this table, we can conclude that the Pruning of the model is better than the Baseline Model

Notebook Link:

Github Profile

Social Handles

Instagram: https://www.instagram.com/sayannath235/

LinkedIn: https://www.linkedin.com/in/sayannath235/

Mail: sayannath235@gmail.com

--

--

Sayan Nath
Analytics Vidhya

I am amongst the top contributors in Github from India currently, my rank is #136. I am an aspiring Junior Data Scientist at Codebugged AI.