Using LinearOperator Objects
LinearOperator objects share (mostly) the same API as torch.Tensor
objects.
Under the hood, these objects use __torch_function__
to dispatch all efficient linear algebra operations
to the torch
and torch.linalg
namespaces.
This includes
torch.add
torch.cat
torch.clone
torch.diagonal
torch.dim
torch.div
torch.expand
torch.logdet
torch.matmul
torch.numel
torch.permute
torch.prod
torch.squeeze
torch.sub
torch.sum
torch.transpose
torch.unsqueeze
torch.linalg.cholesky
torch.linalg.eigh
torch.linalg.eigvalsh
torch.linalg.solve
torch.linalg.svd
Each of these functions will either return a torch.Tensor
, or a new LinearOperator
object,
depending on the function.
For example:
# A = RootLinearOperator(...)
# B = ToeplitzLinearOperator(...)
# d = vec
C = torch.matmul(A, B) # A new LienearOperator representing the product of A and B
torch.linalg.solve(C, d) # A torch.Tensor
For more examples, see the examples folder.
Batch Support and Broadcasting
LinearOperator
objects operate naturally in batch mode.
For example, to represent a batch of 3 100 x 100
diagonal matrices:
# d = torch.randn(3, 100)
D = DiagLinearOperator(d) # Reprents an operator of size 3 x 100 x 100
These objects fully support broadcasted operations:
D @ torch.randn(100, 2) # Returns a tensor of size 3 x 100 x 2
D2 = DiagLinearOperator(torch.randn([2, 1, 100])) # Represents an operator of size 2 x 1 x 100 x 100
D2 + D # Represents an operator of size 2 x 3 x 100 x 100
Indexing
LinearOperator
objects can be indexed in ways similar to torch Tensors. This includes:
Integer indexing (get a row, column, or batch)
Slice indexing (get a subset of rows, columns, or batches)
LongTensor indexing (get a set of individual entries by index)
Ellipses (support indexing operations with arbitrary batch dimensions)
D = DiagLinearOperator(torch.randn(2, 3, 100)) # Represents an operator of size 2 x 3 x 100 x 100
D[-1] # Returns a 3 x 100 x 100 operator
D[..., :10, -5:] # Returns a 2 x 3 x 10 x 5 operator
D[..., torch.LongTensor([0, 1, 2, 3]), torch.LongTensor([0, 1, 2, 3])] # Returns a 2 x 3 x 4 tensor
Composition and Decoration
LinearOperators can be composed with one another in various ways. This includes
Addition (
LinearOpA + LinearOpB
)Matrix multiplication (
LinearOpA @ LinearOpB
)Concatenation (
torch.cat([LinearOpA, LinearOpB], dim=-2)
)Kronecker product (
torch.kron(LinearOpA, LinearOpB)
)
In addition, there are many ways to “decorate” LinearOperator objects. This includes:
Elementwise multiplying by constants (
torch.mul(2., LinearOpA)
)Summing over batches (
torch.sum(LinearOpA, dim=-3)
)Elementwise multiplying over batches (
torch.prod(LinearOpA, dim=-3)
)
See the documentation for a full list of supported composition and decoration operations.