In addition to the built-in models, Concrete-ML supports generic machine learning models implemented with Torch, or exported as ONNX graphs.
The following example uses a simple torch model that implements a fully connected neural network with two hidden units. Due to its small size, making this model respect FHE constraints is relatively easy.
Once the model is trained, calling the compile_torch_model
from Concrete-ML will automatically perform post-training quantization and compilation to FHE. Here, we use a 3-bit quantization for both the weights and activations.
The model can now be used to perform encrypted inference. Next, the test data is quantized:
and the encrypted inference run using either:
quantized_numpy_module.forward_and_dequant()
to compute predictions in the clear, on quantized data and then de-quantize the result. The return value of this function contains the dequantized (float) output of running the model in the clear. Calling the forward function on the clear data is useful when debugging. The results in FHE will be the same as those on clear quantized data.
quantized_numpy_module.forward_fhe.encrypt_run_decrypt()
to perform the FHE inference. In this case, de-quantization is done in a second stage using quantized_numpy_module.dequantize_output()
.
While the example above shows how to import a floating point model for post-training quantization, Concrete-ML also provides an option to import quantization aware trained (QAT) models.
QAT models contain quantizers in the torch graph. These quantizers ensure that the inputs to the Linear/Dense and Conv layers are quantized. Torch quantizers are not included in Concrete-ML, so you can either implement your own or use a 3rd party library such as brevitas as shown in the FHE-friendly models documentation. Custom models can have a more generic architecture and training procedure than the Concrete-ML built-in models.
Suppose that n_bits_qat
is the bitwidth of activations and weights during the QAT process. To import a torch QAT network you can use the following library function:
Concrete-ML supports a variety of torch operators that can be used to build fully connected or convolutional neural networks, with normalization and activation layers. Moreover, many element-wise operators are supported.
torch.nn.GELU
-- sometimes accuracy issues
torch.nn.LogSigmoid
-- sometimes accuracy issues
torch.nn.Threshold
-- partial support
Note that the equivalent versions from torch.functional
are also supported.