concrete.ml.common.utils
Utils that can be re-used by other pieces of code in the module.
SUPPORTED_FLOAT_TYPES
SUPPORTED_INT_TYPES
SUPPORTED_TYPES
MAX_BITWIDTH_BACKWARD_COMPATIBLE
USE_OLD_VL
QUANT_ROUND_LIKE_ROUND_PBS
replace_invalid_arg_name_chars
Sanitize arg_name, replacing invalid chars by _.
This does not check that the starting character of arg_name is valid.
Args:
arg_name
(str): the arg name to sanitize.
Returns:
str
: the sanitized arg name, with only chars in _VALID_ARG_CHARS.
generate_proxy_function
Generate a proxy function for a function accepting only *args type arguments.
This returns a runtime compiled function with the sanitized argument names passed in desired_functions_arg_names as the arguments to the function.
Args:
function_to_proxy
(Callable): the function defined like def f(*args) for which to return a function like f_proxy(arg_1, arg_2) for any number of arguments.
desired_functions_arg_names
(Iterable[str]): the argument names to use, these names are sanitized and the mapping between the original argument name to the sanitized one is returned in a dictionary. Only the sanitized names will work for a call to the proxy function.
Returns:
Tuple[Callable, Dict[str, str]]
: the proxy function and the mapping of the original arg name to the new and sanitized arg names.
get_onnx_opset_version
Return the ONNX opset_version.
Args:
onnx_model
(onnx.ModelProto): the model.
Returns:
int
: the version of the model
manage_parameters_for_pbs_errors
Return (p_error, global_p_error) that we want to give to Concrete.
The returned (p_error, global_p_error) depends on user's parameters and the way we want to manage defaults in Concrete ML, which may be different from the way defaults are managed in Concrete.
Principle: - if none are set, we set global_p_error to a default value of our choice - if both are set, we raise an error - if one is set, we use it and forward it to Concrete
Note that global_p_error is currently set to 0 in the FHE simulation mode.
Args:
p_error
(Optional[float]): probability of error of a single PBS.
global_p_error
(Optional[float]): probability of error of the full circuit.
Returns:
(p_error, global_p_error)
: parameters to give to the compiler
Raises:
ValueError
: if the two parameters are set (this is not as in Concrete-Python)
check_there_is_no_p_error_options_in_configuration
Check the user did not set p_error or global_p_error in configuration.
It would be dangerous, since we set them in direct arguments in our calls to Concrete-Python.
Args:
configuration
: Configuration object to use during compilation
get_model_class
Return the class of the model (instantiated or not), which can be a partial() instance.
Args:
model_class
: The model, which can be a partial() instance.
Returns: The model's class.
is_model_class_in_a_list
Indicate if a model class, which can be a partial() instance, is an element of a_list.
Args:
model_class
: The model, which can be a partial() instance.
a_list
: The list in which to look into.
Returns: If the model's class is in the list or not.
get_model_name
Return the name of the model, which can be a partial() instance.
Args:
model_class
: The model, which can be a partial() instance.
Returns: the model's name.
is_classifier_or_partial_classifier
Indicate if the model class represents a classifier.
Args:
model_class
: The model class, which can be a functool's partial
class.
Returns:
bool
: If the model class represents a classifier.
is_regressor_or_partial_regressor
Indicate if the model class represents a regressor.
Args:
model_class
: The model class, which can be a functool's partial
class.
Returns:
bool
: If the model class represents a regressor.
is_pandas_dataframe
Indicate if the input container is a Pandas DataFrame.
This function is inspired from Scikit-Learn's test validation tools and avoids the need to add and import Pandas as an additional dependency to the project. See https://github.com/scikit-learn/scikit-learn/blob/98cf537f5/sklearn/utils/validation.py#L629
Args:
input_container
(Any): The input container to consider
Returns:
bool
: If the input container is a DataFrame
is_pandas_series
Indicate if the input container is a Pandas Series.
This function is inspired from Scikit-Learn's test validation tools and avoids the need to add and import Pandas as an additional dependency to the project. See https://github.com/scikit-learn/scikit-learn/blob/98cf537f5/sklearn/utils/validation.py#L629
Args:
input_container
(Any): The input container to consider
Returns:
bool
: If the input container is a Series
is_pandas_type
Indicate if the input container is a Pandas DataFrame or Series.
Args:
input_container
(Any): The input container to consider
Returns:
bool
: If the input container is a DataFrame orSeries
check_dtype_and_cast
Convert any allowed type into an array and cast it if required.
If values types don't match with any supported type or the expected dtype, raise a ValueError.
Args:
values
(Any): The values to consider
expected_dtype
(str): The expected dtype, either "float32" or "int64"
error_information
(str): Additional information to put in front of the error message when raising a ValueError. Default to None.
Returns:
(Union[numpy.ndarray, torch.utils.data.dataset.Subset])
: The values with proper dtype.
Raises:
ValueError
: If the values' dtype don't match the expected one or casting is not possible.
compute_bits_precision
Compute the number of bits required to represent x.
Args:
x
(numpy.ndarray): Integer data
Returns:
int
: the number of bits required to represent x
is_brevitas_model
Check if a model is a Brevitas type.
Args:
model
: PyTorch model.
Returns:
bool
: True if model
is a Brevitas network.
to_tuple
Make the input a tuple if it is not already the case.
Args:
x
(Any): The input to consider. It can already be an input.
Returns:
tuple
: The input as a tuple.
all_values_are_integers
Indicate if all unpacked values are of a supported integer dtype.
Args:
*values (Any)
: The values to consider.
Returns:
bool
: Whether all values are supported integers or not.
all_values_are_floats
Indicate if all unpacked values are of a supported float dtype.
Args:
*values (Any)
: The values to consider.
Returns:
bool
: Whether all values are supported floating points or not.
all_values_are_of_dtype
Indicate if all unpacked values are of the specified dtype(s).
Args:
*values (Any)
: The values to consider.
dtypes
(Union[str, List[str]]): The dtype(s) to consider.
Returns:
bool
: Whether all values are of the specified dtype(s) or not.
array_allclose_and_same_shape
Check if two numpy arrays are equal within a tolerances and have the same shape.
Args:
a
(numpy.ndarray): The first input array
b
(numpy.ndarray): The second input array
rtol
(float): The relative tolerance parameter
atol
(float): The absolute tolerance parameter
equal_nan
(bool): Whether to compare NaN’s as equal. If True, NaN’s in a will be considered equal to NaN’s in b in the output array
Returns:
bool
: True if the arrays have the same shape and all elements are equal within the specified tolerances, False otherwise.
FheMode
Enum representing the execution mode.
This enum inherits from str in order to be able to easily compare a string parameter to its equivalent Enum attribute.
Examples: fhe_disable = FheMode.DISABLE
fhe_disable == "disable"
True