vollo_torch.fx
- vollo_torch.fx.save(module: torch.nn.modules.module.Module, input: torch.Tensor, archive: Union[str, pathlib.Path])
Trace and save the Torch FX representation of a PyTorch model to an archive file.
This wraps
torch.fx.GraphModule.to_folder
to makes it only use relative paths in the saved code, making it more portable.- Parameters:
module – PyTorch model to save.
input – Input tensor for the model, needed to trace the model with concrete shapes.
archive – Archive file to save model as. The arhive will be a valid .tar.gz file.
- vollo_torch.fx.load(archive: Union[str, pathlib.Path]) Tuple[torch.nn.modules.module.Module, torch.Size]
Load a PyTorch model that has been saved to a folder with
save()
.- Parameters:
archive – Archive file that a model has been saved as.
- Returns:
A tuple (module, input_shape) where module is the loaded PyTorch model and input_shape is the shape of the input that the model was traced with.