Source code for pylops_gpu.utils.torch2numpy

import torch
import numpy as np



def numpytype_from_strtype(strtype):
    """Convert str into equivalent numpy type

    Parameters
    ----------
    strtype : :obj:`str`
        String type

    Returns
    -------
    numpytype : :obj:`numpy.dtype`
        Numpy equivalent type

    """
    numpytype = np.dtype(strtype)
    return numpytype


[docs]def numpytype_from_torchtype(torchtype): """Convert torch type into equivalent numpy type Parameters ---------- torchtype : :obj:`torch.dtype` Torch type Returns ------- numpytype : :obj:`numpy.dtype` Numpy equivalent type """ if isinstance(torchtype, torch.dtype): numpytype = torch.scalar_tensor(1, dtype=torchtype).numpy().dtype else: # in case it is already a numpy dtype numpytype = torchtype return numpytype
[docs]def torchtype_from_numpytype(numpytype): """Convert torch type into equivalent numpy type Parameters ---------- numpytype : :obj:`numpy.dtype` Numpy type Returns ------- torchtype : :obj:`torch.dtype` Torch equivalent type Notes ----- Given limitations of torch to handle complex numbers, complex numpy types are casted into equivalent real types and the equivalent torch type is returned. """ if isinstance(numpytype, torch.dtype): # in case it is already a torch dtype torchtype = numpytype else: torchtype = \ torch.from_numpy(np.real(np.ones(1, dtype=numpytype_from_strtype(numpytype)))).dtype return torchtype