# concrete.ml.search_parameters.p_error_search.md

Last updated

Last updated

module

`concrete.ml.search_parameters.p_error_search`

p_error binary search for classification and regression tasks.

Only PyTorch neural networks and Concrete built-in models are supported.

Concrete built-in models include trees and QNN

Quantized aware trained model are supported using Brevitas framework

Torch models can be converted into post-trained quantized models

The `p_error`

represents an essential hyper-parameter in the FHE computation at Zama. As it impacts the speed of the FHE computations and the model's performance.

In this script, we provide an approach to find out an optimal `p_error`

, which would offer an interesting compromise between speed and efficiency.

The `p_error`

represents the probability of a single PBS being incorrect. Know that the FHE scheme allows to perform 2 types of operations

Linear operations: additions and multiplications

Non-linear operation: uni-variate activation functions

At Zama, non-linear operations are represented by table lookup (TLU), which are implemented through the Programmable Bootstrapping technology (PBS). A single PBS operation has `p_error`

chances of being incorrect.

It's highly recommended to adjust the `p_error`

as it is linked to the data-set.

The inference is performed via the FHE simulation mode.

The goal is to look for the largest `p_error_i`

, a float ∈ ]0,0.9[, which gives a model_i that has `accuracy_i`

, such that: | accuracy_i - accuracy_0| <= Threshold, where: Threshold ∈ R, given by the user and `accuracy_0`

refers to original model_0 with `p_error_0 ≈ 0.0`

.

`p_error`

is bounded between 0 and 0.9 `p_error ≈ 0.0`

, refers to the original model in clear, that gives an accuracy that we note as `accuracy_0`

.

We assume that the condition is satisfied when we have a match A match is defined as a uni-variate function, through `strategy`

argument, given by the user, it can be

`any = lambda all_matches: any(all_matches)`

`all = lambda all_matches: all(all_matches)`

`mean = lambda all_matches: numpy.mean(all_matches) >= 0.5`

`median = lambda all_matches: numpy.median(all_matches) == 1`

To validate the results of the FHE simulation and get a stable estimation, we do several simulations If match, we update the lower bound to be the current `p_error`

Else, we update the upper bound to be the current `p_error`

Update the current `p_error`

with the mean of the bounds

We stop the search when the maximum number of iterations is reached.

If we don't reach the convergence, a user warning is raised.

function

`compile_and_simulated_fhe_inference`

Get the quantized module of a given model in FHE, simulated or not.

Supported models are:

Built-in models, including trees and QNN,

Quantized aware trained model are supported using Brevitas framework,

Torch models can be converted into post-trained quantized models.

**Args:**

`estimator`

(torch.nn.Module): Torch model or a built-in model`calibration_data`

(numpy.ndarray): Calibration data required for compilation`ground_truth`

(numpy.ndarray): The ground truth`p_error`

(float): Concrete ML uses table lookup (TLU) to represent any non-linear`n_bits`

(int): Quantization bits`is_qat`

(bool): True, if the NN has been trained through QAT. If`False`

it is converted into post-trained quantized model.`metric`

(Callable): Classification or regression evaluation metric.`predict`

(str): The predict method to use.`kwargs`

(Dict): Hyper-parameters to use for the metric.

**Returns:**

`Tuple[numpy.ndarray, float]`

: De-quantized or quantized output model depending on`is_benchmark_test`

and the score.

**Raises:**

`ValueError`

: If the model is neither a built-in model nor a torch neural network.

class

`BinarySearch`

Class for `p_error`

hyper-parameter search for classification and regression tasks.

method

`__init__`

`p_error`

binary search algorithm.

**Args:**

`estimator`

: Custom model (Brevitas or PyTorch) or built-in models (trees or QNNs).`predict`

(str): The prediction method to use for built-in tree models.`metric`

(Callable): Evaluation metric for classification or regression tasks.`n_bits`

(int): Quantization bits, for PTQ models. Default is 4.`is_qat`

(bool): Flag that indicates whether the`estimator`

has been trained through QAT (quantization-aware training). Default is True.`lower`

(float): The lower bound of the search space for the`p_error`

. Default is 0.0.`upper`

(float): The upper bound of the search space for the`p_error`

. Default is 0.9. Increasing the upper bound beyond this range may result in longer execution times especially when`p_error≈1`

.`max_iter`

(int): The maximum number of iterations to run the binary search algorithm. Default is 20.`n_simulation`

(int): The number of simulations to validate the results of the FHE simulation. Default is 5.`strategy`

(Any): A uni-variate function that defines a "match". It can be built-in functions provided in Python, such as any() or all(), or custom functions, like:`mean = lambda all_matches`

: numpy.mean(all_matches) >= 0.5`median = lambda all_matches`

: numpy.median(all_matches) == 1 Default is 'all'.`max_metric_loss`

(float): The threshold to use to satisfy the condition: | accuracy_i - accuracy_0| <=`max_metric_loss`

. Default is 0.01.`save`

(bool): Flag that indicates whether to save some meta data in log file. Default is False.`log_file`

(str): The log file name. Default is None.`directory`

(str): The directory to save the meta data. Default is None.`verbose`

(bool): Flag that indicates whether to print detailed information. Default is False.`kwargs`

: Parameter of the evaluation metric.

method

`eval_match`

Eval the matches.

**Args:**

`strategy`

(Callable): A uni-variate function that defines a "match". It can be built-in functions provided in Python, such as any() or all(), or custom functions, like:`mean = lambda all_matches`

: numpy.mean(all_matches) >= 0.5`median = lambda all_matches`

: numpy.median(all_matches) == 1`all_matches`

(List[bool]): List of matches.

**Returns:**

`bool`

: Evaluation of the matches according to the given strategy.

**Raises:**

`TypeError`

: If the`strategy`

function is not valid.

method

`reset_history`

Clean history.

method

`run`

Get an optimal `p_error`

using binary search for classification and regression tasks.

PyTorch models and built-in models are supported.

To find an optimal `p_error`

that offers a balance between speed and efficiency, we use a binary search approach. Where the goal to look for the largest `p_error_i`

, a float ∈ ]0,1[, which gives a model_i that has `accuracy_i`

, such that | accuracy_i - accuracy_0| <= max_metric_loss, where max_metric_loss ∈ R and `accuracy_0`

refers to original model_0 with `p_error ≈ 0.0`

.

We assume that the condition is satisfied when we have a match. A match is defined as a uni-variate function, specified through `strategy`

argument.

To validate the results of the FHE simulation and get a stable estimation, we perform multiple samplings. If match, we update the lower bound to be the current p_error. Else, we update the upper bound to be the current p_error. Update the current p_error with the mean of the bounds.

We stop the search either when the maximum number of iterations is reached or when the update of the `p_error`

is below at a given threshold.

**Args:**

`x`

(numpy.ndarray): Data-set which is used for calibration and evaluation`ground_truth`

(numpy.ndarray): The ground truth`kwargs`

(Dict): Class parameters`strategy`

(Callable): A uni-variate function that defines a "match". It can be: a`built-in functions provided in Python, like`

: any or all or a custom function, like:`mean = lambda all_matches`

: numpy.mean(all_matches) >= 0.5`median = lambda all_matches`

: numpy.median(all_matches) == 1 Default is`all`

.

**Returns:**

`float`

: The optimal`p_error`

that aims to speedup computations while maintaining good performance.