Concrete ML
WebsiteLibrariesProducts & ServicesDevelopersSupport
1.6
1.6
  • Welcome
  • Get Started
    • What is Concrete ML?
    • Installation
    • Key concepts
    • Inference in the cloud
  • Built-in Models
    • Linear models
    • Tree-based models
    • Neural networks
    • Nearest neighbors
    • Encrypted dataframe
    • Encrypted training
  • Deep Learning
    • Using Torch
    • Using ONNX
    • Step-by-step guide
    • Debugging models
    • Optimizing inference
  • Guides
    • Prediction with FHE
    • Production deployment
    • Hybrid models
    • Serialization
  • Tutorials
    • See all tutorials
    • Built-in model examples
    • Deep learning examples
  • References
    • API
  • Explanations
    • Security and correctness
    • Quantization
    • Pruning
    • Compilation
    • Advanced features
    • Project architecture
      • Importing ONNX
      • Quantization tools
      • FHE Op-graph design
      • External libraries
  • Developers
    • Set up the project
    • Set up Docker
    • Documentation
    • Support and issues
    • Contributing
    • Support new ONNX node
    • Release note
    • Feature request
    • Bug report
Powered by GitBook

Libraries

  • TFHE-rs
  • Concrete
  • Concrete ML
  • fhEVM

Developers

  • Blog
  • Documentation
  • Github
  • FHE resources

Company

  • About
  • Introduction to FHE
  • Media
  • Careers
On this page
  • Quantization-aware training
  • Post-training quantization
  • Configuring quantization parameters
  • Running encrypted inference
  • Simulated FHE Inference in the clear
  • Supported operators and activations
  • Operators
  • Quantizers
  • Activation functions

Was this helpful?

Export as PDF
  1. Deep Learning

Using Torch

PreviousEncrypted trainingNextUsing ONNX

Last updated 9 months ago

Was this helpful?

In addition to the built-in models, Concrete ML supports generic machine learning models implemented with Torch, or .

There are two approaches to build :

  • requires using custom layers, but can quantize weights and activations to low bit-widths. Concrete ML works with , a library providing QAT support for PyTorch. To use this mode, compile models using compile_brevitas_qat_model

  • Post-training Quantization: This mode allows a vanilla PyTorch model to be compiled. However, when quantizing weights & activations to fewer than 7 bits, the accuracy can decrease strongly. On the other hand, depending on the model size, quantizing with 6-8 bits can be incompatible with FHE constraints. To use this mode, compile models with compile_torch_model.

Both approaches require the rounding_threshold_bits parameter to be set accordingly. The best values for this parameter need to be determined through experimentation. A good initial value to try is 6. See for more details.

See the for an explanation of some error messages that the compilation function may raise.

Quantization-aware training

The following example uses a simple QAT PyTorch model that implements a fully connected neural network with two hidden layers. Due to its small size, making this model respect FHE constraints is relatively easy. To use QAT, Brevitas QuantIdentity nodes must be inserted in the PyTorch model, including one that quantizes the input of the forward function.

import brevitas.nn as qnn
import torch.nn as nn
import torch

N_FEAT = 12
n_bits = 3

class QATSimpleNet(nn.Module):
    def __init__(self, n_hidden):
        super().__init__()

        self.quant_inp = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
        self.fc1 = qnn.QuantLinear(N_FEAT, n_hidden, True, weight_bit_width=n_bits, bias_quant=None)
        self.quant2 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
        self.fc2 = qnn.QuantLinear(n_hidden, n_hidden, True, weight_bit_width=n_bits, bias_quant=None)
        self.quant3 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
        self.fc3 = qnn.QuantLinear(n_hidden, 2, True, weight_bit_width=n_bits, bias_quant=None)

    def forward(self, x):
        x = self.quant_inp(x)
        x = self.quant2(torch.relu(self.fc1(x)))
        x = self.quant3(torch.relu(self.fc2(x)))
        x = self.fc3(x)
        return x
from concrete.ml.torch.compile import compile_brevitas_qat_model
import numpy

torch_input = torch.randn(100, N_FEAT)
torch_model = QATSimpleNet(30)
quantized_module = compile_brevitas_qat_model(
    torch_model, # our model
    torch_input, # a representative input-set to be used for both quantization and compilation
    rounding_threshold_bits={"n_bits": 6, "method": "approximate"}
)

Post-training quantization

The following example uses a simple PyTorch model that implements a fully connected neural network with two hidden layers. The model is compiled to use FHE using compile_torch_model.

import torch.nn as nn
import torch

N_FEAT = 12
n_bits = 6

