A torch to numpy module.

Global Variables


class NumpyModule

General interface to transform a torch.nn.Module to numpy module.


  • torch_model (Union[nn.Module, onnx.ModelProto]): A fully trained, torch model along with its parameters or the onnx graph of the model.

  • dummy_input (Union[torch.Tensor, Tuple[torch.Tensor, ...]]): Sample tensors for all the module inputs, used in the ONNX export to get a simple to manipulate nn representation.

  • debug_onnx_output_file_path: (Optional[Union[Path, str]], optional): An optional path to indicate where to save the ONNX file exported by torch for debug. Defaults to None.

method __init__

    model: Union[Module, ModelProto],
    dummy_input: Optional[Tensor, Tuple[Tensor, ]] = None,
    debug_onnx_output_file_path: Optional[Path, str] = None

property onnx_model

Get the ONNX model.

.. # noqa: DAR201


  • _onnx_model (onnx.ModelProto): the ONNX model

method forward

forward(*args: ndarray) → Union[ndarray, Tuple[ndarray, ]]

Apply a forward pass on args with the equivalent numpy function only.


  • *args: the inputs of the forward function


  • Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]: result of the forward on the given inputs

Last updated