Pruning in Deep Learning: The efficacy of pruning for model compression

Through this article, we will seek to perform a closer exploration of the effectiveness of pruning for model compression.



The primary property of deep learning is that its accuracy empirically scales with the size of the model and the amount of training data. This property has dramatically improved the state-of-the-art performance across sectors. With that, the resources required for the training and serving these models scales with model and data size. The equivalence of scale to model quality is the catalyst for the pursuit of efficiency in the machine learning and systems community. A promising avenue for improving the efficiency of Deep Neural Networks is exploiting sparsity. DNN’s primarily consists of input data, weight matrices during training, and activation matrices that are computed from weights and data. Sparsity can be induced in these matrices to get the efficient model.

Pruning in DL

Model pruning seeks to produce sparsity in a deep neural network’s various connection matrices, thereby reducing the number of nonzero-valued parameters in the model.

Shrinking DNN models via pruning | Source

Recent reports prune deep networks at the cost of only a marginal loss in accuracy and achieve a sizable reduction in model size. This hints at the possibility that the baseline models in these experiments are perhaps severely over-parameterized and an alternative for model compression might be to simply reduce the number of hidden units while maintaining the model’s dense structure, exposing a similar trade-off in model size and accuracy.

Pruning can certainly reduce the size of the model by removing non-zero valued parameters while affecting the accuracy of the model a little.

Table 1 highlights work reductions from different forms of sparsity on popular neural network models.

Table 1: Sparsity levels and compute advantages on common models.| Source


We’ll learn to do the following by using Pruning with Keras (a Tensor Flow implementation):

  1. Train a tf.keras model for MNIST from scratch.
  2. Fine-tune the model by applying the pruning API and see the accuracy.
pip install -q tensorflow-model-optimizationimport tempfile
import os

import tensorflow as tf
import numpy as np

from tensorflow import keras

%load_ext tensorboard
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
keras.layers.MaxPooling2D(pool_size=(2, 2)),

# Train the digit classification model


Epoch 1/4
1688/1688 [==============================] - 7s 4ms/step - loss: 0.3422 - accuracy: 0.9004 - val_loss: 0.1760 - val_accuracy: 0.9498
Epoch 2/4
1688/1688 [==============================] - 7s 4ms/step - loss: 0.1813 - accuracy: 0.9457 - val_loss: 0.1176 - val_accuracy: 0.9698
Epoch 3/4
1688/1688 [==============================] - 7s 4ms/step - loss: 0.1220 - accuracy: 0.9648 - val_loss: 0.0864 - val_accuracy: 0.9770
Epoch 4/4
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0874 - accuracy: 0.9740 - val_loss: 0.0763 - val_accuracy: 0.9787

<tensorflow.python.keras.callbacks.History at 0x7f32cbeb9550>

Evaluate baseline test accuracy and save the model for later usage.

_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)


Baseline test accuracy: 0.972599983215332
Saved baseline model to: /tmp/tmp6quew9ig.h5

Start the model with 50% sparsity (50% zeros in weights) and end with 80% sparsity.

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set.

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,

model_for_pruning = prune_low_magnitude(model, **pruning_params)

# `prune_low_magnitude` requires a recompile.



WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/ Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
Model: "sequential"
Layer (type) Output Shape Param #
prune_low_magnitude_reshape (None, 28, 28, 1) 1
prune_low_magnitude_conv2d ( (None, 26, 26, 12) 230
prune_low_magnitude_max_pool (None, 13, 13, 12) 1
prune_low_magnitude_flatten (None, 2028) 1
prune_low_magnitude_dense (P (None, 10) 40572
Total params: 40,805
Trainable params: 20,410
Non-trainable params: 20,395

Fine-tune with pruning for two epochs.

tfmot.sparsity.keras.UpdatePruningStep is required during training, and tfmot.sparsity.keras.PruningSummaries provides logs for tracking progress and debugging.

logdir = tempfile.mkdtemp()callbacks = [
], train_labels,
batch_size=batch_size, epochs=epochs, validation_split=validation_split,


Epoch 1/2
1/422 [..............................] - ETA: 0s - loss: 0.2689 - accuracy: 0.8984WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/ stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
422/422 [==============================] - 3s 7ms/step - loss: 0.1105 - accuracy: 0.9691 - val_loss: 0.1247 - val_accuracy: 0.9682
Epoch 2/2
422/422 [==============================] - 3s 6ms/step - loss: 0.1197 - accuracy: 0.9667 - val_loss: 0.0969 - val_accuracy: 0.9763

<tensorflow.python.keras.callbacks.History at 0x7f32422a9550>

For this example, there is minimal loss in test accuracy after pruning, compared to the baseline.

_, model_for_pruning_accuracy = model_for_pruning.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
print('Pruned test accuracy:', model_for_pruning_accuracy)
Baseline test accuracy: 0.972599983215332
Pruned test accuracy: 0.9689000248908997


Exploiting sparsity in deep learning is a complex task, but the opportunity for scaling existing applications and enabling new ones is large. In addition to the topics discussed in this post, increased software and hardware support for sparsity will be an important contributor to progress in the field.


  4. To prune, or not to prune.

AI Enthusiast | Independent researcher

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store