concrete.ml.torch.hybrid_model
Implement the conversion of a torch model to a hybrid fhe/torch inference.
MAX_BITWIDTH_BACKWARD_COMPATIBLE
tuple_to_underscore_str
Convert a tuple to a string representation.
Args:
tup
(Tuple): a tuple to change into string representation
Returns:
str
: a string representing the tuple
underscore_str_to_tuple
Convert a a string representation of a tuple to a tuple.
Args:
tup
(str): a string representing the tuple
Returns:
Tuple
: a tuple to change into string representation
convert_conv1d_to_linear
Convert all Conv1D layers in a module or a Conv1D layer itself to nn.Linear.
Args:
layer_or_module
(nn.Module or Conv1D): The module which will be recursively searched for Conv1D layers, or a Conv1D layer itself.
Returns:
nn.Module or nn.Linear
: The updated module with Conv1D layers converted to Linear layers, or the Conv1D layer converted to a Linear layer.
HybridFHEMode
Simple enum for different modes of execution of HybridModel.
RemoteModule
A wrapper class for the modules to be evaluated remotely with FHE.
__init__
forward
Forward pass of the remote module.
To change the behavior of this forward function one must change the fhe_local_mode attribute. Choices are:
disable: forward using torch module
remote: forward with fhe client-server
simulate: forward with local fhe simulation
calibrate: forward for calibration
Args:
x
(torch.Tensor): The input tensor.
Returns:
(torch.Tensor)
: The output tensor.
Raises:
ValueError
: if local_fhe_mode is not supported
init_fhe_client
Set the clients keys.
Args:
path_to_client
(str): Path where the client.zip is located.
path_to_keys
(str): Path where keys are located.
Raises:
ValueError
: if anything goes wrong with the server.
remote_call
Call the remote server to get the private module inference.
Args:
x
(torch.Tensor): The input tensor.
Returns:
torch.Tensor
: The result of the FHE computation
HybridFHEModel
Convert a model to a hybrid model.
This is done by converting targeted modules by RemoteModules. This will modify the model in place.
Args:
model
(nn.Module): The model to modify (in-place modification)
module_names
(Union[str, List[str]]): The module name(s) to replace with FHE server.
server_remote_address)
: The remote address of the FHE server
model_name
(str): Model name identifier
verbose
(int): If logs should be printed when interacting with FHE server
__init__
compile_model
Compiles the specific layers to FHE.
Args:
x
(torch.Tensor): The input tensor for the model. This is used to run the model once for calibration.
n_bits
(int): The bit precision for quantization during FHE model compilation. Default is 8.
rounding_threshold_bits
(int): The number of bits to use for rounding threshold during FHE model compilation. Default is 8.
p_error
(float): Error allowed for each table look-up in the circuit.
configuration
(Configuration): A concrete Configuration object specifying the FHE encryption parameters. If not specified, a default configuration is used.
init_client
Initialize client for all remote modules.
Args:
path_to_clients
(Optional[Path]): Path to the client.zip files.
path_to_keys
(Optional[Path]): Path to the keys folder.
publish_to_hub
Allow the user to push the model and FHE required files to HF Hub.
save_and_clear_private_info
Save the PyTorch model to the provided path and also saves the corresponding FHE circuit.
Args:
path
(Path): The directory where the model and the FHE circuit will be saved.
via_mlir
(bool): if fhe circuits should be serialized using via_mlir option useful for cross-platform (compile on one architecture and run on another)
set_fhe_mode
Set Hybrid FHE mode for all remote modules.
Args:
hybrid_fhe_mode
(Union[str, HybridFHEMode]): Hybrid FHE mode to set to all remote modules.
LoggerStub
Placeholder type for a typical logger like the one from loguru.
info
Placholder function for logger.info.
Args:
msg
(str): the message to output
HybridFHEModelServer
Hybrid FHE Model Server.
This is a class object to server FHE models serialized using HybridFHEModel.
__init__
add_key
Add public key.
Arguments:
key
(bytes): public key
model_name
(str): model name
module_name
(str): name of the module in the model
input_shape
(str): input shape of said module
Returns: Dict[str, str] - uid: uid a personal uid
check_inputs
Check that the given configuration exist in the compiled models folder.
Args:
model_name
(str): name of the model
module_name
(Optional[str]): name of the module in the model
input_shape
(Optional[str]): input shape of the module
Raises:
ValueError
: if the given configuration does not exist.
compute
Compute the circuit over encrypted input.
Arguments:
model_input
(bytes): input of the circuit
uid
(str): uid of the public key to use
model_name
(str): model name
module_name
(str): name of the module in the model
input_shape
(str): input shape of said module
Returns:
bytes
: the result of the circuit
dump_key
Dump a public key to a stream.
Args:
key_bytes
(bytes): stream to dump the public serialized key to
uid
(Union[str, uuid.UUID]): uid of the public key to dump
get_circuit
Get circuit based on model name, module name and input shape.
Args:
model_name
(str): name of the model
module_name
(str): name of the module in the model
input_shape
(str): input shape of the module
Returns:
FHEModelServer
: a fhe model server of the given module of the given model for the given shape
get_client
Get client.
Args:
model_name
(str): name of the model
module_name
(str): name of the module in the model
input_shape
(str): input shape of the module
Returns:
Path
: the path to the correct client
Raises:
ValueError
: if client couldn't be found
list_modules
List all modules in a model.
Args:
model_name
(str): name of the model
Returns: Dict[str, Dict[str, Dict]]
list_shapes
List all modules in a model.
Args:
model_name
(str): name of the model
module_name
(str): name of the module in the model
Returns: Dict[str, Dict]
load_key
Load a public key from the key path in the file system.
Args:
uid
(Union[str, uuid.UUID]): uid of the public key to load
Returns:
bytes
: the bytes of the public key