vollo_torch.nn
- class vollo_torch.nn.PaddedConv1d(in_channels, out_channels, kernel_size, dilation=1, split_into=1)
Bases:
Module
torch.nn.Conv1d
with left-padding such that the output sequence length is the same as the input sequence lengthCan be split into multiple nodes in the compiler by setting
split_into
> 1
- class vollo_torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False)
Bases:
Module
torch.nn.LSTM
with zeros as the h_0 and c_0 initial states, and which discards the output h_n and c_n states.
- class vollo_torch.nn.RecurrentStateLSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False)
Bases:
Module
A version of
torch.nn.LSTM
which is stateful across forward passes.The h_0 and c_0 initial states are initialized to zeros. The output h_n and c_n on each forward pass are stored internally, and are used as the initial h_0 and c_0 for the next forward pass.
Note
This module must be streaming transformed.
- class vollo_torch.nn.LSTMCell(input_size, hidden_size, batch_size=None)
Bases:
Module
torch.nn.LSTMCell
but only takes the hidden state as an explicit input/output.The cell state is initialised as zeros, kept internal, and updated each time this module is called. Since the cell state is kept internal and may have a batch size, this module needs to be instantiated with a batch size argument if using a batch.
Warning
This class is experimental and is likely to change in future versions.
- forward(x, hidden_state)
Applies
torch.nn.LSTMCell
to the input and hidden state, returning the updated hidden state and internally updating its cell state.- Parameters:
x (torch.Tensor) – of shape (batch_size, input_size) or (input_size).
hidden_state (torch.Tensor) – of shape (batch_size, hidden_size) or (hidden_size).
- Returns:
Hidden state of shape (batch_size, hidden_size) or (hidden_size).
- Return type:
torch.Tensor
- reset()
Reset the cell state back to the initial state (zeros).
- class vollo_torch.nn.Ones
Bases:
Module
Instances of this class behaves similarly to torch.ones, but with the ability to make torch.fx not trace through it.
- class vollo_torch.nn.Zeros
Bases:
Module
Instances of this class behaves similarly to torch.zeros, but with the ability to make torch.fx not trace through it.
- class vollo_torch.nn.Scan(step)
Bases:
Module
A scan operator that can express a subset of
for
loops.Unlike
for
loops, the Vollo compiler recognises Scans as being the repeated application of some function, and furthermore it can Streaming Transform a Scan along its axis.This module needs to be Streaming Transformed to be compiled to a Vollo program.
- Parameters:
step (nn.Module) – The PyTorch module to call for each subtensor (along the specified axis) of the Scan operator’s input tensor. The step module’s
forward
method should take an input tensor and an input “state” tensor, and return an updated “state” tensor.
Example
>>> import numpy as np >>> import torch >>> import torch.nn as nn >>> import torch.nn.functional as F >>> >>> import vollo_compiler >>> import vollo_torch >>> >>> _ = torch.manual_seed(0) >>> >>> class ScannerLSTMStep(nn.Module): ... def __init__(self, input_size, hidden_size): ... super().__init__() ... self.cell = vollo_torch.nn.LSTMCell(input_size, hidden_size) ... ... def forward(self, x, state): ... return self.cell(x, state) ... >>> class ScannerLSTM(nn.Module): ... def __init__(self, input_size, hidden_size): ... super().__init__() ... self.step = ScannerLSTMStep(input_size, hidden_size) ... self.h_0 = nn.Parameter(torch.zeros(hidden_size)) ... ... self.scan = vollo_torch.nn.Scan(self.step) ... ... def forward(self, x): ... return self.scan(x, self.h_0) ... >>> input_size, hidden_size, sequence_length = (128, 128, 100) >>> model = ScannerLSTM(input_size, hidden_size) >>> test_input = torch.randn(sequence_length, input_size) >>> model, torch_output = vollo_torch.fx.prepare_shape(model, test_input) >>> nnir = vollo_torch.fx.nnir.to_nnir(model) >>> input_streaming_dimension = 0 >>> nnir, output_streaming_dimension = nnir.streaming_transform(input_streaming_dimension) >>> program = nnir.to_program(vollo_compiler.Config(num_cores=6, block_size=32)) >>> vm = program.to_vm(bit_accurate=False) >>> vm_output = vm.run_timesteps(test_input.numpy(), input_streaming_dimension, output_streaming_dimension) >>> np.testing.assert_allclose(torch_output, vm_output, atol=1e-5, rtol=1e-5)
Example
>>> class SSMStep(nn.Module): ... def __init__(self, input_size, hidden_size): ... super().__init__() ... self.linear_A = nn.Linear(hidden_size, hidden_size, bias=False) ... self.linear_B = nn.Linear(input_size, hidden_size, bias=False) ... ... def forward(self, u, state): ... return self.linear_A(state) + self.linear_B(u) ... >>> class SSM(nn.Module): ... "An S5-like SSM layer: https://arxiv.org/abs/2208.04933" ... ... def __init__(self, input_size, hidden_size): ... super().__init__() ... self.step = SSMStep(input_size, hidden_size) ... self.state_0 = nn.Parameter(torch.zeros(hidden_size)) ... ... self.scan = vollo_torch.nn.Scan(self.step) ... ... self.linear_C = nn.Linear(hidden_size, input_size, bias=False) ... self.linear_D = nn.Linear(input_size, input_size, bias=False) ... ... def forward(self, u): ... return self.linear_C(self.scan(u, self.state_0)) + self.linear_D(u) ... >>> input_size, hidden_size, sequence_length = (128, 128, 100) >>> model = SSM(input_size, hidden_size) >>> test_input = torch.randn(sequence_length, input_size) >>> model, torch_output = vollo_torch.fx.prepare_shape(model, test_input) >>> nnir = vollo_torch.fx.nnir.to_nnir(model) >>> input_streaming_dimension = 0 >>> nnir, output_streaming_dimension = nnir.streaming_transform(input_streaming_dimension) >>> program = nnir.to_program(vollo_compiler.Config(num_cores=6, block_size=32)) >>> vm = program.to_vm(bit_accurate=False) >>> vm_output = vm.run_timesteps(test_input.numpy(), input_streaming_dimension, output_streaming_dimension) >>> np.testing.assert_allclose(torch_output, vm_output, atol=1e-5, rtol=1e-5)
Warning
This class is experimental and is likely to change in future versions.
- forward(xs, state, axis=0)
- Parameters:
xs (torch.Tensor) – Input tensor where one of its axes is the sequence axis.
state (torch.Tensor) – Initial state tensor of the Scan.
axis (int) – Sequence axis along which xs should be sliced and sequentially passed as inputs to the step module.
- Returns:
Concatenation of the sequence of output state tensors along the Sequence axis.
- Return type:
torch.Tensor
- class vollo_torch.nn.RMSNorm(normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None)
Bases:
Module
RMS Normalisation layer based on https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html.
Note that the default value of epsilon is no longer None, since vollo requires a fixed value when compiling to a program. Also note that only 1D normalized_shape is supported.
- extra_repr() str
Extra information about the module.