concrete.ml.search_parameters.p_error_search.md
module concrete.ml.search_parameters.p_error_search
concrete.ml.search_parameters.p_error_search
p_error binary search for classification and regression tasks.
Only PyTorch neural networks and Concrete built-in models are supported.
Concrete built-in models include trees and QNN
Quantized aware trained model are supported using Brevitas framework
Torch models can be converted into post-trained quantized models
The p_error
represents an essential hyper-parameter in the FHE computation at Zama. As it impacts the speed of the FHE computations and the model's performance.
In this script, we provide an approach to find out an optimal p_error
, which would offer an interesting compromise between speed and efficiency.
The p_error
represents the probability of a single PBS being incorrect. Know that the FHE scheme allows to perform 2 types of operations
Linear operations: additions and multiplications
Non-linear operation: uni-variate activation functions
At Zama, non-linear operations are represented by table lookup (TLU), which are implemented through the Programmable Bootstrapping technology (PBS). A single PBS operation has p_error
chances of being incorrect.
It's highly recommended to adjust the p_error
as it is linked to the data-set.
The inference is performed via the FHE simulation mode.
The goal is to look for the largest p_error_i
, a float ∈ ]0,0.9[, which gives a model_i that has accuracy_i
, such that: | accuracy_i - accuracy_0| <= Threshold, where: Threshold ∈ R, given by the user and accuracy_0
refers to original model_0 with p_error_0 ≈ 0.0
.
p_error
is bounded between 0 and 0.9 p_error ≈ 0.0
, refers to the original model in clear, that gives an accuracy that we note as accuracy_0
.
We assume that the condition is satisfied when we have a match A match is defined as a uni-variate function, through strategy
argument, given by the user, it can be
any = lambda all_matches: any(all_matches)
all = lambda all_matches: all(all_matches)
mean = lambda all_matches: numpy.mean(all_matches) >= 0.5
median = lambda all_matches: numpy.median(all_matches) == 1
To validate the results of the FHE simulation and get a stable estimation, we do several simulations If match, we update the lower bound to be the current p_error
Else, we update the upper bound to be the current p_error
Update the current p_error
with the mean of the bounds
We stop the search when the maximum number of iterations is reached.
If we don't reach the convergence, a user warning is raised.
function compile_and_simulated_fhe_inference
compile_and_simulated_fhe_inference
Get the quantized module of a given model in FHE, simulated or not.
Supported models are:
Built-in models, including trees and QNN,
Quantized aware trained model are supported using Brevitas framework,
Torch models can be converted into post-trained quantized models.
Args:
estimator
(torch.nn.Module): Torch model or a built-in modelcalibration_data
(numpy.ndarray): Calibration data required for compilationground_truth
(numpy.ndarray): The ground truthp_error
(float): Concrete ML uses table lookup (TLU) to represent any non-linearn_bits
(int): Quantization bitsis_qat
(bool): True, if the NN has been trained through QAT. IfFalse
it is converted into post-trained quantized model.metric
(Callable): Classification or regression evaluation metric.predict
(str): The predict method to use.kwargs
(Dict): Hyper-parameters to use for the metric.
Returns:
Tuple[numpy.ndarray, float]
: De-quantized or quantized output model depending onis_benchmark_test
and the score.
Raises:
ValueError
: If the model is neither a built-in model nor a torch neural network.
class BinarySearch
BinarySearch
Class for p_error
hyper-parameter search for classification and regression tasks.
method __init__
__init__
p_error
binary search algorithm.
Args:
estimator
: Custom model (Brevitas or PyTorch) or built-in models (trees or QNNs).predict
(str): The prediction method to use for built-in tree models.metric
(Callable): Evaluation metric for classification or regression tasks.n_bits
(int): Quantization bits, for PTQ models. Default is 4.is_qat
(bool): Flag that indicates whether theestimator
has been trained through QAT (quantization-aware training). Default is True.lower
(float): The lower bound of the search space for thep_error
. Default is 0.0.upper
(float): The upper bound of the search space for thep_error
. Default is 0.9. Increasing the upper bound beyond this range may result in longer execution times especially whenp_error≈1
.max_iter
(int): The maximum number of iterations to run the binary search algorithm. Default is 20.n_simulation
(int): The number of simulations to validate the results of the FHE simulation. Default is 5.strategy
(Any): A uni-variate function that defines a "match". It can be built-in functions provided in Python, such as any() or all(), or custom functions, like:mean = lambda all_matches
: numpy.mean(all_matches) >= 0.5median = lambda all_matches
: numpy.median(all_matches) == 1 Default is 'all'.max_metric_loss
(float): The threshold to use to satisfy the condition: | accuracy_i - accuracy_0| <=max_metric_loss
. Default is 0.01.save
(bool): Flag that indicates whether to save some meta data in log file. Default is False.log_file
(str): The log file name. Default is None.directory
(str): The directory to save the meta data. Default is None.verbose
(bool): Flag that indicates whether to print detailed information. Default is False.kwargs
: Parameter of the evaluation metric.
method eval_match
eval_match
Eval the matches.
Args:
strategy
(Callable): A uni-variate function that defines a "match". It can be built-in functions provided in Python, such as any() or all(), or custom functions, like:mean = lambda all_matches
: numpy.mean(all_matches) >= 0.5median = lambda all_matches
: numpy.median(all_matches) == 1all_matches
(List[bool]): List of matches.
Returns:
bool
: Evaluation of the matches according to the given strategy.
Raises:
TypeError
: If thestrategy
function is not valid.
method reset_history
reset_history
Clean history.
method run
run
Get an optimal p_error
using binary search for classification and regression tasks.
PyTorch models and built-in models are supported.
To find an optimal p_error
that offers a balance between speed and efficiency, we use a binary search approach. Where the goal to look for the largest p_error_i
, a float ∈ ]0,1[, which gives a model_i that has accuracy_i
, such that | accuracy_i - accuracy_0| <= max_metric_loss, where max_metric_loss ∈ R and accuracy_0
refers to original model_0 with p_error ≈ 0.0
.
We assume that the condition is satisfied when we have a match. A match is defined as a uni-variate function, specified through strategy
argument.
To validate the results of the FHE simulation and get a stable estimation, we perform multiple samplings. If match, we update the lower bound to be the current p_error. Else, we update the upper bound to be the current p_error. Update the current p_error with the mean of the bounds.
We stop the search either when the maximum number of iterations is reached or when the update of the p_error
is below at a given threshold.
Args:
x
(numpy.ndarray): Data-set which is used for calibration and evaluationground_truth
(numpy.ndarray): The ground truthkwargs
(Dict): Class parametersstrategy
(Callable): A uni-variate function that defines a "match". It can be: abuilt-in functions provided in Python, like
: any or all or a custom function, like:mean = lambda all_matches
: numpy.mean(all_matches) >= 0.5median = lambda all_matches
: numpy.median(all_matches) == 1 Default isall
.
Returns:
float
: The optimalp_error
that aims to speedup computations while maintaining good performance.
Last updated