Step-by-Step Guide
This section includes a complete example of converting a neural network to Quantization Aware Training (QAT). This tutorial uses PyTorch and Brevitas to train a simple network on a synthetic data-set. You can find the demo of the final network in the custom-model with quantization aware training demo. To see how to apply these network design principles for a real-world data-set, please see the MNIST use-case example.
For a more formal description of the usage of Brevitas to build FHE-compatible neural networks, please see the Brevitas usage reference.
Summary
Baseline model
This example shows how to train a fully-connected neural network on a synthetic 2D data-set with a checkerboard grid pattern of 100 x 100 points. The data is split into 9500 training and 500 test samples.
In PyTorch, using standard layers, this network would look as follows:
Once trained, this network can be imported using the compile_torch_model
function. This function uses simple Post-Training Quantization.
The network was trained using different numbers of neurons in the hidden layers, and quantized using 3-bits weights and activations. The mean accumulator size shown below was extracted using the Virtual Library.
fp32 accuracy
68.70%
83.32%
88.06%
3bit accuracy
56.44%
55.54%
56.50%
mean accumulator size
6.6
6.9
7.4
This shows that the fp32 accuracy and accumulator size increases with the number of hidden neurons, while the 3-bit accuracy remains low irrespective of to the number of neurons. While all the configurations tried here were FHE-compatible (accumulator < 8 bits), it is sometimes preferable to have a lower accumulator size in order for the inference time to be faster.
The accumulator size is determined by Concrete-Numpy as being the maximum bit-width encountered anywhere in the encrypted circuit
Pruning using Torch
Considering that FHE only works with limited integer precision, there is a risk of overflowing in the accumulator, resulting in unpredictable results.
This can be leveraged to train a network with more neurons, while not overflowing the accumulator, using a technique called pruning, where the developer can impose a number of zero-valued weights. Torch provides support for pruning out of the box.
The following code shows how to use pruning in the previous example:
Results with PrunedSimpleNet
, a pruned version of the SimpleNet
with 100 neurons on the hidden layers, are given below:
fp32 accuracy
82.50%
88.06%
3bit accuracy
57.74%
57.82%
mean accumulator size
6.6
6.8
This shows that the fp32 accuracy has been improved while maintaining constant mean accumulator size.
When pruning a larger neural network during training, it is easier to obtain a low bit-width accumulator while maintaining better final accuracy. Thus, pruning is more robust than training a similar smaller network.
Quantization Aware Training
While pruning helps maintain the post-quantization level of accuracy in low-precision settings, it does not help maintain accuracy when quantizing from floating point models. The best way to guarantee accuracy is to use QAT (read more in the quantization documentation).
In this example, QAT is done using Brevitas, changing Linear
layers to QuantLinear
and adding quantizers on the inputs of linear layers using QuantIdentity.
The QAT import tool in Concrete-ML is a work in progress. While it has been tested with some networks built with Brevitas, it is possible to use other tools to obtain QAT networks.
Training this network with 30 out of 100 total non-zero neurons gives good accuracy while being FHE-compatible (accumulator size < 8 bits).
3bit accuracy brevitas
95.4%
3bit accuracy in Concrete-ML
92.4%
accumulator size
7
The PyTorch QAT training loop is the same as the standard floating point training loop, but hyper-parameters such as learning rate might need to be adjusted.
Quantization Aware Training is somewhat slower than normal training. QAT introduces quantization during both the forward and backward passes. The quantization process is inefficient on GPUs as its computational intensity is low with respect to data transfer time.
Last updated