Weight Pruning with Keras
In this blog, we will be understanding the concept of weight pruning with Keras. Basically, weight pruning is a model optimization technique. In weight pruning, it gradually zeroes out model weight during the training process to achieve model sparsity.
This technique brings improvements via model compression. This technique is widely used to decrease the latency of the model.
I will be implementing weight pruning in the Fashion MNIST dataset where I have made a comparison between the normal way and the pruning method.
The example which I will be implementing will be required Tensorflow version 2.4 as well as
Tensorflow-model-optimization for that we need to install this package.
pip install -q tensorflow-model-optimization
Import the necessary dependencies
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tempfile
from sklearn.metrics import accuracy_score
from sys import getsizeofimport tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D, GlobalAvgPool2D, Dropout%load_ext tensorboard
Let’s define some helper function to determine the file size of the models which we will generate
def get_file_size(file_path):
size = os.path.getsize(file_path)
return size
def convert_bytes(size, unit=None):
if unit == "KB":
return print('File Size: ' + str(round(size/1024, 3)) + 'Kilobytes')
elif unit == 'MB':
return print('File Size: ' + str(round(size/(1024*1024), 3)) + 'Megabytes')
else:
return print('File Size: ' + str(size) + 'bytes')
Load the Fashion MNIST dataset
The Fashion MNIST dataset contains 70,000 grayscale images in 10 categories. The images show individual articles of clothing at low resolution (28 by 28 pixels), as seen here:
fashion_mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
#Storing test labels
test_labels = y_test
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
y_train = tf.one_hot(y_train, 10)
x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
y_test = tf.one_hot(y_test, 10)
Define the Labels
class_name = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot']
Display the shape of the training as well as testing images and labels
print("Training Image Shape: ",x_train.shape)
print("Training Label Shape", y_train.shape)
print("Testing Image Shape: ",x_test.shape)
print("Testing Label Shape", y_test.shape)Training Image Shape: (60000, 28, 28, 1)
Training Label Shape (60000, 10)
Testing Image Shape: (10000, 28, 28, 1)
Testing Label Shape (10000, 10)
Define the Hyperparameters
AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 64
EPOCHS = 10
NUM_CLASSES=10
Create the Data Pipeline
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = (
train_ds
.shuffle(BATCH_SIZE * 100)
.batch(BATCH_SIZE)
)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = (
test_ds
.batch(BATCH_SIZE)
)
The pipeline is ready!
Visualise the Training Images
sample_images, sample_labels = next(iter(train_ds))
plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(zip(sample_images[:9], sample_labels[:9])):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image.numpy().squeeze())
plt.title(class_name[np.argmax(label.numpy().tolist())])
plt.axis("off")
Define the Model
def training_model():
model = tf.keras.Sequential(
[
Conv2D(16, (5, 5), activation="relu", input_shape=(28, 28, 1)),
MaxPooling2D(pool_size=(2, 2)),
Conv2D(32, (5, 5), activation="relu"),
MaxPooling2D(pool_size=(2, 2)),
Dropout(0.2),
GlobalAvgPool2D(),
Flatten(),
Dense(128, activation="relu"),
Dense(NUM_CLASSES, activation="softmax"),
]
)
return model
For the sake of reproducibility, we serialize the initial random weights of our shallow network.
initial_model = training_model()
initial_model.save_weights("initial_weights.h5")
Let’s Compile and Train our Model
model.load_weights("initial_weights.h5")model.summary()model.compile(optimizer='adam',
loss="categorical_crossentropy",
metrics=['accuracy'])model.fit(train_ds, validation_data=test_ds, epochs=EPOCHS)test_loss, test_acc = model.evaluate(test_ds)
print("Baseline Test accuracy: {:.2f}%".format(test_acc * 100))
Model Summary
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_2 (Conv2D) (None, 24, 24, 16) 416
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 12, 12, 16) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 8, 8, 32) 12832
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 4, 4, 32) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 4, 4, 32) 0
_________________________________________________________________
global_average_pooling2d_1 ( (None, 32) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 32) 0
_________________________________________________________________
dense_2 (Dense) (None, 128) 4224
_________________________________________________________________
dense_3 (Dense) (None, 10) 1290
=================================================================
Total params: 18,762
Trainable params: 18,762
Non-trainable params: 0
_________________________________________________________________
After training our model, we get our baseline accuracy of 86.97%.
Save out Baseline Model
_, keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
print('Saved Baseline Model to:', keras_file)
Fine-tune Model with Pruning
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
Define the Hyperparamteres
VALIDATION_SPLIT = 0.1 # 10% of training set will be used for validation set.
EPOCHS=6
Note: We have taken the epoch less, in our baseline model it was around 10
images, labels = next(iter(train_ds))
num_images = images.shape[0] * (1 - VALIDATION_SPLIT)
end_step = np.ceil(num_images / BATCH_SIZE).astype(np.int32) * EPOCHS
Define Model for Pruning
In this example, you start the model with 50% sparsity (50% zeros in weights) and end with 80% sparsity.
# Define model for pruning
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50, final_sparsity=0.80, begin_step=0, end_step=end_step)
}model = training_model()
model.load_weights("initial_weights.h5")
model_for_pruning = prune_low_magnitude(model, **pruning_params)
prune_low_magnitude
requires a recompile
model_for_pruning.compile(optimizer='adam',
loss="categorical_crossentropy",
metrics=['accuracy'])
model_for_pruning.summary()
Summary
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
prune_low_magnitude_conv2d_4 (None, 24, 24, 16) 818
_________________________________________________________________
prune_low_magnitude_max_pool (None, 12, 12, 16) 1
_________________________________________________________________
prune_low_magnitude_conv2d_5 (None, 8, 8, 32) 25634
_________________________________________________________________
prune_low_magnitude_max_pool (None, 4, 4, 32) 1
_________________________________________________________________
prune_low_magnitude_dropout_ (None, 4, 4, 32) 1
_________________________________________________________________
prune_low_magnitude_global_a (None, 32) 1
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 32) 1
_________________________________________________________________
prune_low_magnitude_dense_4 (None, 128) 8322
_________________________________________________________________
prune_low_magnitude_dense_5 (None, 10) 2572
=================================================================
Total params: 37,351
Trainable params: 18,762
Non-trainable params: 18,589
_________________________________________________________________
Train your Pruning Model
logdir = tempfile.mkdtemp()callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]model_for_pruning.fit(train_ds, validation_data=test_ds, epochs=EPOCHS, callbacks=callbacks)
_, model_for_pruning_accuracy = model_for_pruning.evaluate(test_ds)
print("Pruned test accuracy: {:.2f}%".format(model_for_pruning_accuracy * 100))
Pruned Model Accuracy is 82.90%
Save the Pruning Model
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
_, pruned_keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)
After saving the respective model, I converted the saved models to TF-Lite Model. After converting the model into TF-Lite I decided to make an inference on the TF-Lite Model.
Note: I made the inference on the test images.
Results
Test accuracy TFLITE Baseline Model : 0.829
Test accuracy TFLITE Pruned Model : 0.829
We can clearly see that both of the accuracies remains the same while making the inference on the TF-Lite Model.
From this table, we can conclude that the Pruning of the model is better than the Baseline Model
Notebook Link:
Github Profile
Social Handles
Instagram: https://www.instagram.com/sayannath235/
LinkedIn: https://www.linkedin.com/in/sayannath235/
Mail: sayannath235@gmail.com