concrete.ml.torch.hybrid_model.md
module concrete.ml.torch.hybrid_model
concrete.ml.torch.hybrid_model
Implement the conversion of a torch model to a hybrid fhe/torch inference.
Global Variables
MAX_BITWIDTH_BACKWARD_COMPATIBLE
function tuple_to_underscore_str
tuple_to_underscore_str
Convert a tuple to a string representation.
Args:
tup
(Tuple): a tuple to change into string representation
Returns:
str
: a string representing the tuple
function underscore_str_to_tuple
underscore_str_to_tuple
Convert a a string representation of a tuple to a tuple.
Args:
tup
(str): a string representing the tuple
Returns:
Tuple
: a tuple to change into string representation
function convert_conv1d_to_linear
convert_conv1d_to_linear
Convert all Conv1D layers in a module or a Conv1D layer itself to nn.Linear.
Args:
layer_or_module
(nn.Module or Conv1D): The module which will be recursively searched for Conv1D layers, or a Conv1D layer itself.
Returns:
nn.Module or nn.Linear
: The updated module with Conv1D layers converted to Linear layers, or the Conv1D layer converted to a Linear layer.
class HybridFHEMode
HybridFHEMode
Simple enum for different modes of execution of HybridModel.
class RemoteModule
RemoteModule
A wrapper class for the modules to be evaluated remotely with FHE.
method __init__
__init__
method forward
forward
Forward pass of the remote module.
To change the behavior of this forward function one must change the fhe_local_mode attribute. Choices are:
disable: forward using torch module
remote: forward with fhe client-server
simulate: forward with local fhe simulation
calibrate: forward for calibration
Args:
x
(torch.Tensor): The input tensor.
Returns:
(torch.Tensor)
: The output tensor.
Raises:
ValueError
: if local_fhe_mode is not supported
method init_fhe_client
init_fhe_client
Set the clients keys.
Args:
path_to_client
(str): Path where the client.zip is located.path_to_keys
(str): Path where keys are located.
Raises:
ValueError
: if anything goes wrong with the server.
method remote_call
remote_call
Call the remote server to get the private module inference.
Args:
x
(torch.Tensor): The input tensor.
Returns:
torch.Tensor
: The result of the FHE computation
class HybridFHEModel
HybridFHEModel
Convert a model to a hybrid model.
This is done by converting targeted modules by RemoteModules. This will modify the model in place.
Args:
model
(nn.Module): The model to modify (in-place modification)module_names
(Union[str, List[str]]): The module name(s) to replace with FHE server.server_remote_address)
: The remote address of the FHE servermodel_name
(str): Model name identifierverbose
(int): If logs should be printed when interacting with FHE server
method __init__
__init__
method compile_model
compile_model
Compiles the specific layers to FHE.
Args:
x
(torch.Tensor): The input tensor for the model. This is used to run the model once for calibration.n_bits
(int): The bit precision for quantization during FHE model compilation. Default is 8.rounding_threshold_bits
(int): The number of bits to use for rounding threshold during FHE model compilation. Default is 8.p_error
(float): Error allowed for each table look-up in the circuit.configuration
(Configuration): A concrete Configuration object specifying the FHE encryption parameters. If not specified, a default configuration is used.
method init_client
init_client
Initialize client for all remote modules.
Args:
path_to_clients
(Optional[Path]): Path to the client.zip files.path_to_keys
(Optional[Path]): Path to the keys folder.
method publish_to_hub
publish_to_hub
Allow the user to push the model and FHE required files to HF Hub.
method save_and_clear_private_info
save_and_clear_private_info
Save the PyTorch model to the provided path and also saves the corresponding FHE circuit.
Args:
path
(Path): The directory where the model and the FHE circuit will be saved.via_mlir
(bool): if fhe circuits should be serialized using via_mlir option useful for cross-platform (compile on one architecture and run on another)
method set_fhe_mode
set_fhe_mode
Set Hybrid FHE mode for all remote modules.
Args:
hybrid_fhe_mode
(Union[str, HybridFHEMode]): Hybrid FHE mode to set to all remote modules.
class LoggerStub
LoggerStub
Placeholder type for a typical logger like the one from loguru.
method info
info
Placholder function for logger.info.
Args:
msg
(str): the message to output
class HybridFHEModelServer
HybridFHEModelServer
Hybrid FHE Model Server.
This is a class object to server FHE models serialized using HybridFHEModel.
method __init__
__init__
method add_key
add_key
Add public key.
Arguments:
key
(bytes): public keymodel_name
(str): model namemodule_name
(str): name of the module in the modelinput_shape
(str): input shape of said module
Returns: Dict[str, str] - uid: uid a personal uid
method check_inputs
check_inputs
Check that the given configuration exist in the compiled models folder.
Args:
model_name
(str): name of the modelmodule_name
(Optional[str]): name of the module in the modelinput_shape
(Optional[str]): input shape of the module
Raises:
ValueError
: if the given configuration does not exist.
method compute
compute
Compute the circuit over encrypted input.
Arguments:
model_input
(bytes): input of the circuituid
(str): uid of the public key to usemodel_name
(str): model namemodule_name
(str): name of the module in the modelinput_shape
(str): input shape of said module
Returns:
bytes
: the result of the circuit
method dump_key
dump_key
Dump a public key to a stream.
Args:
key_bytes
(bytes): stream to dump the public serialized key touid
(Union[str, uuid.UUID]): uid of the public key to dump
method get_circuit
get_circuit
Get circuit based on model name, module name and input shape.
Args:
model_name
(str): name of the modelmodule_name
(str): name of the module in the modelinput_shape
(str): input shape of the module
Returns:
FHEModelServer
: a fhe model server of the given module of the given model for the given shape
method get_client
get_client
Get client.
Args:
model_name
(str): name of the modelmodule_name
(str): name of the module in the modelinput_shape
(str): input shape of the module
Returns:
Path
: the path to the correct client
Raises:
ValueError
: if client couldn't be found
method list_modules
list_modules
List all modules in a model.
Args:
model_name
(str): name of the model
Returns: Dict[str, Dict[str, Dict]]
method list_shapes
list_shapes
List all modules in a model.
Args:
model_name
(str): name of the modelmodule_name
(str): name of the module in the model
Returns: Dict[str, Dict]
method load_key
load_key
Load a public key from the key path in the file system.
Args:
uid
(Union[str, uuid.UUID]): uid of the public key to load
Returns:
bytes
: the bytes of the public key
Last updated