Concrete ML
WebsiteLibrariesProducts & ServicesDevelopersSupport
1.1
1.1
  • What is Concrete ML?
  • Getting Started
    • Installation
    • Key Concepts
    • Inference in the Cloud
    • Demos and Tutorials
  • Built-in Models
    • Linear Models
    • Tree-based Models
    • Neural Networks
    • Pandas
    • Built-in Model Examples
  • Deep Learning
    • Using Torch
    • Using ONNX
    • Step-by-step Guide
    • Deep Learning Examples
    • Debugging Models
    • Optimizing Inference
  • Advanced topics
    • Quantization
    • Pruning
    • Compilation
    • Prediction with FHE
    • Production Deployment
    • Advanced Features
    • Serialization
  • Developer Guide
    • Workflow
      • Set Up the Project
      • Set Up Docker
      • Documentation
      • Support and Issues
      • Contributing
    • Inner Workings
      • Importing ONNX
      • Quantization Tools
      • FHE Op-graph Design
      • External Libraries
    • API
      • concrete.ml.common.check_inputs.md
      • concrete.ml.common.debugging.custom_assert.md
      • concrete.ml.common.debugging.md
      • concrete.ml.common.md
      • concrete.ml.common.serialization.decoder.md
      • concrete.ml.common.serialization.dumpers.md
      • concrete.ml.common.serialization.encoder.md
      • concrete.ml.common.serialization.loaders.md
      • concrete.ml.common.serialization.md
      • concrete.ml.common.utils.md
      • concrete.ml.deployment.deploy_to_aws.md
      • concrete.ml.deployment.deploy_to_docker.md
      • concrete.ml.deployment.fhe_client_server.md
      • concrete.ml.deployment.md
      • concrete.ml.deployment.server.md
      • concrete.ml.deployment.utils.md
      • concrete.ml.onnx.convert.md
      • concrete.ml.onnx.md
      • concrete.ml.onnx.onnx_impl_utils.md
      • concrete.ml.onnx.onnx_model_manipulations.md
      • concrete.ml.onnx.onnx_utils.md
      • concrete.ml.onnx.ops_impl.md
      • concrete.ml.pytest.md
      • concrete.ml.pytest.torch_models.md
      • concrete.ml.pytest.utils.md
      • concrete.ml.quantization.base_quantized_op.md
      • concrete.ml.quantization.md
      • concrete.ml.quantization.post_training.md
      • concrete.ml.quantization.quantized_module.md
      • concrete.ml.quantization.quantized_ops.md
      • concrete.ml.quantization.quantizers.md
      • concrete.ml.search_parameters.md
      • concrete.ml.search_parameters.p_error_search.md
      • concrete.ml.sklearn.base.md
      • concrete.ml.sklearn.glm.md
      • concrete.ml.sklearn.linear_model.md
      • concrete.ml.sklearn.md
      • concrete.ml.sklearn.qnn.md
      • concrete.ml.sklearn.qnn_module.md
      • concrete.ml.sklearn.rf.md
      • concrete.ml.sklearn.svm.md
      • concrete.ml.sklearn.tree.md
      • concrete.ml.sklearn.tree_to_numpy.md
      • concrete.ml.sklearn.xgb.md
      • concrete.ml.torch.compile.md
      • concrete.ml.torch.md
      • concrete.ml.torch.numpy_module.md
      • concrete.ml.version.md
Powered by GitBook

Libraries

  • TFHE-rs
  • Concrete
  • Concrete ML
  • fhEVM

Developers

  • Blog
  • Documentation
  • Github
  • FHE resources

Company

  • About
  • Introduction to FHE
  • Media
  • Careers
On this page
  • Configuring quantization parameters
  • Running encrypted inference
  • Simulated FHE Inference in the clear
  • Generic Quantization Aware Training import
  • Supported operators and activations
  • Operators
  • Quantizers
  • Activations

Was this helpful?

Export as PDF
  1. Deep Learning

Using Torch

PreviousBuilt-in Model ExamplesNextUsing ONNX

Last updated 1 year ago

Was this helpful?

In addition to the built-in models, Concrete ML supports generic machine learning models implemented with Torch, or .

As is the most appropriate method of training neural networks that are compatible with , Concrete ML works with , a library providing QAT support for PyTorch.

The following example uses a simple QAT PyTorch model that implements a fully connected neural network with two hidden layers. Due to its small size, making this model respect FHE constraints is relatively easy.

import brevitas.nn as qnn
import torch.nn as nn
import torch

N_FEAT = 12
n_bits = 3

class QATSimpleNet(nn.Module):
    def __init__(self, n_hidden):
        super().__init__()

        self.quant_inp = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
        self.fc1 = qnn.QuantLinear(N_FEAT, n_hidden, True, weight_bit_width=n_bits, bias_quant=None)
        self.quant2 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
        self.fc2 = qnn.QuantLinear(n_hidden, n_hidden, True, weight_bit_width=n_bits, bias_quant=None)
        self.quant3 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
        self.fc3 = qnn.QuantLinear(n_hidden, 2, True, weight_bit_width=n_bits, bias_quant=None)

    def forward(self, x):
        x = self.quant_inp(x)
        x = self.quant2(torch.relu(self.fc1(x)))
        x = self.quant3(torch.relu(self.fc2(x)))
        x = self.fc3(x)
        return x
from concrete.ml.torch.compile import compile_brevitas_qat_model
import numpy

torch_input = torch.randn(100, N_FEAT)
torch_model = QATSimpleNet(30)
quantized_module = compile_brevitas_qat_model(
    torch_model, # our model
    torch_input, # a representative input-set to be used for both quantization and compilation
)

