vollo_torch.nn

class vollo_torch.nn.PaddedConv1d(in_channels, out_channels, kernel_size, dilation=1, split_into=1)

Bases: Module

torch.nn.Conv1d with left-padding such that the output sequence length is the same as the input sequence length

Can be split into multiple nodes in the compiler by setting split_into > 1

class vollo_torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False)

Bases: Module

torch.nn.LSTM with zeros as the h_0 and c_0 initial states, and which discards the output h_n and c_n states.

class vollo_torch.nn.RecurrentStateLSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False)

Bases: Module

A version of torch.nn.LSTM which is stateful across forward passes.

The h_0 and c_0 initial states are initialized to zeros. The output h_n and c_n on each forward pass are stored internally, and are used as the initial h_0 and c_0 for the next forward pass.

Note

This module must be streaming transformed.

class vollo_torch.nn.LSTMCell(input_size, hidden_size, batch_size=None)

Bases: Module

torch.nn.LSTMCell but only takes the hidden state as an explicit input/output.

The cell state is initialised as zeros, kept internal, and updated each time this module is called. Since the cell state is kept internal and may have a batch size, this module needs to be instantiated with a batch size argument if using a batch.

Warning

This class is experimental and is likely to change in future versions.

forward(x, hidden_state)

Applies torch.nn.LSTMCell to the input and hidden state, returning the updated hidden state and internally updating its cell state.

Parameters:
  • x (torch.Tensor) – of shape (batch_size, input_size) or (input_size).

  • hidden_state (torch.Tensor) – of shape (batch_size, hidden_size) or (hidden_size).

Returns:

Hidden state of shape (batch_size, hidden_size) or (hidden_size).

Return type:

torch.Tensor

reset()

Reset the cell state back to the initial state (zeros).

class vollo_torch.nn.Ones

Bases: Module

Instances of this class behaves similarly to torch.ones, but with the ability to make torch.fx not trace through it.

class vollo_torch.nn.Zeros

Bases: Module

Instances of this class behaves similarly to torch.zeros, but with the ability to make torch.fx not trace through it.

class vollo_torch.nn.Scan(step)

Bases: Module

A scan operator that can express a subset of for loops.

Unlike for loops, the Vollo compiler recognises Scans as being the repeated application of some function, and furthermore it can Streaming Transform a Scan along its axis.

This module needs to be Streaming Transformed to be compiled to a Vollo program.

Parameters:

step (nn.Module) – The PyTorch module to call for each subtensor (along the specified axis) of the Scan operator’s input tensor. The step module’s forward method should take an input tensor and an input “state” tensor, and return an updated “state” tensor.

Example

>>> import numpy as np
>>> import torch
>>> import torch.nn as nn
>>> import torch.nn.functional as F
>>>
>>> import vollo_compiler
>>> import vollo_torch
>>>
>>> _ = torch.manual_seed(0)
>>>
>>> class ScannerLSTMStep(nn.Module):
...     def __init__(self, input_size, hidden_size):
...         super().__init__()
...         self.cell = vollo_torch.nn.LSTMCell(input_size, hidden_size)
...
...     def forward(self, x, state):
...         return self.cell(x, state)
...
>>> class ScannerLSTM(nn.Module):
...     def __init__(self, input_size, hidden_size):
...         super().__init__()
...         self.step = ScannerLSTMStep(input_size, hidden_size)
...         self.h_0 = nn.Parameter(torch.zeros(hidden_size))
...
...         self.scan = vollo_torch.nn.Scan(self.step)
...
...     def forward(self, x):
...         return self.scan(x, self.h_0)
...
>>> input_size, hidden_size, sequence_length = (128, 128, 100)
>>> model = ScannerLSTM(input_size, hidden_size)
>>> test_input = torch.randn(sequence_length, input_size)
>>> model, torch_output = vollo_torch.fx.prepare_shape(model, test_input)
>>> nnir = vollo_torch.fx.nnir.to_nnir(model)
>>> input_streaming_dimension = 0
>>> nnir, output_streaming_dimension = nnir.streaming_transform(input_streaming_dimension)
>>> program = nnir.to_program(vollo_compiler.Config(num_cores=6, block_size=32))
>>> vm = program.to_vm(bit_accurate=False)
>>> vm_output = vm.run_timesteps(test_input.numpy(), input_streaming_dimension, output_streaming_dimension)
>>> np.testing.assert_allclose(torch_output, vm_output, atol=1e-5, rtol=1e-5)

Example

>>> class SSMStep(nn.Module):
...     def __init__(self, input_size, hidden_size):
...         super().__init__()
...         self.linear_A = nn.Linear(hidden_size, hidden_size, bias=False)
...         self.linear_B = nn.Linear(input_size, hidden_size, bias=False)
...
...     def forward(self, u, state):
...         return self.linear_A(state) + self.linear_B(u)
...
>>> class SSM(nn.Module):
...     "An S5-like SSM layer: https://arxiv.org/abs/2208.04933"
...
...     def __init__(self, input_size, hidden_size):
...         super().__init__()
...         self.step = SSMStep(input_size, hidden_size)
...         self.state_0 = nn.Parameter(torch.zeros(hidden_size))
...
...         self.scan = vollo_torch.nn.Scan(self.step)
...
...         self.linear_C = nn.Linear(hidden_size, input_size, bias=False)
...         self.linear_D = nn.Linear(input_size, input_size, bias=False)
...
...     def forward(self, u):
...         return self.linear_C(self.scan(u, self.state_0)) + self.linear_D(u)
...
>>> input_size, hidden_size, sequence_length = (128, 128, 100)
>>> model = SSM(input_size, hidden_size)
>>> test_input = torch.randn(sequence_length, input_size)
>>> model, torch_output = vollo_torch.fx.prepare_shape(model, test_input)
>>> nnir = vollo_torch.fx.nnir.to_nnir(model)
>>> input_streaming_dimension = 0
>>> nnir, output_streaming_dimension = nnir.streaming_transform(input_streaming_dimension)
>>> program = nnir.to_program(vollo_compiler.Config(num_cores=6, block_size=32))
>>> vm = program.to_vm(bit_accurate=False)
>>> vm_output = vm.run_timesteps(test_input.numpy(), input_streaming_dimension, output_streaming_dimension)
>>> np.testing.assert_allclose(torch_output, vm_output, atol=1e-5, rtol=1e-5)

Warning

This class is experimental and is likely to change in future versions.

forward(xs, state, axis=0)
Parameters:
  • xs (torch.Tensor) – Input tensor where one of its axes is the sequence axis.

  • state (torch.Tensor) – Initial state tensor of the Scan.

  • axis (int) – Sequence axis along which xs should be sliced and sequentially passed as inputs to the step module.

Returns:

Concatenation of the sequence of output state tensors along the Sequence axis.

Return type:

torch.Tensor

class vollo_torch.nn.RMSNorm(normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None)

Bases: Module

RMS Normalisation layer based on https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html.

Note that the default value of epsilon is no longer None, since vollo requires a fixed value when compiling to a program. Also note that only 1D normalized_shape is supported.

extra_repr() str

Extra information about the module.