concrete.ml.torch.hybrid_model.md

module concrete.ml.torch.hybrid_model

Implement the conversion of a torch model to a hybrid fhe/torch inference.

Global Variables

  • MAX_BITWIDTH_BACKWARD_COMPATIBLE


function tuple_to_underscore_str

tuple_to_underscore_str(tup: Tuple)str

Convert a tuple to a string representation.

Args:

  • tup (Tuple): a tuple to change into string representation

Returns:

  • str: a string representing the tuple


function underscore_str_to_tuple

underscore_str_to_tuple(tup: str) → Tuple

Convert a a string representation of a tuple to a tuple.

Args:

  • tup (str): a string representing the tuple

Returns:

  • Tuple: a tuple to change into string representation


function convert_conv1d_to_linear

convert_conv1d_to_linear(layer_or_module)

Convert all Conv1D layers in a module or a Conv1D layer itself to nn.Linear.

Args:

  • layer_or_module (nn.Module or Conv1D): The module which will be recursively searched for Conv1D layers, or a Conv1D layer itself.

Returns:

  • nn.Module or nn.Linear: The updated module with Conv1D layers converted to Linear layers, or the Conv1D layer converted to a Linear layer.


class HybridFHEMode

Simple enum for different modes of execution of HybridModel.


class RemoteModule

A wrapper class for the modules to be evaluated remotely with FHE.

method __init__

__init__(
    module: Optional[Module] = None,
    server_remote_address: Optional[str] = None,
    module_name: Optional[str] = None,
    model_name: Optional[str] = None,
    verbose: int = 0
)

method forward

forward(x: Tensor) → Union[Tensor, QuantTensor]

Forward pass of the remote module.

To change the behavior of this forward function one must change the fhe_local_mode attribute. Choices are:

  • disable: forward using torch module

  • remote: forward with fhe client-server

  • simulate: forward with local fhe simulation

  • calibrate: forward for calibration

Args:

  • x (torch.Tensor): The input tensor.

Returns:

  • (torch.Tensor): The output tensor.

Raises:

  • ValueError: if local_fhe_mode is not supported


method init_fhe_client

init_fhe_client(
    path_to_client: Optional[Path] = None,
    path_to_keys: Optional[Path] = None
)

Set the clients keys.

Args:

  • path_to_client (str): Path where the client.zip is located.

  • path_to_keys (str): Path where keys are located.

Raises:

  • ValueError: if anything goes wrong with the server.


method remote_call

remote_call(x: Tensor) → Tensor

Call the remote server to get the private module inference.

Args:

  • x (torch.Tensor): The input tensor.

Returns:

  • torch.Tensor: The result of the FHE computation


class HybridFHEModel

Convert a model to a hybrid model.

This is done by converting targeted modules by RemoteModules. This will modify the model in place.

Args:

  • model (nn.Module): The model to modify (in-place modification)

  • module_names (Union[str, List[str]]): The module name(s) to replace with FHE server.

  • server_remote_address): The remote address of the FHE server

  • model_name (str): Model name identifier

  • verbose (int): If logs should be printed when interacting with FHE server

method __init__

__init__(
    model: Module,
    module_names: Union[str, List[str]],
    server_remote_address=None,
    model_name: str = 'model',
    verbose: int = 0
)

method compile_model

compile_model(
    x: Tensor,
    n_bits: Union[int, Dict[str, int]] = 8,
    rounding_threshold_bits: Optional[int] = None,
    p_error: Optional[float] = None,
    configuration: Optional[Configuration] = None
)

Compiles the specific layers to FHE.

Args:

  • x (torch.Tensor): The input tensor for the model. This is used to run the model once for calibration.

  • n_bits (int): The bit precision for quantization during FHE model compilation. Default is 8.

  • rounding_threshold_bits (int): The number of bits to use for rounding threshold during FHE model compilation. Default is 8.

  • p_error (float): Error allowed for each table look-up in the circuit.

  • configuration (Configuration): A concrete Configuration object specifying the FHE encryption parameters. If not specified, a default configuration is used.


method init_client

