concrete.ml.torch.hybrid_model.md
module concrete.ml.torch.hybrid_model
concrete.ml.torch.hybrid_modelImplement the conversion of a torch model to a hybrid fhe/torch inference.
Global Variables
MAX_BITWIDTH_BACKWARD_COMPATIBLE
function tuple_to_underscore_str
tuple_to_underscore_strtuple_to_underscore_str(tup: Tuple) → strConvert a tuple to a string representation.
Args:
tup(Tuple): a tuple to change into string representation
Returns:
str: a string representing the tuple
function underscore_str_to_tuple
underscore_str_to_tupleunderscore_str_to_tuple(tup: str) → TupleConvert a a string representation of a tuple to a tuple.
Args:
tup(str): a string representing the tuple
Returns:
Tuple: a tuple to change into string representation
function convert_conv1d_to_linear
convert_conv1d_to_linearconvert_conv1d_to_linear(layer_or_module)Convert all Conv1D layers in a module or a Conv1D layer itself to nn.Linear.
Args:
layer_or_module(nn.Module or Conv1D): The module which will be recursively searched for Conv1D layers, or a Conv1D layer itself.
Returns:
nn.Module or nn.Linear: The updated module with Conv1D layers converted to Linear layers, or the Conv1D layer converted to a Linear layer.
class HybridFHEMode
HybridFHEModeSimple enum for different modes of execution of HybridModel.
class RemoteModule
RemoteModuleA wrapper class for the modules to be evaluated remotely with FHE.
method __init__
__init____init__(
module: Optional[Module] = None,
server_remote_address: Optional[str] = None,
module_name: Optional[str] = None,
model_name: Optional[str] = None,
verbose: int = 0
)method forward
forwardforward(x: Tensor) → Union[Tensor, QuantTensor]Forward pass of the remote module.
To change the behavior of this forward function one must change the fhe_local_mode attribute. Choices are:
disable: forward using torch module
remote: forward with fhe client-server
simulate: forward with local fhe simulation
calibrate: forward for calibration
Args:
x(torch.Tensor): The input tensor.
Returns:
(torch.Tensor): The output tensor.
Raises:
ValueError: if local_fhe_mode is not supported
method init_fhe_client
init_fhe_clientinit_fhe_client(
path_to_client: Optional[Path] = None,
path_to_keys: Optional[Path] = None
)Set the clients keys.
Args:
path_to_client(str): Path where the client.zip is located.path_to_keys(str): Path where keys are located.
Raises:
ValueError: if anything goes wrong with the server.
method remote_call
remote_callremote_call(x: Tensor) → TensorCall the remote server to get the private module inference.
Args:
x(torch.Tensor): The input tensor.
Returns:
torch.Tensor: The result of the FHE computation
class HybridFHEModel
HybridFHEModelConvert a model to a hybrid model.
This is done by converting targeted modules by RemoteModules. This will modify the model in place.
Args:
model(nn.Module): The model to modify (in-place modification)module_names(Union[str, List[str]]): The module name(s) to replace with FHE server.server_remote_address): The remote address of the FHE servermodel_name(str): Model name identifierverbose(int): If logs should be printed when interacting with FHE server
method __init__
__init____init__(
model: Module,
module_names: Union[str, List[str]],
server_remote_address=None,
model_name: str = 'model',
verbose: int = 0
)method compile_model
compile_modelcompile_model(
x: Tensor,
n_bits: Union[int, Dict[str, int]] = 8,
rounding_threshold_bits: Optional[int] = None,
p_error: Optional[float] = None,
configuration: Optional[Configuration] = None
)Compiles the specific layers to FHE.
Args:
x(torch.Tensor): The input tensor for the model. This is used to run the model once for calibration.n_bits(int): The bit precision for quantization during FHE model compilation. Default is 8.rounding_threshold_bits(int): The number of bits to use for rounding threshold during FHE model compilation. Default is 8.p_error(float): Error allowed for each table look-up in the circuit.configuration(Configuration): A concrete Configuration object specifying the FHE encryption parameters. If not specified, a default configuration is used.
method init_client
init_clientinit_client(
path_to_clients: Optional[Path] = None,
path_to_keys: Optional[Path] = None
)Initialize client for all remote modules.
Args:
path_to_clients(Optional[Path]): Path to the client.zip files.path_to_keys(Optional[Path]): Path to the keys folder.
method publish_to_hub
publish_to_hubpublish_to_hub()Allow the user to push the model and FHE required files to HF Hub.
method save_and_clear_private_info
save_and_clear_private_infosave_and_clear_private_info(path: Path, via_mlir=False)Save the PyTorch model to the provided path and also saves the corresponding FHE circuit.
Args:
path(Path): The directory where the model and the FHE circuit will be saved.via_mlir(bool): if fhe circuits should be serialized using via_mlir option useful for cross-platform (compile on one architecture and run on another)
method set_fhe_mode
set_fhe_modeset_fhe_mode(hybrid_fhe_mode: Union[str, HybridFHEMode])Set Hybrid FHE mode for all remote modules.
Args:
hybrid_fhe_mode(Union[str, HybridFHEMode]): Hybrid FHE mode to set to all remote modules.
class LoggerStub
LoggerStubPlaceholder type for a typical logger like the one from loguru.
method info
infoinfo(msg: str)Placholder function for logger.info.
Args:
msg(str): the message to output
class HybridFHEModelServer
HybridFHEModelServerHybrid FHE Model Server.
This is a class object to server FHE models serialized using HybridFHEModel.
method __init__
__init____init__(key_path: Path, model_dir: Path, logger: Optional[LoggerStub])method add_key
add_keyadd_key(key: bytes, model_name: str, module_name: str, input_shape: str)Add public key.
Arguments:
key(bytes): public keymodel_name(str): model namemodule_name(str): name of the module in the modelinput_shape(str): input shape of said module
Returns: Dict[str, str] - uid: uid a personal uid
method check_inputs
check_inputscheck_inputs(
model_name: str,
module_name: Optional[str],
input_shape: Optional[str]
)Check that the given configuration exist in the compiled models folder.
Args:
model_name(str): name of the modelmodule_name(Optional[str]): name of the module in the modelinput_shape(Optional[str]): input shape of the module
Raises:
ValueError: if the given configuration does not exist.
method compute
computecompute(
model_input: bytes,
uid: str,
model_name: str,
module_name: str,
input_shape: str
)Compute the circuit over encrypted input.
Arguments:
model_input(bytes): input of the circuituid(str): uid of the public key to usemodel_name(str): model namemodule_name(str): name of the module in the modelinput_shape(str): input shape of said module
Returns:
bytes: the result of the circuit
method dump_key
dump_keydump_key(key_bytes: bytes, uid: Union[UUID, str]) → NoneDump a public key to a stream.
Args:
key_bytes(bytes): stream to dump the public serialized key touid(Union[str, uuid.UUID]): uid of the public key to dump
method get_circuit
get_circuitget_circuit(model_name, module_name, input_shape)Get circuit based on model name, module name and input shape.
Args:
model_name(str): name of the modelmodule_name(str): name of the module in the modelinput_shape(str): input shape of the module
Returns:
FHEModelServer: a fhe model server of the given module of the given model for the given shape
method get_client
get_clientget_client(model_name: str, module_name: str, input_shape: str)Get client.
Args:
model_name(str): name of the modelmodule_name(str): name of the module in the modelinput_shape(str): input shape of the module
Returns:
Path: the path to the correct client
Raises:
ValueError: if client couldn't be found
method list_modules
list_moduleslist_modules(model_name: str)List all modules in a model.
Args:
model_name(str): name of the model
Returns: Dict[str, Dict[str, Dict]]
method list_shapes
list_shapeslist_shapes(model_name: str, module_name: str)List all modules in a model.
Args:
model_name(str): name of the modelmodule_name(str): name of the module in the model
Returns: Dict[str, Dict]
method load_key
load_keyload_key(uid: Union[str, UUID]) → bytesLoad a public key from the key path in the file system.
Args:
uid(Union[str, uuid.UUID]): uid of the public key to load
Returns:
bytes: the bytes of the public key
Last updated
Was this helpful?