Concrete ML
WebsiteLibrariesProducts & ServicesDevelopersSupport
0.5
0.5
  • What is Concrete ML?
  • Getting Started
    • Installation
    • Key Concepts
    • Inference in the Cloud
  • Built-in Models
    • Linear Models
    • Tree-based Models
    • Neural Networks
    • Pandas
    • Built-in Model Examples
  • Deep Learning
    • Using Torch
    • Using ONNX
    • Step-by-Step Guide
    • Deep Learning Examples
    • Debugging Models
  • Advanced topics
    • Quantization
    • Pruning
    • Compilation
    • Production Deployment
    • Advanced Features
  • Developer Guide
    • Workflow
      • Set Up the Project
      • Set Up Docker
      • Documentation
      • Support and Issues
      • Contributing
    • Inner workings
      • Importing ONNX
      • Quantization tools
      • FHE Op-graph design
      • External Libraries
    • API
      • concrete.ml.common
      • concrete.ml.common.check_inputs
      • concrete.ml.common.debugging
      • concrete.ml.common.debugging.custom_assert
      • concrete.ml.common.utils
      • concrete.ml.deployment
      • concrete.ml.deployment.fhe_client_server
      • concrete.ml.onnx
      • concrete.ml.onnx.convert
      • concrete.ml.onnx.onnx_model_manipulations
      • concrete.ml.onnx.onnx_utils
      • concrete.ml.onnx.ops_impl
      • concrete.ml.quantization
      • concrete.ml.quantization.base_quantized_op
      • concrete.ml.quantization.post_training
      • concrete.ml.quantization.quantized_module
      • concrete.ml.quantization.quantized_ops
      • concrete.ml.quantization.quantizers
      • concrete.ml.sklearn
      • concrete.ml.sklearn.base
      • concrete.ml.sklearn.glm
      • concrete.ml.sklearn.linear_model
      • concrete.ml.sklearn.protocols
      • concrete.ml.sklearn.qnn
      • concrete.ml.sklearn.rf
      • concrete.ml.sklearn.svm
      • concrete.ml.sklearn.torch_module
      • concrete.ml.sklearn.tree
      • concrete.ml.sklearn.tree_to_numpy
      • concrete.ml.sklearn.xgb
      • concrete.ml.torch
      • concrete.ml.torch.compile
      • concrete.ml.torch.numpy_module
      • concrete.ml.version
Powered by GitBook

Libraries

  • TFHE-rs
  • Concrete
  • Concrete ML
  • fhEVM

Developers

  • Blog
  • Documentation
  • Github
  • FHE resources

Company

  • About
  • Introduction to FHE
  • Media
  • Careers
On this page
  • Summary
  • Baseline model
  • Pruning using Torch
  • Quantization Aware Training

Was this helpful?

Export as PDF
  1. Deep Learning

Step-by-Step Guide

PreviousUsing ONNXNextDeep Learning Examples

Last updated 2 years ago

Was this helpful?

This section includes a complete example of converting a neural network to Quantization Aware Training (QAT). This tutorial uses PyTorch and Brevitas to train a simple network on a synthetic data-set. You can find the demo of the final network in the . To see how to apply these network design principles for a real-world data-set, please see the .

For a more formal description of the usage of Brevitas to build FHE-compatible neural networks, please see the .

Summary

Baseline model

This example shows how to train a fully-connected neural network on a synthetic 2D data-set with a checkerboard grid pattern of 100 x 100 points. The data is split into 9500 training and 500 test samples.

In PyTorch, using standard layers, this network would look as follows:

from torch import nn
import torch

N_FEAT = 2
class SimpleNet(nn.Module):
    """Simple MLP with PyTorch"""

    def __init__(self, n_hidden=30):
        super().__init__()
        self.fc1 = nn.Linear(in_features=N_FEAT, out_features=n_hidden)
        self.fc2 = nn.Linear(in_features=n_hidden, out_features=n_hidden)
        self.fc3 = nn.Linear(in_features=n_hidden, out_features=2)


    def forward(self, x):
        """Forward pass."""
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
neurons
10
30
100

fp32 accuracy

68.70%

83.32%

88.06%

3bit accuracy

56.44%

55.54%

56.50%

mean accumulator size

6.6

6.9

7.4

