Stepbystep Guide
This guide provides a complete example of converting a PyTorch neural network into its FHEfriendly, quantized counterpart. It focuses on Quantization Aware Training a simple network on a synthetic dataset.
In general, quantization can be carried out in two different ways: either during training with Quantization Aware Training (QAT) or after the training phase with PostTraining Quantization (PTQ).
Regarding FHEfriendly neural networks, QAT is the best way to reach optimal accuracy under FHE constrains. This technique allows weights and activations to be reduced to very low bitwidths (e.g. 23 bits), which, combined with pruning, can keep accumulator bitwidths low.
ConcreteML uses the third party library Brevitas to perform QAT for PyTorch NNs, but options exist for other frameworks such as Keras/Tensorflow.
Several demos and tutorials that use Brevitas are available in ConcreteML library, such as the CIFAR classification tutorial.
This guide is based on a notebook tutorial, from which some code blocks are documented here.
For a more formal description of the usage of Brevitas to build FHEcompatible neural networks, please see the Brevitas usage reference.
Baseline PyTorch model
In PyTorch, using standard layers, a fully connected neural network would look as follows:
The notebook tutorial, example shows how to train a fullyconnected neural network, similar to the one above, on a synthetic 2D dataset with a checkerboard grid pattern of 100 x 100 points. The data is split into 9500 training and 500 test samples.
Once trained, this PyTorch network can be imported using the compile_torch_model
function. This function uses simple PostTraining Quantization.
The network was trained using different numbers of neurons in the hidden layers, and quantized using 3bits weights and activations. The mean accumulator size shown below was extracted using the Virtual Library and is measured as the mean over 10 runs of the experiment. An accumulator of 6.6 means that 4 times out of 10 the accumulator measured was 6 bits while 6 times it was 7 bits.
neurons  10  30  100 

fp32 accuracy  68.70%  83.32%  88.06% 
3bit accuracy  56.44%  55.54%  56.50% 
mean accumulator size  6.6  6.9  7.4 
This shows that the fp32 accuracy and accumulator size increases with the number of hidden neurons, while the 3bit accuracy remains low irrespective of the number of neurons. While all the configurations tried here were FHEcompatible (accumulator < 16 bits), it is often preferable to have a lower accumulator size in order to speed up the inference time.
The accumulator size is determined by ConcreteNumpy as being the maximum bitwidth encountered anywhere in the encrypted circuit.
Quantization Aware Training:
Quantization Aware Training using Brevitas is the best way to guarantee a good accuracy for ConcreteML compatible neural networks.
Brevitas provides a quantized version of almost all PyTorch layers (Linear
layer becomes QuantLinear
, ReLU
layer becomes QuantReLU
and so one), plus some extra quantization parameters, such as :
bit_width
: precision quantization bits for activationsact_quant
: quantization protocol for the activationsweight_bit_width
: precision quantization bits for weightsweight_quant
: quantization protocol for the weights
In order to use FHE, the network must be quantized from end to end, and thanks to the Brevitas's QuantIdentity
layer, it is possible to quantize the input by placing it at the entry point of the network. Moreover, it is also possible to combine PyTorch and Brevitas layers, provided that a QuantIdentity
is placed after this PyTorch layer. The following table gives the replacements to be made to convert a PyTorch NN for ConcreteML compatibility.
Pytorch fp32 layer  ConcreteML model with Pytorch/Brevitas 









Furthermore, some PyTorch operators (from the PyTorch functional API), require a brevitas.quant.QuantIdentity
to be applied on their inputs.
PyTorch ops that require QuantIdentity 





The QAT import tool in ConcreteML is a work in progress. While it has been tested with some networks built with Brevitas, it is possible to use other tools to obtain QAT networks.
For instance, with Brevitas, the network above becomes :
Note that in the network above, biases are used for linear layers but are not quantized ("bias": True, "bias_quant": None
). The addition of the bias is an univariate operation and is fused into the activation function.
Training this network with pruning (see below) with 30 out of 100 total nonzero neurons gives good accuracy while keeping the accumulator size low.
Nonzero neurons  30 

3bit accuracy brevitas  95.4% 
3bit accuracy in ConcreteML  95.4% 
Accumulator size  7 
The PyTorch QAT training loop is the same as the standard floating point training loop, but hyperparameters such as learning rate might need to be adjusted.
Quantization Aware Training is somewhat slower than normal training. QAT introduces quantization during both the forward and backward passes. The quantization process is inefficient on GPUs as its computational intensity is low with respect to data transfer time.
Pruning using torch
Considering that FHE only works with limited integer precision, there is a risk of overflowing in the accumulator, which will make ConcreteML raise an error.
To understand how to overcome this limitation, consider a scenario where 2 bits are used for weights and layer inputs/outputs. The Linear
layer computes a dot product between weights and inputs $y = \sum_i w_i x_i$. With 2 bits, no overflow can occur during the computation of the Linear
layer as long the number of neurons does not exceed 14, i.e. the sum of 14 products of 2bit numbers does not exceed 7 bits.
By default, ConcreteML uses symmetric quantization for model weights, with values in the interval $\left[2^{n_{bits}1}, 2^{n_{bits}1}1\right]$. For example, for $n_{bits}=2$ the possible values are $[2, 1, 0, 1]$, for $n_{bits}=3$ the values can be $[4,3,2,1,0,1,2,3]$.
However, in a typical setting, the weights will not all have the maximum or minimum values (e.g. $2^{n_{bits}1}$). Instead, weights typically have a normal distribution around 0, which is one of the motivating factors for their symmetric quantization. A symmetric distribution and many zerovalued weights are desirable because opposite sign weights can cancel each other out and zero weights do not increase the accumulator size.
This fact can be leveraged to train a network with more neurons, while not overflowing the accumulator, using a technique called pruning, where the developer can impose a number of zerovalued weights. Torch provides support for pruning out of the box.
The following code shows how to use pruning in the previous example:
Results with PrunedQuantNet
, a pruned version of the QuantSimpleNet
with 100 neurons on the hidden layers, are given below, showing a mean accumulator size measured over 10 runs of the experiment:
Nonzero neurons  10  30 

3bit accuracy  82.50%  88.06% 
Mean accumulator size  6.6  6.8 
This shows that the fp32 accuracy has been improved while maintaining constant mean accumulator size.
When pruning a larger neural network during training, it is easier to obtain a low bitwidth accumulator while maintaining better final accuracy. Thus, pruning is more robust than training a similar, smaller network.
Last updated