import torch
import numpy as np
from pytorch_complex_tensor import ComplexTensor
from pylops_gpu.LinearOperator import LinearOperator
from pylops_gpu.utils.complex import conj, reshape, flatten
from pylops_gpu.utils.torch2numpy import numpytype_from_torchtype, \
torchtype_from_numpytype
[docs]class MatrixMult(LinearOperator):
r"""Matrix multiplication.
Simple wrapper to :py:func:`torch.matmul` for
an input matrix :math:`\mathbf{A}`.
Parameters
----------
A : :obj:`torch.Tensor` or :obj:`pytorch_complex_tensor.ComplexTensor` or :obj:`numpy.ndarray`
Matrix.
dims : :obj:`tuple`, optional
Number of samples for each other dimension of model
(model/data will be reshaped and ``A`` applied multiple times
to each column of the model/data).
device : :obj:`str`, optional
Device to be used
togpu : :obj:`tuple`, optional
Move model and data from cpu to gpu prior to applying ``matvec`` and
``rmatvec``, respectively (only when ``device='gpu'``)
tocpu : :obj:`tuple`, optional
Move data and model from gpu to cpu after applying ``matvec`` and
``rmatvec``, respectively (only when ``device='gpu'``)
dtype : :obj:`torch.dtype` or :obj:`np.dtype`, optional
Type of elements in input array.
Attributes
----------
shape : :obj:`tuple`
Operator shape
explicit : :obj:`bool`
Operator contains a matrix that can be solved explicitly
(``True``) or not (``False``)
Notes
-----
Refer to :class:`pylops.basicoperators.MatrixMult` for
implementation details.
"""
def __init__(self, A, dims=None, device='cpu',
togpu=(False, False), tocpu=(False, False),
dtype=torch.float32):
# convert A to torch tensor if provided as numpy array numpy
if not isinstance(A, (torch.Tensor, ComplexTensor)):
dtype = numpytype_from_torchtype(dtype)
self.A = \
torch.from_numpy(A.astype(numpytype_from_torchtype(dtype))).to(device)
self.complex = True if np.iscomplexobj(A) else False
else:
self.complex = True if isinstance(A, ComplexTensor) else False
self.A = A
if dims is None:
self.reshape = False
self.shape = A.shape
else:
if isinstance(dims, int):
dims = (dims, )
self.reshape = True
self.dims = np.array(dims, dtype=np.int)
self.shape = (A.shape[0]*np.prod(self.dims),
A.shape[1]*np.prod(self.dims))
self.newshape = \
(tuple(np.insert([np.prod(self.dims)], 0, self.A.shape[1])),
tuple(np.insert([np.prod(self.dims)], 0, self.A.shape[0])))
self.complex = True if isinstance(A, ComplexTensor) else False
if self.complex:
self.Ac = conj(A).t()
self.device = device
self.togpu = togpu
self.tocpu = tocpu
self.dtype = torchtype_from_numpytype(dtype)
self.explicit = True
self.Op = None
def _matvec(self, x):
if self.reshape:
x = reshape(x, self.newshape[0]) if self.complex else \
torch.reshape(x, self.newshape[0])
else:
if self.complex:
x = x.t()
if self.complex:
y = self.A.mm(x)
if not self.reshape:
y = y.t()
else:
y = self.A.matmul(x)
if self.reshape:
y = flatten(y) if self.complex else y.view(-1)
return y
def _rmatvec(self, x):
if self.reshape:
x = reshape(x, self.newshape[1]) if self.complex else \
torch.reshape(x, self.newshape[1])
else:
if self.complex:
x = x.t()
if self.complex:
y = self.Ac.mm(x)
if not self.reshape:
y = y.t()
else:
y = self.A.t().matmul(x)
if self.reshape:
y = flatten(y) if self.complex else y.view(-1)
return y
[docs] def inv(self):
r"""Return the inverse of :math:`\mathbf{A}`.
Returns
----------
Ainv : :obj:`torch.Tensor`
Inverse matrix.
"""
Ainv = torch.inverse(self.A)
return Ainv
def aslinearoperator(A, device='cpu'):
"""Return A as a LinearOperator.
``A`` may be already a :class:`pylops_gpu.LinearOperator` or a
:obj:`torch.Tensor`.
"""
if isinstance(A, LinearOperator):
return A
else:
return MatrixMult(A, device=device)