Source code for pylops_gpu.TorchOperator

import torch

from torch.utils.dlpack import from_dlpack, to_dlpack
from pylops.utils import deps

if deps.cupy_enabled:
    import cupy as cp
else:
    cp = None


class _TorchOperator(torch.autograd.Function):
    """Wrapper class for PyLops operators into Torch functions

    The flag pylops is used to discriminate pylops and pylops-gpu operators;
    the former one requires the input to be converted into numpy.ndarray and
    the output to be converted back to torch.Tensor

    """
    @staticmethod
    def forward(ctx, x, forw, adj, pylops, device):
        ctx.forw = forw
        ctx.adj = adj
        ctx.pylops = pylops
        ctx.device = device

        # prepare input
        if ctx.pylops:
            if ctx.device == 'cpu':
                # bring x to cpu and numpy
                x = x.cpu().detach().numpy()
            else:
                # pass x to cupy using DLPack
                x = cp.fromDlpack(to_dlpack(x))

        # apply forward operator
        y = ctx.forw(x)

        # prepare output
        if ctx.pylops:
            if ctx.device == 'cpu':
                # move y to torch and device
                y = torch.from_numpy(y)
            else:
                # move y to torch and device
                y = from_dlpack(y.toDlpack())
        return y

    @staticmethod
    def backward(ctx, y):
        # prepare input
        if ctx.pylops:
            if ctx.device == 'cpu':
                y = y.cpu().detach().numpy()
            else:
                # pass x to cupy using DLPack
                y = cp.fromDlpack(to_dlpack(y))

        # apply adjoint operator
        x = ctx.adj(y)

        # prepare output
        if ctx.pylops:
            if ctx.device == 'cpu':
                x = torch.from_numpy(x)
            else:
                x = from_dlpack(x.toDlpack())
        return x, None, None, None, None


[docs]class TorchOperator(): """Wrap a PyLops operator into a Torch function. This class can be used to wrap a pylops (or pylops-gpu) operator into a torch function. Doing so, users can mix native torch functions (e.g. basic linear algebra operations, neural networks, etc.) and pylops operators. Since all operators in PyLops are linear operators, a Torch function is simply implemented by using the forward operator for its forward pass and the adjont operator for its backward (gradient) pass. Parameters ---------- Op : :obj:`pylops_gpu.LinearOperator` or :obj:`pylops.LinearOperator` PyLops operator batch : :obj:`bool`, optional Input has single sample (``False``) or batch of samples (``True``). If ``batch==False`` the input must be a 1-d Torch tensor, if `batch==False`` the input must be a 2-d Torch tensor with batches along the first dimension pylops : :obj:`bool`, optional ``Op`` is a pylops operator (``True``) or a pylops-gpu operator (``False``) device : :obj:`str`, optional Device to be used for output vectors when ``Op`` is a pylops operator Returns ------- y : :obj:`torch.Tensor` Output array resulting from the application of the operator to ``x``. """ def __init__(self, Op, batch=False, pylops=False, device='cpu'): self.pylops = pylops self.device = device if not batch: self.matvec = Op.matvec self.rmatvec = Op.rmatvec else: self.matvec = lambda x: Op.matmat(x, kfirst=True) self.rmatvec = lambda x: Op.rmatmat(x, kfirst=True) self.Top = _TorchOperator.apply
[docs] def apply(self, x): """Apply forward pass to input vector Parameters ---------- x : :obj:`torch.Tensor` Input array Returns ------- y : :obj:`torch.Tensor` Output array resulting from the application of the operator to ``x``. """ return self.Top(x, self.matvec, self.rmatvec, self.pylops, self.device)