This shows that the fp32 accuracy and accumulator size increases with the number of hidden neurons, while the 3-bit accuracy remains low irrespective of to the number of neurons. While all the configurations tried here were FHE-compatible (accumulator < 8 bits), it is sometimes preferable to have a lower accumulator size in order for the inference time to be faster.

The accumulator size is determined by Concrete-Numpy as being the maximum bit-width encountered anywhere in the encrypted circuit

Pruning using Torch

Considering that FHE only works with limited integer precision, there is a risk of overflowing in the accumulator, resulting in unpredictable results.

To understand how to overcome this limitation, consider a scenario where 2 bits are used for weights and layer inputs/outputs. The Linear layer computes a dot product between weights and inputs y=∑iwixiy = \sum_i w_i x_iy=∑i​wi​xi​. With 2 bits, no overflow can occur during the computation of the Linear layer as long the number of neurons does not exceed 14, i.e. the sum of 14 products of 2-bit numbers does not exceed 7 bits.

By default, Concrete-ML uses symmetric quantization for model weights, with values in the interval [−2nbits−1,2nbits−1−1]\left[-2^{n_{bits}-1}, 2^{n_{bits}-1}-1\right][−2nbits​−1,2nbits​−1−1]. For example, for nbits=2n_{bits}=2nbits​=2 the possible values are [−2,−1,0,1][-2, -1, 0, 1][−2,−1,0,1], for nbits=3n_{bits}=3nbits​=3 the values can be [−4,−3,−2,−1,0,1,2,3][-4,-3,-2,-1,0,1,2,3][−4,−3,−2,−1,0,1,2,3].

However, in a typical setting, the weights will not all have the maximum or minimum values (e.g. −2nbits−1-2^{n_{bits}-1}−2nbits​−1). Instead, weights typically have a normal distribution around 0, which is one of the motivating factors for their symmetric quantization. A symmetric distribution and many zero-valued weights are desirable because opposite sign weights can cancel each other out and zero weights do not increase the accumulator size.

The following code shows how to use pruning in the previous example:

import torch.nn.utils.prune as prune

class PrunedSimpleNet(SimpleNet):
    """Simple MLP with PyTorch"""

    def prune(self, max_non_zero, enable):
        # Linear layer weight has dimensions NumOutputs x NumInputs
        for layer in self.named_modules():
            if isinstance(layer, nn.Linear):
                num_zero_weights = (layer.weight.shape[1] - max_non_zero) * layer.weight.shape[0]
                if num_zero_weights <= 0:
                    continue

                if enable:
                    prune.l1_unstructured(layer, "weight", amount=num_zero_weights)
                else:
                    prune.remove(layer, "weight")

Results with PrunedSimpleNet, a pruned version of the SimpleNet with 100 neurons on the hidden layers, are given below:

non-zero neurons
10
30

fp32 accuracy

82.50%

88.06%

3bit accuracy

57.74%

57.82%

mean accumulator size

6.6

6.8

This shows that the fp32 accuracy has been improved while maintaining constant mean accumulator size.

When pruning a larger neural network during training, it is easier to obtain a low bit-width accumulator while maintaining better final accuracy. Thus, pruning is more robust than training a similar smaller network.

Quantization Aware Training

The QAT import tool in Concrete-ML is a work in progress. While it has been tested with some networks built with Brevitas, it is possible to use other tools to obtain QAT networks.

import brevitas.nn as qnn


from brevitas.core.bit_width import BitWidthImplType
from brevitas.core.quant import QuantType
from brevitas.core.restrict_val import FloatToIntImplType, RestrictValueType
from brevitas.core.scaling import ScalingImplType
from brevitas.core.zero_point import ZeroZeroPoint
from brevitas.inject import ExtendedInjector
from brevitas.quant.solver import ActQuantSolver, WeightQuantSolver
from dependencies import value

# Configure quantization options
class CommonQuant(ExtendedInjector):
    bit_width_impl_type = BitWidthImplType.CONST
    scaling_impl_type = ScalingImplType.CONST
    restrict_scaling_type = RestrictValueType.FP
    zero_point_impl = ZeroZeroPoint
    float_to_int_impl_type = FloatToIntImplType.ROUND
    scaling_per_output_channel = False
    narrow_range = True
    signed = True

    @value
    def quant_type(bit_width):
        if bit_width is None:
            return QuantType.FP
        elif bit_width == 1:
            return QuantType.BINARY
        else:
            return QuantType.INT

