concrete.ml.pytest.utils.md
module concrete.ml.pytest.utils
concrete.ml.pytest.utils
Common functions or lists for test files, which can't be put in fixtures.
Global Variables
sklearn_models_and_datasets
function get_random_extract_of_sklearn_models_and_datasets
get_random_extract_of_sklearn_models_and_datasets
Return a random sublist of sklearn_models_and_datasets.
The sublist contains exactly one model of each kind.
Returns: the sublist
function instantiate_model_generic
instantiate_model_generic
Instantiate any Concrete ML model type.
Args:
model_class
(class): The type of the model to instantiateparameters
(dict): Hyper-parameters for the model instantiation
Returns:
model_name
(str): The type of the model as a stringmodel
(object): The model instance
function get_torchvision_dataset
get_torchvision_dataset
Get train or testing data-set.
Args:
param
(Dict): Set of hyper-parameters to use based on the selected torchvision data-set.It must contain
: data-set transformations (torchvision.transforms.Compose), and the data-set_size (Optional[int]).train_set
(bool): Use train data-set if True, else testing data-set
Returns: A torchvision data-sets.
function data_calibration_processing
data_calibration_processing
Reduce size of the given data-set.
Args:
data
: The input container to considern_sample
(int): Number of samples to keep if the given data-settargets
: Ifdataset
is atorch.utils.data.Dataset
, it typically contains both the data and the corresponding targets. In this case,targets
must be set toNone
. Ifdata
is instance oftorch.Tensor
or 'numpy.ndarray,
targets` is expected.
Returns:
Tuple[numpy.ndarray, numpy.ndarray]
: The input data and the target (respectively x and y).
Raises:
TypeError
: If the 'data-set' does not match any expected type.
function load_torch_model
load_torch_model
Load an object saved with torch.save() from a file or dict.
Args:
model_class
(torch.nn.Module): A PyTorch or Brevitas network.state_dict_or_path
(Optional[Union[str, Path, Dict[str, Any]]]): Path or state_dictparams
(Dict): Model's parametersdevice
(str): Device type.
Returns:
torch.nn.Module
: A PyTorch or Brevitas network.
function values_are_equal
values_are_equal
Indicate if two values are equal.
This method takes into account objects of type None, numpy.ndarray, numpy.floating, numpy.integer, numpy.random.RandomState or any instance that provides a __eq__
method.
Args:
value_2
(Any): The first value to consider.value_1
(Any): The second value to consider.
Returns:
bool
: If the two values are equal.
function check_serialization
check_serialization
Check that the given object can properly be serialized.
This function serializes all objects using the dump
, dumps
, load
and loads
functions from Concrete ML. If the given object provides a dump
and dumps
method, they are also serialized using these.
Args:
object_to_serialize
(Any): The object to serialize.expected_type
(Type): The object's expected type.equal_method
(Optional[Callable]): The function to use to compare the two loaded objects. Default tovalues_are_equal
.check_str
(bool): If the JSON strings should also be checked. Default to True.
Last updated
Was this helpful?