concrete.ml.sklearn.qnn_module
Sparse Quantized Neural Network torch module.
SparseQuantNeuralNetwork
Sparse Quantized Neural Network.
This class implements an MLP that is compatible with FHE constraints. The weights and activations are quantized to low bit-width and pruning is used to ensure accumulators do not surpass an user-provided accumulator bit-width. The number of classes and number of layers are specified by the user, as well as the breadth of the network
__init__
Sparse Quantized Neural Network constructor.
Args:
input_dim
(int): Number of dimensions of the input data.
n_layers
(int): Number of linear layers for this network.
n_outputs
(int): Number of output classes or regression targets.
n_w_bits
(int): Number of weight bits.
n_a_bits
(int): Number of activation and input bits.
n_accum_bits
(int): Maximal allowed bit-width of intermediate accumulators.
n_hidden_neurons_multiplier
(int): The number of neurons on the hidden will be the number of dimensions of the input multiplied by n_hidden_neurons_multiplier
. Note that pruning is used to adjust the accumulator size to attempt to keep the maximum accumulator bit-width to n_accum_bits
, meaning that not all hidden layer neurons will be active. The default value for n_hidden_neurons_multiplier
is chosen for small dimensions of the input. Reducing this value decreases the FHE inference time considerably but also decreases the robustness and accuracy of model training.
n_prune_neurons_percentage
(float): The percentage of neurons to prune in the hidden layers. This can be used when setting n_hidden_neurons_multiplier
with a high number (3-4), once good accuracy is obtained, in order to speed up the model in FHE.
activation_function
(Type): The activation function to use in the network (e.g., torch.ReLU, torch.SELU, torch.Sigmoid, ...).
quant_narrow
(bool): Whether this network should quantize the values using narrow range (e.g a 2-bits signed quantization uses [-1, 0, 1] instead of [-2, -1, 0, 1]).
quant_signed
(bool): Whether this network should quantize the values using signed integers.
power_of_two_scaling
(bool): Force quantization scales to be a power of two to enable inference speed optimizations. Defaults to True
Raises:
ValueError
: If the parameters have invalid values or the computed accumulator bit-width is zero.
enable_pruning
Enable pruning in the network. Pruning must be made permanent to recover pruned weights.
Raises:
ValueError
: If the quantization parameters are invalid.
forward
Forward pass.
Args:
x
(torch.Tensor): network input
Returns:
x
(torch.Tensor): network prediction
make_pruning_permanent
Make the learned pruning permanent in the network.
max_active_neurons
Compute the maximum number of active (non-zero weight) neurons.
The computation is done using the quantization parameters passed to the constructor. Warning: With the current quantization algorithm (asymmetric) the value returned by this function is not guaranteed to ensure FHE compatibility. For some weight distributions, weights that are 0 (which are pruned weights) will not be quantized to 0. Therefore the total number of active quantized neurons will not be equal to max_active_neurons.
Returns:
int
: The maximum number of active neurons.