Using Torch

In addition to the built-in models, Concrete-ML supports generic machine learning models implemented with Torch, or exported as ONNX graphs.

As Quantization Aware Training (QAT) is the most appropriate method of training neural networks that are compatible with FHE constraints, Concrete-ML works with Brevitas, a library providing QAT support for PyTorch.

The following example uses a simple QAT PyTorch model that implements a fully connected neural network with two hidden layers. Due to its small size, making this model respect FHE constraints is relatively easy.

import brevitas.nn as qnn
import torch.nn as nn
import torch

N_FEAT = 12
n_bits = 3

class QATSimpleNet(nn.Module):
    def __init__(self, n_hidden):
        super().__init__()

        self.quant_inp = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
        self.fc1 = qnn.QuantLinear(N_FEAT, n_hidden, True, weight_bit_width=n_bits, bias_quant=None)
        self.quant2 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
        self.fc2 = qnn.QuantLinear(n_hidden, n_hidden, True, weight_bit_width=n_bits, bias_quant=None)
        self.quant3 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
        self.fc3 = qnn.QuantLinear(n_hidden, 2, True, weight_bit_width=n_bits, bias_quant=None)

    def forward(self, x):
        x = self.quant_inp(x)
        x = self.quant2(torch.relu(self.fc1(x)))
        x = self.quant3(torch.relu(self.fc2(x)))
        x = self.fc3(x)
        return x

Once the model is trained, calling the compile_brevitas_qat_model from Concrete-ML will automatically perform conversion and compilation of a QAT network. Here, 3-bit quantization is used for both the weights and activations.

from concrete.ml.torch.compile import compile_brevitas_qat_model
import numpy

torch_input = torch.randn(100, N_FEAT)
torch_model = QATSimpleNet(30)
quantized_numpy_module = compile_brevitas_qat_model(
    torch_model, # our model
    torch_input, # a representative input-set to be used for both quantization and compilation
    n_bits = n_bits,
)

The model can now be used to perform encrypted inference. Next, the test data is quantized:

x_test = numpy.array([numpy.random.randn(N_FEAT)])
x_test_quantized = quantized_numpy_module.quantize_input(x_test)

and the encrypted inference can be run using either:

  • quantized_numpy_module.forward_and_dequant() to compute predictions in the clear on quantized data, and then de-quantize the result. The return value of this function contains the dequantized (float) output of running the model in the clear. Calling the forward function on the clear data is useful when debugging. The results in FHE will be the same as those on clear quantized data.

  • quantized_numpy_module.forward_fhe.encrypt_run_decrypt() to perform the FHE inference. In this case, de-quantization is done in a second stage using quantized_numpy_module.dequantize_output().

Generic Quantization Aware Training import

While the example above shows how to import a Brevitas/PyTorch model, Concrete-ML also provides an option to import generic QAT models implemented either in PyTorch or through ONNX. Interestingly, deep learning models made with TensorFlow or Keras should be usable, by preliminary converting them to ONNX.

QAT models contain quantizers in the PyTorch graph. These quantizers ensure that the inputs to the Linear/Dense and Conv layers are quantized.

Suppose that n_bits_qat is the bit-width of activations and weights during the QAT process. To import a PyTorch QAT network, you can use the compile_torch_model library function, passing import_qat=True:

from concrete.ml.torch.compile import compile_torch_model
n_bits_qat = 3

quantized_numpy_module = compile_torch_model(
    torch_model,
    torch_input,
    import_qat=True,
    n_bits=n_bits_qat,
)

Alternatively, if you want to import an ONNX model directly, please see the ONNX guide. The compile_onnx_model also supports the import_qat parameter.

When importing QAT models using this generic pipeline, a representative calibration set should be given as quantization parameters in the model need to be inferred from the statistics of the values encountered during inference.

Supported operators and activations

Concrete-ML supports a variety of PyTorch operators that can be used to build fully connected or convolutional neural networks, with normalization and activation layers. Moreover, many element-wise operators are supported.

Operators

univariate operators

shape modifying operators

operators that take an encrypted input and unencrypted constants

Please note that Concrete-ML supports these operators but also the QAT equivalents from Brevitas.

  • brevitas.nn.QuantLinear

  • brevitas.nn.QuantConv2d

operators that can take both encrypted+unencrypted and encrypted+encrypted inputs

Quantizers

  • brevitas.nn.QuantIdentity

Activations

Note that the equivalent versions from torch.functional are also supported.

Last updated