concrete.ml.quantization.quantized_module_passes
Optimization passes for QuantizedModules.
PowerOfTwoScalingRoundPBSAdapter
Detect neural network patterns that can be optimized with round PBS.
__init__
property num_ignored_valid_patterns
Get the number of optimizable patterns that were ignored.
Patterns could be ignored since a number of rounding bits was set manually through the compilation function.
Returns:
result
(int): number of patterns that could be optimized but were not
compute_op_predecessors
Compute the predecessors for each QuantizedOp in a QuantizedModule.
Stores, for each quantized op, a list of quantized ops that produce its inputs. Currently only the first input of the operations is considered as it is, usually, the encrypted input.
Returns:
result
(PredecessorsType): a dictionary containing a hierarchy of op predecessors
detect_patterns
Detect the patterns that can be optimized with roundPBS in the QuantizedModule.
Args:
predecessors
(PredecessorsType): Module predecessor operation list
Returns:
result
(PatternDict): list of optimizable patterns
match_path_pattern
Determine if a pattern has the structure that makes it viable for roundPBS.
Args:
predecessors
(PredecessorsType): Module predecessor operation list
nodes_in_path
(List[QuantizedOp]): list of quantized ops in the pattern
input_producer_of_path
(Optional[QuantizedOp]): operation that produces the input
Returns:
result
(bool): whether the pattern can be optimized
process
Analyze an ONNX graph and detect Gemm/Conv patterns that can use RoundPBS.
We want to detect a gemm/conv node whose weights/bias are Brevitas QAT, and whose input is produced by a Brevitas QAT node that is applied on the output of another Gemm/conv node. Optionally a Relu can be placed before this input quantization node.
Nothing will be done if rounding is already specified.
Returns:
result
(PatternDict): a dictionary containing for each Conv/Gemm node for which round PBS can be applied based on power-of-two scaling factors
process_patterns
Configure the rounding bits of roundPBS for the optimizable operations.
Args:
valid_paths
(PatternDict): list of optimizable patterns
Returns:
result
(PatternDict): list of patterns actually optimized with roundPBS