# Pruning in Deep Learning: The efficacy of pruning for model compression

Through this article, we will seek to perform a closer exploration of the effectiveness of pruning for model compression.

# Introduction

The primary property of deep learning is that its ** accuracy empirically scales with the size of the model and the amount of training data**. This property has dramatically improved the state-of-the-art performance across sectors. With that, the resources required for the training and serving these models scales with model and data size. The equivalence of scale to model quality is the catalyst for the pursuit of efficiency in the machine learning and systems community. A promising avenue for improving the efficiency of Deep Neural Networks is

**DNN’s primarily consists of input data, weight matrices during training, and activation matrices that are computed from weights and data. Sparsity can be induced in these matrices to get the efficient model.**

*exploiting sparsity*.# Pruning in DL

Model pruning seeks to produce sparsity in a deep neural network’s various connection matrices, thereby

reducing the number of nonzero-valued parameters in the model.

Recent reports prune deep networks at the cost of only a marginal loss in accuracy and achieve a sizable reduction in model size. This hints at the possibility that the baseline models in these experiments are perhaps severely over-parameterized and an alternative for ** model compression** might be to

**units while maintaining the model’s dense structure, exposing a similar trade-off in model size and accuracy.**

*simply reduce the number of hidden*Pruning can certainly reduce the size of the model by removing non-zero valued parameters while affecting the accuracy of the model a little.

Table 1 highlights work reductions from different forms of sparsity on popular neural network models.

# Implementation

We’ll learn to do the following by using **Pruning with Keras** (a Tensor Flow implementation):

- Train a
`tf.keras`

model for MNIST from scratch. - Fine-tune the model by applying the pruning API and see the accuracy.

## Setup

pip install -q tensorflow-model-optimizationimport tempfile

import os

import tensorflow as tf

import numpy as np

from tensorflow import keras

%load_ext tensorboard

## Train a model for MNIST without pruning

`# Load MNIST dataset`

mnist = keras.datasets.mnist

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.

train_images = train_images / 255.0

test_images = test_images / 255.0

# Define the model architecture.

model = keras.Sequential([

keras.layers.InputLayer(input_shape=(28, 28)),

keras.layers.Reshape(target_shape=(28, 28, 1)),

keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),

keras.layers.MaxPooling2D(pool_size=(2, 2)),

keras.layers.Flatten(),

keras.layers.Dense(10)

])

# Train the digit classification model

model.compile(optimizer='adam',

loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),

metrics=['accuracy'])

model.fit(

train_images,

train_labels,

epochs=4,

validation_split=0.1,

)

Output

`Epoch 1/4`

1688/1688 [==============================] - 7s 4ms/step - loss: 0.3422 - accuracy: 0.9004 - val_loss: 0.1760 - val_accuracy: 0.9498

Epoch 2/4

1688/1688 [==============================] - 7s 4ms/step - loss: 0.1813 - accuracy: 0.9457 - val_loss: 0.1176 - val_accuracy: 0.9698

Epoch 3/4

1688/1688 [==============================] - 7s 4ms/step - loss: 0.1220 - accuracy: 0.9648 - val_loss: 0.0864 - val_accuracy: 0.9770

Epoch 4/4

1688/1688 [==============================] - 7s 4ms/step - loss: 0.0874 - accuracy: 0.9740 - val_loss: 0.0763 - val_accuracy: 0.9787

<tensorflow.python.keras.callbacks.History at 0x7f32cbeb9550>

Evaluate baseline test accuracy and save the model for later usage.

`_, baseline_model_accuracy = model.evaluate(`

test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')

tf.keras.models.save_model(model, keras_file, include_optimizer=False)

print('Saved baseline model to:', keras_file)

Output

`Baseline test accuracy: 0.972599983215332`

Saved baseline model to: /tmp/tmp6quew9ig.h5

## Fine-tune pre-trained model with pruning

Start the model with 50% sparsity (50% zeros in weights) and end with 80% sparsity.

`import tensorflow_model_optimization as tfmot`

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.

batch_size = 128

epochs = 2

validation_split = 0.1 # 10% of training set will be used for validation set.

num_images = train_images.shape[0] * (1 - validation_split)

end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.

pruning_params = {

'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,

final_sparsity=0.80,

begin_step=0,

end_step=end_step)

}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

# `prune_low_magnitude` requires a recompile.

model_for_pruning.compile(optimizer='adam',

loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),

metrics=['accuracy'])

model_for_pruning.summary()

Output

`WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:220: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.`

Instructions for updating:

Please use `layer.add_weight` method instead.

Model: "sequential"

_________________________________________________________________

Layer (type) Output Shape Param #

=================================================================

prune_low_magnitude_reshape (None, 28, 28, 1) 1

_________________________________________________________________

prune_low_magnitude_conv2d ( (None, 26, 26, 12) 230

_________________________________________________________________

prune_low_magnitude_max_pool (None, 13, 13, 12) 1

_________________________________________________________________

prune_low_magnitude_flatten (None, 2028) 1

_________________________________________________________________

prune_low_magnitude_dense (P (None, 10) 40572

=================================================================

Total params: 40,805

Trainable params: 20,410

Non-trainable params: 20,395

_________________________________________________________________

## Train and evaluate the model against the baseline

Fine-tune with pruning for two epochs.

`tfmot.sparsity.keras.UpdatePruningStep`

is required during training, and `tfmot.sparsity.keras.PruningSummaries`

provides logs for tracking progress and debugging.

logdir = tempfile.mkdtemp()callbacks = [

tfmot.sparsity.keras.UpdatePruningStep(),

tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),

]model_for_pruning.fit(train_images, train_labels,

batch_size=batch_size, epochs=epochs, validation_split=validation_split,

callbacks=callbacks)

Output

`Epoch 1/2`

1/422 [..............................] - ETA: 0s - loss: 0.2689 - accuracy: 0.8984WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.

Instructions for updating:

use `tf.profiler.experimental.stop` instead.

422/422 [==============================] - 3s 7ms/step - loss: 0.1105 - accuracy: 0.9691 - val_loss: 0.1247 - val_accuracy: 0.9682

Epoch 2/2

422/422 [==============================] - 3s 6ms/step - loss: 0.1197 - accuracy: 0.9667 - val_loss: 0.0969 - val_accuracy: 0.9763

<tensorflow.python.keras.callbacks.History at 0x7f32422a9550>

For this example, there is minimal loss in test accuracy after pruning, compared to the baseline.

_, model_for_pruning_accuracy = model_for_pruning.evaluate(

test_images, test_labels, verbose=0)print('Baseline test accuracy:', baseline_model_accuracy)

print('Pruned test accuracy:', model_for_pruning_accuracy)Baseline test accuracy: 0.972599983215332

Pruned test accuracy: 0.9689000248908997

# Conclusion

Exploiting sparsity in deep learning is a complex task, but the opportunity for scaling existing applications and enabling new ones is large. In addition to the topics discussed in this post, increased software and hardware support for sparsity will be an important contributor to progress in the field.

# References

You can reach out to me at:

LinkedIn: https://www.linkedin.com/in/manmohan-dogra-4a6655169/

GitHub: https://github.com/immohann

Twitter: https://twitter.com/immohann