vollo_torch.fx

vollo_torch.fx.save(module: torch.nn.modules.module.Module, input: torch.Tensor, archive: Union[str, pathlib.Path])

Trace and save the Torch FX representation of a PyTorch model to an archive file.

This wraps torch.fx.GraphModule.to_folder to makes it only use relative paths in the saved code, making it more portable.

Parameters:
  • module – PyTorch model to save.

  • input – Input tensor for the model, needed to trace the model with concrete shapes.

  • archive – Archive file to save model as. The arhive will be a valid .tar.gz file.

vollo_torch.fx.load(archive: Union[str, pathlib.Path]) Tuple[torch.nn.modules.module.Module, torch.Size]

Load a PyTorch model that has been saved to a folder with save().

Parameters:

archive – Archive file that a model has been saved as.

Returns:

A tuple (module, input_shape) where module is the loaded PyTorch model and input_shape is the shape of the input that the model was traced with.