MLPRegressorclasses. The built-in fully connected neural network (FCNN) models train easily with a single call to
.fit(), which will automatically quantize the weights and activations.
NeuralNetClassifierprovide scikit-learn like models, their architecture is somewhat restricted in order to make training easy and robust. If you need more advanced models you can convert custom neural networks, as described in the FHE-friendly models documentation.
NeuralNetRegressorclasses and configure a number of parameters that are passed to their constructor. Note that some parameters need to be prefixed by
module__, while others don't. Basically, the parameters that are related to the model, i.e. the underlying
nn.Module, must have the prefix. The parameters that are related to training options do not require the prefix.
module__n_layers: number of layers in the FCNN, must be at least 1
module__n_outputs: number of outputs (classes or targets)
module__input_dim: dimensionality of the input
n_w_bits(default 3): number of bits for weights
n_a_bits(default 3): number of bits for activations and inputs
max_epochs: The number of epochs to train the network (default 10),
verbose: Whether to log loss/metrics during training (default: False)
lr: Learning rate (default 0.001)
module__n_hidden_neurons_multiplier: The number of hidden neurons will be automatically set proportional to the dimensionality of the input (i.e. the vlaue for
module__input_dim). This parameter controls the proportionality factor, and is by default set to 4. This value gives good accuracy while avoiding accumulator overflow.
n_hidden_neurons_multiplierparameter influences training accuracy as it controls the number of non-zero neurons that are allowed in each layer. Increasing
n_hidden_neurons_multiplierimproves accuracy, but should take into account precision limitations to avoid overflow in the accumulator. The default value is a good compromise that avoids overflow, in most cases, but you may want to change the value of this parameter to reduce the breadth of the network if you have overflow errors. A value of 1 should be completely safe with respect to overflow.