PyLops-GPU¶
Note
This library is under early development.
Expect things to constantly change until version v1.0.0.
This library is an extension of PyLops to run operators on GPUs.
As much as numpy and scipy lie at the core of the parent project PyLops, PyLops-GPU heavily builds on top of PyTorch and takes advantage of the same optimized tensor computations used in PyTorch for deep learning using GPUs and CPUs. Doing so, linear operators can be computed on GPUs.
Here is a simple example showing how a diagonal operator can be created, applied and inverted using PyLops:
import numpy as np
from pylops import Diagonal
n = int(1e6)
x = np.ones(n)
d = np.arange(n) + 1.
Dop = Diagonal(d)
# y = Dx
y = Dop*x
and similarly using PyLops-GPU:
import numpy as np
import torch
from pylops_gpu.utils.backend import device
from pylops_gpu import Diagonal
dev = device() # will return 'gpu' if GPU is available
n = int(1e6)
x = torch.ones(n, dtype=torch.float64).to(dev)
d = (torch.arange(0, n, dtype=torch.float64) + 1.).to(dev)
Dop = Diagonal(d, device=dev)
# y = Dx
y = Dop*x
Running these two snippets of code in Google Colab with GPU enabled gives a 50+ speed up for the forward pass.
As a by-product of implementing PyLops linear operators in PyTorch, we can easily
chain our operators with any nonlinear mathematical operation (e.g., log, sin, tan, pow, …)
as well as with operators from the torch.nn
submodule and obtain Automatic
Differentiation (AD) for the entire chain. Since the gradient of a linear
operator is simply its adjoint, we have implemented a single class,
pylops_gpu.TorchOperator
, which can wrap any linear operator
from PyLops and PyLops-gpu libraries and return a torch.autograd.Function
object.
History¶
PyLops-GPU was initially written and it is currently maintained by Equinor It is an extension of PyLops for large-scale optimization with GPU-powered linear operators that can be tailored to our needs, and as contribution to the free software community.
Installation¶
You will need Python 3.5 or greater to get started.
Dependencies¶
Our mandatory dependencies are limited to:
We advise using the Anaconda Python distribution
to ensure that these dependencies are installed via the Conda
package manager.
Step-by-step installation for users¶
Python environment¶
Stable releases on PyPI and Conda coming soon…
To install the latest source from github:
>> pip install https://git@github.com/equinor/pylops-gpu.git@master
or just clone the repository
>> git clone https://github.com/equinor/pylops-gpu.git
or download the zip file from the repository (green button in the top right corner of the main github repo page) and install PyLops from terminal using the command:
>> make install
Step-by-step installation for developers¶
Fork and clone the repository by executing the following in your terminal:
>> git clone https://github.com/your_name_here/pylops-gpu.git
The first time you clone the repository run the following command:
>> make dev-install
If you prefer to build a new Conda enviroment just for PyLops, run the following command:
>> make dev-install_conda
To ensure that everything has been setup correctly, run tests:
>> make tests
Make sure no tests fail, this guarantees that the installation has been successfull.
If using Conda environment, always remember to activate the conda environment every time you open a new bash shell by typing:
>> source activate pylops-gpu
Tutorials¶
Note
Click here to download the full example code
01. Automatic Differentiation¶
This tutorial focuses on one of the two main benefits of re-implementing some of PyLops linear operators within the PyTorch framework, namely the possibility to perform Automatic Differentiation (AD) on chains of operators which can be:
- native PyTorch mathematical operations (e.g.,
torch.log
,torch.sin
,torch.tan
,torch.pow
, …) - neural network operators in
torch.nn
- PyLops and/or PyLops-gpu linear operators
This opens up many opportunities, such as easily including linear regularization terms to nonlinear cost functions or using linear preconditioners with nonlinear modelling operators.
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.autograd import gradcheck
import pylops_gpu
from pylops_gpu.utils.backend import device
dev = device()
plt.close('all')
np.random.seed(10)
torch.manual_seed(10)
Out:
<torch._C.Generator object at 0x7f28c50c47b0>
In this example we consider a simple multidimensional functional:
and we use AD to compute the gradient with respect to the input vector evaluated at \(\mathbf{x}=\mathbf{x}_0\) : \(\mathbf{g} = d\mathbf{y} / d\mathbf{x} |_{\mathbf{x}=\mathbf{x}_0}\).
Let’s start by defining the Jacobian:
\[\begin{split}\textbf{J} = \begin{bmatrix} dy_1 / dx_1 & ... & dy_1 / dx_M \\ ... & ... & ... \\ dy_N / dx_1 & ... & dy_N / dx_M \end{bmatrix} = \begin{bmatrix} a_{11} cos(x_1) & ... & a_{1M} cos(x_M) \\ ... & ... & ... \\ a_{N1} cos(x_1) & ... & a_{NM} cos(x_M) \end{bmatrix} = \textbf{A} cos(\mathbf{x})\end{split}\]
Since both input and output are multidimensional,
PyTorch backward
actually computes the product between the transposed
Jacobian and a vector \(\mathbf{v}\):
\(\mathbf{g}=\mathbf{J^T} \mathbf{v}\).
To validate the correctness of the AD result, we can in this simple case
also compute the Jacobian analytically and apply it to the same vector
\(\mathbf{v}\) that we have provided to PyTorch backward
.
nx, ny = 10, 6
x0 = torch.arange(nx, dtype=torch.double, requires_grad=True)
# Forward
A = torch.normal(0., 1., (ny, nx), dtype=torch.double)
Aop = pylops_gpu.TorchOperator(pylops_gpu.MatrixMult(A))
y = Aop.apply(torch.sin(x0))
# AD
v = torch.ones(ny, dtype=torch.double)
y.backward(v, retain_graph=True)
adgrad = x0.grad
# Analytical
J = (A * torch.cos(x0))
anagrad = torch.matmul(J.T, v)
print('Input: ', x0)
print('AD gradient: ', adgrad)
print('Analytical gradient: ', anagrad)
Out:
Input: tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=torch.float64,
requires_grad=True)
AD gradient: tensor([-0.0695, 0.6679, -0.1115, 5.8981, -0.2886, 0.3653, 2.6875, -2.1607,
-0.4924, 0.2005], dtype=torch.float64)
Analytical gradient: tensor([-0.0695, 0.6679, -0.1115, 5.8981, -0.2886, 0.3653, 2.6875, -2.1607,
-0.4924, 0.2005], dtype=torch.float64, grad_fn=<MvBackward>)
Similarly we can use the torch.autograd.gradcheck
directly from
PyTorch. Note that doubles must be used for this to succeed with very small
eps and atol
input = (torch.arange(nx, dtype=torch.double, requires_grad=True),
Aop.matvec, Aop.rmatvec, Aop.pylops, Aop.device)
test = gradcheck(Aop.Top, input, eps=1e-6, atol=1e-4)
print(test)
Out:
True
Note that while matrix-vector multiplication could have been performed using
the native PyTorch operator torch.matmul
, in this case we have shown
that we are also able to use a PyLops-gpu operator wrapped in
pylops_gpu.TorchOperator
. As already mentioned, this gives us the
ability to use much more complex linear operators provided by PyLops within
a chain of mixed linear and nonlinear AD-enabled operators.
Total running time of the script: ( 0 minutes 0.010 seconds)
Note
Click here to download the full example code
02. Post-stack inversion¶
This tutorial focuses on extending post-stack seismic inversion to GPU processing. We refer to the equivalent PyLops tutorial for a more detailed description of the theory.
# sphinx_gallery_thumbnail_number = 2
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.signal import filtfilt
from pylops.utils.wavelets import ricker
import pylops_gpu
from pylops_gpu.utils.backend import device
dev = device()
plt.close('all')
np.random.seed(10)
torch.manual_seed(10)
Out:
<torch._C.Generator object at 0x7f28c50c47b0>
We consider the 1d example. A synthetic profile of acoustic impedance
is created and data is modelled using both the dense and linear operator
version of pylops_gpu.avo.poststack.PoststackLinearModelling
operator. Both model and wavelet are created as numpy arrays and converted
into torch tensors (note that we enforce float32
for optimal performance
on GPU).
# model
nt0 = 301
dt0 = 0.004
t0 = np.arange(nt0)*dt0
vp = 1200 + np.arange(nt0) + \
filtfilt(np.ones(5)/5., 1, np.random.normal(0, 80, nt0))
rho = 1000 + vp + \
filtfilt(np.ones(5)/5., 1, np.random.normal(0, 30, nt0))
vp[131:] += 500
rho[131:] += 100
m = np.log(vp*rho)
# smooth model
nsmooth = 100
mback = filtfilt(np.ones(nsmooth)/float(nsmooth), 1, m)
# wavelet
ntwav = 41
wav, twav, wavc = ricker(t0[:ntwav//2+1], 20)
# convert to torch tensors
m = torch.from_numpy(m.astype('float32'))
mback = torch.from_numpy(mback.astype('float32'))
wav = torch.from_numpy(wav.astype('float32'))
# dense operator
PPop_dense = \
pylops_gpu.avo.poststack.PoststackLinearModelling(wav / 2, nt0=nt0,
explicit=True)
# lop operator
PPop = pylops_gpu.avo.poststack.PoststackLinearModelling(wav / 2, nt0=nt0)
# data
d_dense = PPop_dense * m.flatten()
d = PPop * m.flatten()
# add noise
dn_dense = d_dense + \
torch.from_numpy(np.random.normal(0, 2e-2, d_dense.shape).astype('float32'))
We can now estimate the acoustic profile from band-limited data using either the dense operator or linear operator.
# solve dense
minv_dense = \
pylops_gpu.avo.poststack.PoststackInversion(d, wav / 2, m0=mback, explicit=True,
simultaneous=False)[0]
# solve lop
minv = \
pylops_gpu.avo.poststack.PoststackInversion(d_dense, wav / 2, m0=mback,
explicit=False,
simultaneous=False,
**dict(niter=500))[0]
# solve noisy
mn = \
pylops_gpu.avo.poststack.PoststackInversion(dn_dense, wav / 2, m0=mback,
explicit=True, epsI=1e-4,
epsR=1e0, **dict(niter=100))[0]
fig, axs = plt.subplots(1, 2, figsize=(6, 7), sharey=True)
axs[0].plot(d_dense, t0, 'k', lw=4, label='Dense')
axs[0].plot(d, t0, '--r', lw=2, label='Lop')
axs[0].plot(dn_dense, t0, '-.g', lw=2, label='Noisy')
axs[0].set_title('Data')
axs[0].invert_yaxis()
axs[0].axis('tight')
axs[0].legend(loc=1)
axs[1].plot(m, t0, 'k', lw=4, label='True')
axs[1].plot(mback, t0, '--b', lw=4, label='Back')
axs[1].plot(minv_dense, t0, '--m', lw=2, label='Inv Dense')
axs[1].plot(minv, t0, '--r', lw=2, label='Inv Lop')
axs[1].plot(mn, t0, '--g', lw=2, label='Inv Noisy')
axs[1].set_title('Model')
axs[1].axis('tight')
axs[1].legend(loc=1)

Out:
<matplotlib.legend.Legend object at 0x7f28aab2a7b8>
We move now to a 2d example. First of all the model is loaded and data generated.
# model
inputfile = '../testdata/avo/poststack_model.npz'
model = np.load(inputfile)
m = np.log(model['model'][:, ::3])
x, z = model['x'][::3]/1000., model['z']/1000.
nx, nz = len(x), len(z)
# smooth model
nsmoothz, nsmoothx = 60, 50
mback = filtfilt(np.ones(nsmoothz)/float(nsmoothz), 1, m, axis=0)
mback = filtfilt(np.ones(nsmoothx)/float(nsmoothx), 1, mback, axis=1)
# convert to torch tensors
m = torch.from_numpy(m.astype('float32'))
mback = torch.from_numpy(mback.astype('float32'))
# dense operator
PPop_dense = \
pylops_gpu.avo.poststack.PoststackLinearModelling(wav / 2, nt0=nz,
spatdims=nx, explicit=True)
# lop operator
PPop = pylops_gpu.avo.poststack.PoststackLinearModelling(wav / 2, nt0=nz,
spatdims=nx)
# data
d = (PPop_dense * m.flatten()).reshape(nz, nx)
n = torch.from_numpy(np.random.normal(0, 1e-1, d.shape).astype('float32'))
dn = d + n
Finally we perform different types of inversion
# dense inversion with noise-free data
minv_dense = \
pylops_gpu.avo.poststack.PoststackInversion(d, wav / 2, m0=mback,
explicit=True,
simultaneous=False)[0]
# dense inversion with noisy data
minv_dense_noisy = \
pylops_gpu.avo.poststack.PoststackInversion(dn, wav / 2, m0=mback,
explicit=True, epsI=4e-2,
simultaneous=False)[0]
# spatially regularized lop inversion with noisy data
minv_lop_reg = \
pylops_gpu.avo.poststack.PoststackInversion(dn, wav / 2, m0=minv_dense_noisy,
explicit=False,
epsR=5e1, epsI=1e-2,
**dict(niter=80))[0]
fig, axs = plt.subplots(2, 4, figsize=(15, 9))
axs[0][0].imshow(d, cmap='gray',
extent=(x[0], x[-1], z[-1], z[0]),
vmin=-0.4, vmax=0.4)
axs[0][0].set_title('Data')
axs[0][0].axis('tight')
axs[0][1].imshow(dn, cmap='gray',
extent=(x[0], x[-1], z[-1], z[0]),
vmin=-0.4, vmax=0.4)
axs[0][1].set_title('Noisy Data')
axs[0][1].axis('tight')
axs[0][2].imshow(m, cmap='gist_rainbow',
extent=(x[0], x[-1], z[-1], z[0]),
vmin=m.min(), vmax=m.max())
axs[0][2].set_title('Model')
axs[0][2].axis('tight')
axs[0][3].imshow(mback, cmap='gist_rainbow',
extent=(x[0], x[-1], z[-1], z[0]),
vmin=m.min(), vmax=m.max())
axs[0][3].set_title('Smooth Model')
axs[0][3].axis('tight')
axs[1][0].imshow(minv_dense, cmap='gist_rainbow',
extent=(x[0], x[-1], z[-1], z[0]),
vmin=m.min(), vmax=m.max())
axs[1][0].set_title('Noise-free Inversion')
axs[1][0].axis('tight')
axs[1][1].imshow(minv_dense_noisy, cmap='gist_rainbow',
extent=(x[0], x[-1], z[-1], z[0]),
vmin=m.min(), vmax=m.max())
axs[1][1].set_title('Trace-by-trace Noisy Inversion')
axs[1][1].axis('tight')
axs[1][2].imshow(minv_lop_reg, cmap='gist_rainbow',
extent=(x[0], x[-1], z[-1], z[0]),
vmin=m.min(), vmax=m.max())
axs[1][2].set_title('Regularized Noisy Inversion - lop ')
axs[1][2].axis('tight')
fig, ax = plt.subplots(1, 1, figsize=(3, 7))
ax.plot(m[:, nx//2], z, 'k', lw=4, label='True')
ax.plot(mback[:, nx//2], z, '--r', lw=4, label='Back')
ax.plot(minv_dense[:, nx//2], z, '--b', lw=2, label='Inv Dense')
ax.plot(minv_dense_noisy[:, nx//2], z, '--m', lw=2, label='Inv Dense noisy')
ax.plot(minv_lop_reg[:, nx//2], z, '--g', lw=2, label='Inv Lop regularized')
ax.set_title('Model')
ax.invert_yaxis()
ax.axis('tight')
ax.legend()
plt.tight_layout()
Finally, if you want to run this code on GPUs, take a look at the following notebook and obtain more and more speed-up for problems of increasing size.
Total running time of the script: ( 0 minutes 4.761 seconds)
PyLops-GPU API¶
Linear operators¶
Templates¶
LinearOperator (shape, dtype[, Op, explicit, …]) |
Common interface for performing matrix-vector products. |
TorchOperator (Op[, batch, pylops, device]) |
Wrap a PyLops operator into a Torch function. |
Basic operators¶
MatrixMult (A[, dims, device, togpu, tocpu, …]) |
Matrix multiplication. |
Identity (N[, M, inplace, complex, device, …]) |
Identity operator. |
Diagonal (diag[, dims, dir, device, togpu, …]) |
Diagonal operator. |
VStack (ops[, device, togpu, tocpu, dtype]) |
Vertical stacking. |
Smoothing and derivatives¶
FirstDerivative (N[, dims, dir, sampling, …]) |
First derivative. |
SecondDerivative (N[, dims, dir, sampling, …]) |
Second derivative. |
Laplacian (dims[, dirs, weights, sampling, …]) |
Laplacian. |
Signal processing¶
Convolve1D (N, h[, offset, dims, dir, …]) |
1D convolution operator. |
Solvers¶
Low-level solvers¶
cg (A, y[, x, niter, tol]) |
Conjugate gradient |
cgls (A, y[, x, niter, damp, tol]) |
Conjugate gradient least squares |
Least-squares¶
leastsquares.NormalEquationsInversion (Op, …) |
Inversion of normal equations. |
Sparsity¶
sparsity.FISTA (Op, data, niter[, eps, …]) |
Fast Iterative Soft Thresholding Algorithm (FISTA). |
sparsity.SplitBregman (Op, RegsL1, data[, …]) |
Split Bregman for mixed L2-L1 norms. |
Applications¶
Geophysical subsurface characterization¶
poststack.PoststackInversion (data, wav[, …]) |
Post-stack linearized seismic inversion. |
PyLops-GPU Utilities¶
Alongside with its Linear Operators and Solvers, PyLops-GPU contains also a number of auxiliary routines.
Contributing¶
Contributions are welcome and greatly appreciated!
Follow the instructions in our main repository
Changelog¶
Version 0.0.1¶
Released on: 03/05/2021
- Added
pylops_gpu.optimization.sparsity.FISTA
andpylops_gpu.optimization.sparsity.SplitBregman
solvers - Modified
pylops_gpu.TorchOperator
to work with cupy arrays - Modified
pylops_gpu.avo.poststack._PoststackLinearModelling
to use the code written in pylops library whilst still dealing with torch arrays - Allowed passing numpy dtypes to operators (automatic conversion to torch types)
Roadmap¶
Coming soon…
Contributors¶
- Matteo Ravasi, mrava87
- Francesco Picetti, fpicetti