Configuring quantization parameters

The PyTorch/Brevitas models, created following the example above, require the user to configure quantization parameters such as bit_width (activation bit-width) and weight_bit_width. The quantization parameters, along with the number of neurons on each layer, will determine the accumulator bit-width of the network. Larger accumulator bit-widths result in higher accuracy but slower FHE inference time.

The following configurations were determined through experimentation for convolutional and dense layers.

target accumulator bit-width
activation bit-width
weight bit-width
number of active neurons

8

3

3

80

10

4

3

90

12

5

5

110

14

6

6

110

16

7

6

120

Using the templates above, the probability of obtaining the target accumulator bit-width, for a single layer, was determined experimentally by training 10 models for each of the following data-sets.

probability of obtaining the accumulator bit-width

8

10

12

14

16

mnist,fashion

72%

100%

72%

85%

100%

cifar10

88%

88%

75%

75%

88%

cifar100

73%

88%

61%

66%

100%

Note that the accuracy on larger data-sets, when the accumulator size is low, is also reduced strongly.

accuracy for target accumulator bit-width

8

10

12

14

16

cifar10

20%

37%

89%

90%

90%

cifar100

6%

30%

67%

69%

69%

Running encrypted inference

The model can now perform encrypted inference.

x_test = numpy.array([numpy.random.randn(N_FEAT)])

y_pred = quantized_module.forward(x_test, fhe="execute")

In this example, the input values x_test and the predicted values y_pred are floating points. The quantization (resp. de-quantization) step is done in the clear within the forward method, before (resp. after) any FHE computations.

Simulated FHE Inference in the clear

The user can also perform the inference on clear data. Two approaches exist:

  • quantized_module.forward(quantized_x, fhe="simulate"): simulates FHE execution taking into account Table Lookup errors. De-quantization must be done in a second step as for actual FHE execution. Simulation takes into account the p_error/global_p_error parameters

  • quantized_module.forward(quantized_x, fhe="disable"): computes predictions in the clear on quantized data, and then de-quantize the result. The return value of this function contains the de-quantized (float) output of running the model in the clear. Calling this function on clear data is useful when debugging, but this does not perform actual FHE simulation.

Generic Quantization Aware Training import

While the example above shows how to import a Brevitas/PyTorch model, Concrete ML also provides an option to import generic QAT models implemented in PyTorch or through ONNX. Deep learning models made with TensorFlow or Keras should be usable by preliminary converting them to ONNX.

QAT models contain quantizers in the PyTorch graph. These quantizers ensure that the inputs to the Linear/Dense and Conv layers are quantized.

from concrete.ml.torch.compile import compile_torch_model
n_bits_qat = 3

quantized_module = compile_torch_model(
    torch_model,
    torch_input,
    import_qat=True,
    n_bits=n_bits_qat,
)

When importing QAT models using this generic pipeline, a representative calibration set should be given as quantization parameters in the model need to be inferred from the statistics of the values encountered during inference.

Supported operators and activations

Concrete ML supports a variety of PyTorch operators that can be used to build fully connected or convolutional neural networks, with normalization and activation layers. Moreover, many element-wise operators are supported.

Operators

univariate operators

shape modifying operators

operators that take an encrypted input and unencrypted constants

Concrete ML supports these operators but also the QAT equivalents from Brevitas.

  • brevitas.nn.QuantLinear

  • brevitas.nn.QuantConv2d

operators that can take both encrypted+unencrypted and encrypted+encrypted inputs

Quantizers

  • brevitas.nn.QuantIdentity

Activations

The equivalent versions from torch.functional are also supported.

Once the model is trained, calling the from Concrete ML will automatically perform conversion and compilation of a QAT network. Here, 3-bit quantization is used for both the weights and activations. The compile_brevitas_qat_model function automatically identifies the number of quantization bits used in the Brevitas model.

FHE simulation allows to measure the impact of the Table Lookup error on the model accuracy. The Table Lookup error can be adjusted using p_error/global_p_error, as described in the section.

Suppose that n_bits_qat is the bit-width of activations and weights during the QAT process. To import a PyTorch QAT network, you can use the library function, passing import_qat=True:

Alternatively, if you want to import an ONNX model directly, please see . The also supports the import_qat parameter.

-- partial support

torch.abs
torch.clip
torch.exp
torch.log
torch.gt
torch.clamp
torch.mul, torch.Tensor operator *
torch.div, torch.Tensor operator /
torch.nn.identity
torch.nn.BatchNorm2d
torch.reshape
torch.Tensor.view
torch.flatten
torch.transpose
torch.conv2d, torch.nn.Conv2D
torch.matmul
torch.nn.Linear
torch.add, torch.Tensor operator +
torch.sub, torch.Tensor operator -
torch.nn.Celu
torch.nn.Elu
torch.nn.GELU
torch.nn.Hardshrink
torch.nn.HardSigmoid
torch.nn.Hardswish
torch.nn.HardTanh
torch.nn.LeakyRelu
torch.nn.LogSigmoid
torch.nn.Mish
torch.nn.PReLU
torch.nn.ReLU6
torch.nn.ReLU
torch.nn.Selu
torch.nn.Sigmoid
torch.nn.SiLU
torch.nn.Softplus
torch.nn.Softshrink
torch.nn.Softsign
torch.nn.Tanh
torch.nn.Tanhshrink
torch.nn.Threshold
exported as ONNX graphs
Quantization Aware Training (QAT)
the ONNX guide
compile_brevitas_qat_model
compile_torch_model
compile_onnx_model
approximate computation
Brevitas
FHE constraints