init_client(
    path_to_clients: Optional[Path] = None,
    path_to_keys: Optional[Path] = None
)

Initialize client for all remote modules.

Args:

  • path_to_clients (Optional[Path]): Path to the client.zip files.

  • path_to_keys (Optional[Path]): Path to the keys folder.


method publish_to_hub

publish_to_hub()

Allow the user to push the model and FHE required files to HF Hub.


method save_and_clear_private_info

save_and_clear_private_info(path: Path, via_mlir=False)

Save the PyTorch model to the provided path and also saves the corresponding FHE circuit.

Args:

  • path (Path): The directory where the model and the FHE circuit will be saved.

  • via_mlir (bool): if fhe circuits should be serialized using via_mlir option useful for cross-platform (compile on one architecture and run on another)


method set_fhe_mode

set_fhe_mode(hybrid_fhe_mode: Union[str, HybridFHEMode])

Set Hybrid FHE mode for all remote modules.

Args:

  • hybrid_fhe_mode (Union[str, HybridFHEMode]): Hybrid FHE mode to set to all remote modules.


class LoggerStub

Placeholder type for a typical logger like the one from loguru.


method info

info(msg: str)

Placholder function for logger.info.

Args:

  • msg (str): the message to output


class HybridFHEModelServer

Hybrid FHE Model Server.

This is a class object to server FHE models serialized using HybridFHEModel.

method __init__

__init__(key_path: Path, model_dir: Path, logger: Optional[LoggerStub])

method add_key

add_key(key: bytes, model_name: str, module_name: str, input_shape: str)

Add public key.

Arguments:

  • key (bytes): public key

  • model_name (str): model name

  • module_name (str): name of the module in the model

  • input_shape (str): input shape of said module

Returns: Dict[str, str] - uid: uid a personal uid


method check_inputs

check_inputs(
    model_name: str,
    module_name: Optional[str],
    input_shape: Optional[str]
)

Check that the given configuration exist in the compiled models folder.

Args:

  • model_name (str): name of the model

  • module_name (Optional[str]): name of the module in the model

  • input_shape (Optional[str]): input shape of the module

Raises:

  • ValueError: if the given configuration does not exist.


method compute

compute(
    model_input: bytes,
    uid: str,
    model_name: str,
    module_name: str,
    input_shape: str
)

Compute the circuit over encrypted input.

Arguments:

  • model_input (bytes): input of the circuit

  • uid (str): uid of the public key to use

  • model_name (str): model name

  • module_name (str): name of the module in the model

  • input_shape (str): input shape of said module

Returns:

  • bytes: the result of the circuit


method dump_key

dump_key(key_bytes: bytes, uid: Union[UUID, str])None

Dump a public key to a stream.

Args:

  • key_bytes (bytes): stream to dump the public serialized key to

  • uid (Union[str, uuid.UUID]): uid of the public key to dump


method get_circuit

get_circuit(model_name, module_name, input_shape)

Get circuit based on model name, module name and input shape.

Args:

  • model_name (str): name of the model

  • module_name (str): name of the module in the model

  • input_shape (str): input shape of the module

Returns:

  • FHEModelServer: a fhe model server of the given module of the given model for the given shape


method get_client

get_client(model_name: str, module_name: str, input_shape: str)

Get client.

Args:

  • model_name (str): name of the model

  • module_name (str): name of the module in the model

  • input_shape (str): input shape of the module

Returns:

  • Path: the path to the correct client

Raises:

  • ValueError: if client couldn't be found


method list_modules

list_modules(model_name: str)

List all modules in a model.

Args:

  • model_name (str): name of the model

Returns: Dict[str, Dict[str, Dict]]


method list_shapes

list_shapes(model_name: str, module_name: str)

List all modules in a model.

Args:

  • model_name (str): name of the model

  • module_name (str): name of the module in the model

Returns: Dict[str, Dict]


method load_key

load_key(uid: Union[str, UUID])bytes

Load a public key from the key path in the file system.

Args:

  • uid (Union[str, uuid.UUID]): uid of the public key to load

Returns:

  • bytes: the bytes of the public key

Last updated