vollo_torch.fx

vollo_torch.fx.save(module: Module, input: Tensor, archive: Union[str, Path])

Trace and save the Torch FX representation of a PyTorch model to an archive file.

This wraps torch.fx.GraphModule.to_folder to makes it only use relative paths in the saved code, making it more portable.

Parameters:
  • module – PyTorch model to save.

  • input – Input tensor for the model, needed to trace the model with concrete shapes.

  • archive – Archive file to save model as. The arhive will be a valid .tar.gz file.

vollo_torch.fx.load(archive: Union[str, Path]) Tuple[Module, Size]

Load a PyTorch model that has been saved to a folder with save().

Parameters:

archive – Archive file that a model has been saved as.

Returns:

A tuple (module, input_shape) where module is the loaded PyTorch model and input_shape is the shape of the input that the model was traced with.