concrete.ml.torch.hybrid_model.md
module concrete.ml.torch.hybrid_model
concrete.ml.torch.hybrid_model
Implement the conversion of a torch model to a hybrid fhe/torch inference.
Global Variables
MAX_BITWIDTH_BACKWARD_COMPATIBLE
function tuple_to_underscore_str
tuple_to_underscore_str
tuple_to_underscore_str(tup: Tuple) → str
Convert a tuple to a string representation.
Args:
tup
(Tuple): a tuple to change into string representation
Returns:
str
: a string representing the tuple
function underscore_str_to_tuple
underscore_str_to_tuple
underscore_str_to_tuple(tup: str) → Tuple
Convert a a string representation of a tuple to a tuple.
Args:
tup
(str): a string representing the tuple
Returns:
Tuple
: a tuple to change into string representation
function convert_conv1d_to_linear
convert_conv1d_to_linear
convert_conv1d_to_linear(layer_or_module)
Convert all Conv1D layers in a module or a Conv1D layer itself to nn.Linear.
Args:
layer_or_module
(nn.Module or Conv1D): The module which will be recursively searched for Conv1D layers, or a Conv1D layer itself.
Returns:
nn.Module or nn.Linear
: The updated module with Conv1D layers converted to Linear layers, or the Conv1D layer converted to a Linear layer.
class HybridFHEMode
HybridFHEMode
Simple enum for different modes of execution of HybridModel.
class RemoteModule
RemoteModule
A wrapper class for the modules to be evaluated remotely with FHE.
method __init__
__init__
__init__(
module: Optional[Module] = None,
server_remote_address: Optional[str] = None,
module_name: Optional[str] = None,
model_name: Optional[str] = None,
verbose: int = 0
)
method forward
forward
forward(x: Tensor) → Union[Tensor, QuantTensor]
Forward pass of the remote module.
To change the behavior of this forward function one must change the fhe_local_mode attribute. Choices are:
disable: forward using torch module
remote: forward with fhe client-server
simulate: forward with local fhe simulation
calibrate: forward for calibration
Args:
x
(torch.Tensor): The input tensor.
Returns:
(torch.Tensor)
: The output tensor.
Raises:
ValueError
: if local_fhe_mode is not supported
method init_fhe_client
init_fhe_client
init_fhe_client(
path_to_client: Optional[Path] = None,
path_to_keys: Optional[Path] = None
)
Set the clients keys.
Args:
path_to_client
(str): Path where the client.zip is located.path_to_keys
(str): Path where keys are located.
Raises:
ValueError
: if anything goes wrong with the server.
method remote_call
remote_call
remote_call(x: Tensor) → Tensor
Call the remote server to get the private module inference.
Args:
x
(torch.Tensor): The input tensor.
Returns:
torch.Tensor
: The result of the FHE computation
class HybridFHEModel
HybridFHEModel
Convert a model to a hybrid model.
This is done by converting targeted modules by RemoteModules. This will modify the model in place.
Args:
model
(nn.Module): The model to modify (in-place modification)module_names
(Union[str, List[str]]): The module name(s) to replace with FHE server.server_remote_address)
: The remote address of the FHE servermodel_name
(str): Model name identifierverbose
(int): If logs should be printed when interacting with FHE server
method __init__
__init__
__init__(
model: Module,
module_names: Union[str, List[str]],
server_remote_address=None,
model_name: str = 'model',
verbose: int = 0
)
method compile_model
compile_model
compile_model(
x: Tensor,
n_bits: Union[int, Dict[str, int]] = 8,
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
configuration: Optional[Configuration] = None
)
Compiles the specific layers to FHE.
Args:
x
(torch.Tensor): The input tensor for the model. This is used to run the model once for calibration.n_bits
(int): The bit precision for quantization during FHE model compilation. Default is 8.rounding_threshold_bits
(int): The number of bits to use for rounding threshold during FHE model compilation. Default is 8.p_error
(float): Error allowed for each table look-up in the circuit.configuration
(Configuration): A concrete Configuration object specifying the FHE encryption parameters. If not specified, a default configuration is used.
method init_client
init_client
init_client(
path_to_clients: Optional[Path] = None,
path_to_keys: Optional[Path] = None
)
Initialize client for all remote modules.
Args:
path_to_clients
(Optional[Path]): Path to the client.zip files.path_to_keys
(Optional[Path]): Path to the keys folder.
method publish_to_hub
publish_to_hub
publish_to_hub()
Allow the user to push the model and FHE required files to HF Hub.
method save_and_clear_private_info
save_and_clear_private_info
save_and_clear_private_info(path: Path, via_mlir=False)
Save the PyTorch model to the provided path and also saves the corresponding FHE circuit.
Args:
path
(Path): The directory where the model and the FHE circuit will be saved.via_mlir
(bool): if fhe circuits should be serialized using via_mlir option useful for cross-platform (compile on one architecture and run on another)
method set_fhe_mode
set_fhe_mode
set_fhe_mode(hybrid_fhe_mode: Union[str, HybridFHEMode])
Set Hybrid FHE mode for all remote modules.
Args:
hybrid_fhe_mode
(Union[str, HybridFHEMode]): Hybrid FHE mode to set to all remote modules.
class LoggerStub
LoggerStub
Placeholder type for a typical logger like the one from loguru.
method info
info
info(msg: str)
Placholder function for logger.info.
Args:
msg
(str): the message to output
class HybridFHEModelServer
HybridFHEModelServer
Hybrid FHE Model Server.
This is a class object to server FHE models serialized using HybridFHEModel.
method __init__
__init__
__init__(key_path: Path, model_dir: Path, logger: Optional[LoggerStub])
method add_key
add_key
add_key(key: bytes, model_name: str, module_name: str, input_shape: str)
Add public key.
Arguments:
key
(bytes): public keymodel_name
(str): model namemodule_name
(str): name of the module in the modelinput_shape
(str): input shape of said module
Returns: Dict[str, str] - uid: uid a personal uid
method check_inputs
check_inputs
check_inputs(
model_name: str,
module_name: Optional[str],
input_shape: Optional[str]
)
Check that the given configuration exist in the compiled models folder.
Args:
model_name
(str): name of the modelmodule_name
(Optional[str]): name of the module in the modelinput_shape
(Optional[str]): input shape of the module
Raises:
ValueError
: if the given configuration does not exist.
method compute
compute
compute(
model_input: bytes,
uid: str,
model_name: str,
module_name: str,
input_shape: str
)
Compute the circuit over encrypted input.
Arguments:
model_input
(bytes): input of the circuituid
(str): uid of the public key to usemodel_name
(str): model namemodule_name
(str): name of the module in the modelinput_shape
(str): input shape of said module
Returns:
bytes
: the result of the circuit
method dump_key
dump_key
dump_key(key_bytes: bytes, uid: Union[UUID, str]) → None
Dump a public key to a stream.
Args:
key_bytes
(bytes): stream to dump the public serialized key touid
(Union[str, uuid.UUID]): uid of the public key to dump
method get_circuit
get_circuit
get_circuit(model_name, module_name, input_shape)
Get circuit based on model name, module name and input shape.
Args:
model_name
(str): name of the modelmodule_name
(str): name of the module in the modelinput_shape
(str): input shape of the module
Returns:
FHEModelServer
: a fhe model server of the given module of the given model for the given shape
method get_client
get_client
get_client(model_name: str, module_name: str, input_shape: str)
Get client.
Args:
model_name
(str): name of the modelmodule_name
(str): name of the module in the modelinput_shape
(str): input shape of the module
Returns:
Path
: the path to the correct client
Raises:
ValueError
: if client couldn't be found
method list_modules
list_modules
list_modules(model_name: str)
List all modules in a model.
Args:
model_name
(str): name of the model
Returns: Dict[str, Dict[str, Dict]]
method list_shapes
list_shapes
list_shapes(model_name: str, module_name: str)
List all modules in a model.
Args:
model_name
(str): name of the modelmodule_name
(str): name of the module in the model
Returns: Dict[str, Dict]
method load_key
load_key
load_key(uid: Union[str, UUID]) → bytes
Load a public key from the key path in the file system.
Args:
uid
(Union[str, uuid.UUID]): uid of the public key to load
Returns:
bytes
: the bytes of the public key
Last updated
Was this helpful?