Pruning in Deep Learning: The efficacy of pruning for model compression
Through this article, we will seek to perform a closer exploration of the effectiveness of pruning for model compression.
Introduction
The primary property of deep learning is that its accuracy empirically scales with the size of the model and the amount of training data. This property has dramatically improved the state-of-the-art performance across sectors. With that, the resources required for the training and serving these models scales with model and data size. The equivalence of scale to model quality is the catalyst for the pursuit of efficiency in the machine learning and systems community. A promising avenue for improving the efficiency of Deep Neural Networks is exploiting sparsity. DNN’s primarily consists of input data, weight matrices during training, and activation matrices that are computed from weights and data. Sparsity can be induced in these matrices to get the efficient model.
Pruning in DL
Model pruning seeks to produce sparsity in a deep neural network’s various connection matrices, thereby reducing the number of nonzero-valued parameters in the model.
Recent reports prune deep networks at the cost of only a marginal loss in accuracy and achieve a sizable reduction in model size. This hints at the possibility that the baseline models in these experiments are perhaps severely over-parameterized and an alternative for model compression might be to simply reduce the number of hidden units while maintaining the model’s dense structure, exposing a similar trade-off in model size and accuracy.
Pruning can certainly reduce the size of the model by removing non-zero valued parameters while affecting the accuracy of the model a little.
Table 1 highlights work reductions from different forms of sparsity on popular neural network models.
Implementation
We’ll learn to do the following by using Pruning with Keras (a Tensor Flow implementation):
- Train a
tf.keras
model for MNIST from scratch. - Fine-tune the model by applying the pruning API and see the accuracy.
Setup
pip install -q tensorflow-model-optimizationimport tempfile
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
%load_ext tensorboard
Train a model for MNIST without pruning
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
# Define the model architecture.
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(10)
])
# Train the digit classification model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
train_images,
train_labels,
epochs=4,
validation_split=0.1,
)
Output
Epoch 1/4
1688/1688 [==============================] - 7s 4ms/step - loss: 0.3422 - accuracy: 0.9004 - val_loss: 0.1760 - val_accuracy: 0.9498
Epoch 2/4
1688/1688 [==============================] - 7s 4ms/step - loss: 0.1813 - accuracy: 0.9457 - val_loss: 0.1176 - val_accuracy: 0.9698
Epoch 3/4
1688/1688 [==============================] - 7s 4ms/step - loss: 0.1220 - accuracy: 0.9648 - val_loss: 0.0864 - val_accuracy: 0.9770
Epoch 4/4
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0874 - accuracy: 0.9740 - val_loss: 0.0763 - val_accuracy: 0.9787
<tensorflow.python.keras.callbacks.History at 0x7f32cbeb9550>
Evaluate baseline test accuracy and save the model for later usage.
_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
_, keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)
Output
Baseline test accuracy: 0.972599983215332
Saved baseline model to: /tmp/tmp6quew9ig.h5
Fine-tune pre-trained model with pruning
Start the model with 50% sparsity (50% zeros in weights) and end with 80% sparsity.
import tensorflow_model_optimization as tfmot
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set.
num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
# Define model for pruning.
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
final_sparsity=0.80,
begin_step=0,
end_step=end_step)
}
model_for_pruning = prune_low_magnitude(model, **pruning_params)
# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model_for_pruning.summary()
Output
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:220: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
prune_low_magnitude_reshape (None, 28, 28, 1) 1
_________________________________________________________________
prune_low_magnitude_conv2d ( (None, 26, 26, 12) 230
_________________________________________________________________
prune_low_magnitude_max_pool (None, 13, 13, 12) 1
_________________________________________________________________
prune_low_magnitude_flatten (None, 2028) 1
_________________________________________________________________
prune_low_magnitude_dense (P (None, 10) 40572
=================================================================
Total params: 40,805
Trainable params: 20,410
Non-trainable params: 20,395
_________________________________________________________________
Train and evaluate the model against the baseline
Fine-tune with pruning for two epochs.
tfmot.sparsity.keras.UpdatePruningStep
is required during training, and tfmot.sparsity.keras.PruningSummaries
provides logs for tracking progress and debugging.
logdir = tempfile.mkdtemp()callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]model_for_pruning.fit(train_images, train_labels,
batch_size=batch_size, epochs=epochs, validation_split=validation_split,
callbacks=callbacks)
Output
Epoch 1/2
1/422 [..............................] - ETA: 0s - loss: 0.2689 - accuracy: 0.8984WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/summary_ops_v2.py:1277: stop (from tensorflow.python.eager.profiler) is deprecated and will be removed after 2020-07-01.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
422/422 [==============================] - 3s 7ms/step - loss: 0.1105 - accuracy: 0.9691 - val_loss: 0.1247 - val_accuracy: 0.9682
Epoch 2/2
422/422 [==============================] - 3s 6ms/step - loss: 0.1197 - accuracy: 0.9667 - val_loss: 0.0969 - val_accuracy: 0.9763
<tensorflow.python.keras.callbacks.History at 0x7f32422a9550>
For this example, there is minimal loss in test accuracy after pruning, compared to the baseline.
_, model_for_pruning_accuracy = model_for_pruning.evaluate(
test_images, test_labels, verbose=0)print('Baseline test accuracy:', baseline_model_accuracy)
print('Pruned test accuracy:', model_for_pruning_accuracy)Baseline test accuracy: 0.972599983215332
Pruned test accuracy: 0.9689000248908997
Conclusion
Exploiting sparsity in deep learning is a complex task, but the opportunity for scaling existing applications and enabling new ones is large. In addition to the topics discussed in this post, increased software and hardware support for sparsity will be an important contributor to progress in the field.
References
You can reach out to me at:
LinkedIn: https://www.linkedin.com/in/manmohan-dogra-4a6655169/
GitHub: https://github.com/immohann
Twitter: https://twitter.com/immohann