
Implement the conversion of a torch model to a hybrid fhe/torch inference.

Global Variables


function tuple_to_underscore_str

tuple_to_underscore_str(tup: Tuple)str

Convert a tuple to a string representation.


  • tup (Tuple): a tuple to change into string representation


  • str: a string representing the tuple

function underscore_str_to_tuple

underscore_str_to_tuple(tup: str) → Tuple

Convert a a string representation of a tuple to a tuple.


  • tup (str): a string representing the tuple


  • Tuple: a tuple to change into string representation

function convert_conv1d_to_linear


Convert all Conv1D layers in a module or a Conv1D layer itself to nn.Linear.


  • layer_or_module (nn.Module or Conv1D): The module which will be recursively searched for Conv1D layers, or a Conv1D layer itself.


  • nn.Module or nn.Linear: The updated module with Conv1D layers converted to Linear layers, or the Conv1D layer converted to a Linear layer.

class HybridFHEMode

Simple enum for different modes of execution of HybridModel.

class RemoteModule

A wrapper class for the modules to be evaluated remotely with FHE.

method __init__

    module: Optional[Module] = None,
    server_remote_address: Optional[str] = None,
    module_name: Optional[str] = None,
    model_name: Optional[str] = None,
    verbose: int = 0

method forward

forward(x: Tensor) → Union[Tensor, QuantTensor]

Forward pass of the remote module.

To change the behavior of this forward function one must change the fhe_local_mode attribute. Choices are:

  • disable: forward using torch module

  • remote: forward with fhe client-server

  • simulate: forward with local fhe simulation

  • calibrate: forward for calibration


  • x (torch.Tensor): The input tensor.


  • (torch.Tensor): The output tensor.


  • ValueError: if local_fhe_mode is not supported

method init_fhe_client

    path_to_client: Optional[Path] = None,
    path_to_keys: Optional[Path] = None

Set the clients keys.


  • path_to_client (str): Path where the is located.

  • path_to_keys (str): Path where keys are located.


  • ValueError: if anything goes wrong with the server.

method remote_call

remote_call(x: Tensor) → Tensor

Call the remote server to get the private module inference.


  • x (torch.Tensor): The input tensor.


  • torch.Tensor: The result of the FHE computation

class HybridFHEModel

Convert a model to a hybrid model.

This is done by converting targeted modules by RemoteModules. This will modify the model in place.


  • model (nn.Module): The model to modify (in-place modification)

  • module_names (Union[str, List[str]]): The module name(s) to replace with FHE server.

  • server_remote_address): The remote address of the FHE server

  • model_name (str): Model name identifier

  • verbose (int): If logs should be printed when interacting with FHE server

method __init__

    model: Module,
    module_names: Union[str, List[str]],
    model_name: str = 'model',
    verbose: int = 0

method compile_model

    x: Tensor,
    n_bits: Union[int, Dict[str, int]] = 8,
    rounding_threshold_bits: Optional[int] = None,
    p_error: Optional[float] = None,
    configuration: Optional[Configuration] = None

Compiles the specific layers to FHE.


  • x (torch.Tensor): The input tensor for the model. This is used to run the model once for calibration.

  • n_bits (int): The bit precision for quantization during FHE model compilation. Default is 8.

  • rounding_threshold_bits (int): The number of bits to use for rounding threshold during FHE model compilation. Default is 8.

  • p_error (float): Error allowed for each table look-up in the circuit.

  • configuration (Configuration): A concrete Configuration object specifying the FHE encryption parameters. If not specified, a default configuration is used.

method init_client

    path_to_clients: Optional[Path] = None,
    path_to_keys: Optional[Path] = None

Initialize client for all remote modules.


  • path_to_clients (Optional[Path]): Path to the files.

  • path_to_keys (Optional[Path]): Path to the keys folder.

method publish_to_hub


Allow the user to push the model and FHE required files to HF Hub.

method save_and_clear_private_info

save_and_clear_private_info(path: Path, via_mlir=False)

Save the PyTorch model to the provided path and also saves the corresponding FHE circuit.


  • path (Path): The directory where the model and the FHE circuit will be saved.

  • via_mlir (bool): if fhe circuits should be serialized using via_mlir option useful for cross-platform (compile on one architecture and run on another)

method set_fhe_mode

set_fhe_mode(hybrid_fhe_mode: Union[str, HybridFHEMode])

Set Hybrid FHE mode for all remote modules.


  • hybrid_fhe_mode (Union[str, HybridFHEMode]): Hybrid FHE mode to set to all remote modules.

class LoggerStub

Placeholder type for a typical logger like the one from loguru.

method info

info(msg: str)

Placholder function for


  • msg (str): the message to output

class HybridFHEModelServer

Hybrid FHE Model Server.

This is a class object to server FHE models serialized using HybridFHEModel.

method __init__

__init__(key_path: Path, model_dir: Path, logger: Optional[LoggerStub])

method add_key

add_key(key: bytes, model_name: str, module_name: str, input_shape: str)

Add public key.


  • key (bytes): public key

  • model_name (str): model name

  • module_name (str): name of the module in the model

  • input_shape (str): input shape of said module

Returns: Dict[str, str] - uid: uid a personal uid

method check_inputs

    model_name: str,
    module_name: Optional[str],
    input_shape: Optional[str]

Check that the given configuration exist in the compiled models folder.


  • model_name (str): name of the model

  • module_name (Optional[str]): name of the module in the model

  • input_shape (Optional[str]): input shape of the module


  • ValueError: if the given configuration does not exist.

method compute

    model_input: bytes,
    uid: str,
    model_name: str,
    module_name: str,
    input_shape: str

Compute the circuit over encrypted input.


  • model_input (bytes): input of the circuit

  • uid (str): uid of the public key to use

  • model_name (str): model name

  • module_name (str): name of the module in the model

  • input_shape (str): input shape of said module


  • bytes: the result of the circuit

method dump_key

dump_key(key_bytes: bytes, uid: Union[UUID, str])None

Dump a public key to a stream.


  • key_bytes (bytes): stream to dump the public serialized key to

  • uid (Union[str, uuid.UUID]): uid of the public key to dump

method get_circuit

get_circuit(model_name, module_name, input_shape)

Get circuit based on model name, module name and input shape.


  • model_name (str): name of the model

  • module_name (str): name of the module in the model

  • input_shape (str): input shape of the module


  • FHEModelServer: a fhe model server of the given module of the given model for the given shape

method get_client

get_client(model_name: str, module_name: str, input_shape: str)

Get client.


  • model_name (str): name of the model

  • module_name (str): name of the module in the model

  • input_shape (str): input shape of the module


  • Path: the path to the correct client


  • ValueError: if client couldn't be found

method list_modules

list_modules(model_name: str)

List all modules in a model.


  • model_name (str): name of the model

Returns: Dict[str, Dict[str, Dict]]

method list_shapes

list_shapes(model_name: str, module_name: str)

List all modules in a model.


  • model_name (str): name of the model

  • module_name (str): name of the module in the model

Returns: Dict[str, Dict]

method load_key

load_key(uid: Union[str, UUID])bytes

Load a public key from the key path in the file system.


  • uid (Union[str, uuid.UUID]): uid of the public key to load


  • bytes: the bytes of the public key

Last updated