# Quantization options for weights/activations
class CommonWeightQuant(CommonQuant, WeightQuantSolver):
    scaling_const = 1.0
    signed = True


class CommonActQuant(CommonQuant, ActQuantSolver):
    min_val = -1.0
    max_val = 1.0

class QATPrunedSimpleNet(nn.Module):
    def __init__(self, n_hidden):
        super(QATPrunedSimpleNet, self).__init__()

        n_bits = 3
        self.quant_inp = qnn.QuantIdentity(
            act_quant=CommonActQuant,
            bit_width=n_bits,
            return_quant_tensor=True,
        )

        self.fc1 = qnn.QuantLinear(
            N_FEAT,
            n_hidden,
            True,
            weight_quant=CommonWeightQuant,
            weight_bit_width=n_bits,
            bias_quant=None,
        )

        self.q1 = qnn.QuantIdentity(
            act_quant=CommonActQuant, bit_width=n_bits, return_quant_tensor=True
        )

        self.fc2 = qnn.QuantLinear(
            n_hidden,
            n_hidden,
            True,
            weight_quant=CommonWeightQuant,
            weight_bit_width=3,
            bias_quant=None
        )

        self.q2 = qnn.QuantIdentity(
            act_quant=CommonActQuant, bit_width=n_bits, return_quant_tensor=True
        )

        self.fc3 = qnn.QuantLinear(
            n_hidden,
            2,
            True,
            weight_quant=CommonWeightQuant,
            weight_bit_width=n_hidden,
            bias_quant=None,
        )

        for m in self.modules():
            if isinstance(m, qnn.QuantLinear):
                torch.nn.init.uniform_(m.weight.data, -1, 1)

    def forward(self, x):
        x = self.quant_inp(x)
        x = self.q1(torch.relu(self.fc1(x)))
        x = self.q2(torch.relu(self.fc2(x)))
        x = self.fc3(x)
        return x

    def prune(self, max_non_zero, enable):
        # Linear layer weight has dimensions NumOutputs x NumInputs
        for name, layer in self.named_modules():
            if isinstance(layer, nn.Linear):
                num_zero_weights = (layer.weight.shape[1] - max_non_zero) * layer.weight.shape[0]
                if num_zero_weights <= 0:
                    continue

                if enable:
                    print(f"Pruning layer {name} factor {num_zero_weights}")
                    prune.l1_unstructured(layer, "weight", amount=num_zero_weights)
                else:
                    prune.remove(layer, "weight")

Training this network with 30 out of 100 total non-zero neurons gives good accuracy while being FHE-compatible (accumulator size < 8 bits).

non-zero neurons
30

3bit accuracy brevitas

95.4%

3bit accuracy in Concrete-ML

92.4%

accumulator size

7

The PyTorch QAT training loop is the same as the standard floating point training loop, but hyper-parameters such as learning rate might need to be adjusted.

Quantization Aware Training is somewhat slower than normal training. QAT introduces quantization during both the forward and backward passes. The quantization process is inefficient on GPUs as its computational intensity is low with respect to data transfer time.

Once trained, this network can be imported using the function. This function uses simple Post-Training Quantization.

The network was trained using different numbers of neurons in the hidden layers, and quantized using 3-bits weights and activations. The mean accumulator size shown below was extracted using the .

This can be leveraged to train a network with more neurons, while not overflowing the accumulator, using a technique called , where the developer can impose a number of zero-valued weights. Torch out of the box.

While pruning helps maintain the post-quantization level of accuracy in low-precision settings, it does not help maintain accuracy when quantizing from floating point models. The best way to guarantee accuracy is to use QAT (read more in the ).

In this example, QAT is done using , changing Linear layers to QuantLinear and adding quantizers on the inputs of linear layers using QuantIdentity.

Virtual Library
pruning
provides support for pruning
quantization documentation
Brevitas
custom-model with quantization aware training demo
MNIST use-case example
Building a standard baseline PyTorch model
Adding pruning to make learning more robust
Converting to Quantization Aware Training with Brevitas
compile_torch_model
Brevitas usage reference