class PTQSimpleNet(nn.Module):
    def __init__(self, n_hidden):
        super().__init__()

        self.fc1 = nn.Linear(N_FEAT, n_hidden)
        self.fc2 = nn.Linear(n_hidden, n_hidden)
        self.fc3 = nn.Linear(n_hidden, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

from concrete.ml.torch.compile import compile_torch_model
import numpy

torch_input = torch.randn(100, N_FEAT)
torch_model = PTQSimpleNet(5)
quantized_module = compile_torch_model(
    torch_model, # our model
    torch_input, # a representative input-set to be used for both quantization and compilation
    n_bits=6,
    rounding_threshold_bits={"n_bits": 6, "method": "approximate"}
)

Configuring quantization parameters

With QAT (the PyTorch/Brevitas models created following the example above), you need to configure quantization parameters such as bit_width (activation bit-width) and weight_bit_width. When using this mode, set n_bits=None in the compile_brevitas_qat_model.

With PTQ, you need to set the n_bits value in the compile_torch_model function and must manually determine the trade-off between accuracy, FHE compatibility, and latency.

The quantization parameters, along with the number of neurons on each layer, will determine the accumulator bit-width of the network. Larger accumulator bit-widths result in higher accuracy but slower FHE inference time.

Running encrypted inference

The model can now perform encrypted inference.

x_test = numpy.array([numpy.random.randn(N_FEAT)])

y_pred = quantized_module.forward(x_test, fhe="execute")

In this example, the input values x_test and the predicted values y_pred are floating points. The quantization (resp. de-quantization) step is done in the clear within the forward method, before (resp. after) any FHE computations.

Simulated FHE Inference in the clear

  • quantized_module.forward(quantized_x, fhe="simulate"): simulates FHE execution taking into account Table Lookup errors. De-quantization must be done in a second step as for actual FHE execution. Simulation takes into account the p_error/global_p_error parameters

  • quantized_module.forward(quantized_x, fhe="disable"): computes predictions in the clear on quantized data, and then de-quantize the result. The return value of this function contains the de-quantized (float) output of running the model in the clear. Calling this function on clear data is useful when debugging, but this does not perform actual FHE simulation.

Supported operators and activations

Concrete ML supports a variety of PyTorch operators that can be used to build fully connected or convolutional neural networks, with normalization and activation layers. Moreover, many element-wise operators are supported.

Operators

Univariate operators

Shape modifying operators

Tensor operators

Multi-variate operators: encrypted input and unencrypted constants

Concrete ML also supports some of their QAT equivalents from Brevitas.

  • brevitas.nn.QuantLinear

  • brevitas.nn.QuantConv1d

  • brevitas.nn.QuantConv2d

Multi-variate operators: encrypted+unencrypted or encrypted+encrypted inputs

Quantizers

  • brevitas.nn.QuantIdentity

Activation functions

The equivalent versions from torch.functional are also supported.

Zama 5-Question Developer Survey

Once the model is trained, calling the from Concrete ML will automatically perform conversion and compilation of a QAT network. Here, 3-bit quantization is used for both the weights and activations. The compile_brevitas_qat_model function automatically identifies the number of quantization bits used in the Brevitas model.

If QuantIdentity layers are missing for any input or intermediate value, the compile function will raise an error. See the for an explanation.

You can perform the inference on clear data in order to evaluate the impact of quantization and of FHE computation on the accuracy of their model. See for more details. Two approaches exist:

FHE simulation allows to measure the impact of the Table Lookup error on the model accuracy. The Table Lookup error can be adjusted using p_error/global_p_error, as described in the section.

-- for casting to dtype

-- partial support

We want to hear from you! Take 1 minute to share your thoughts and helping us enhance our documentation and libraries. 👉 to participate.

compile_brevitas_qat_model
torch.nn.identity
torch.clip
torch.clamp
torch.round
torch.floor
torch.min
torch.max
torch.abs
torch.neg
torch.sign
torch.logical_or, torch.Tensor operator ||
torch.logical_not
torch.gt, torch.greater
torch.ge, torch.greater_equal
torch.lt, torch.less
torch.le, torch.less_equal
torch.eq
torch.where
torch.exp
torch.log
torch.pow
torch.sum
torch.mul, torch.Tensor operator *
torch.div, torch.Tensor operator /
torch.nn.BatchNorm2d
torch.nn.BatchNorm3d
torch.erf, torch.special.erf
torch.nn.functional.pad
torch.reshape
torch.Tensor.view
torch.flatten
torch.unsqueeze
torch.squeeze
torch.transpose
torch.concat, torch.cat
torch.nn.Unfold
torch.Tensor.expand
torch.Tensor.to
torch.nn.Linear
torch.conv1d, torch.nn.Conv1D
torch.conv2d, torch.nn.Conv2D
torch.nn.AvgPool2d
torch.nn.MaxPool2d
torch.add, torch.Tensor operator +
torch.sub, torch.Tensor operator -
torch.matmul
torch.nn.CELU
torch.nn.ELU
torch.nn.GELU
torch.nn.Hardshrink
torch.nn.HardSigmoid
torch.nn.Hardswish
torch.nn.HardTanh
torch.nn.LeakyReLU
torch.nn.LogSigmoid
torch.nn.Mish
torch.nn.PReLU
torch.nn.ReLU6
torch.nn.ReLU
torch.nn.SELU
torch.nn.Sigmoid
torch.nn.SiLU
torch.nn.Softplus
torch.nn.Softshrink
torch.nn.Softsign
torch.nn.Tanh
torch.nn.Tanhshrink
torch.nn.Threshold
Click here
exported as ONNX graphs
Quantization Aware Training (QAT)
Brevitas
FHE-compatible deep networks
approximate computation
here
common compilation errors page
common compilation errors page
this section