{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# FISTA\n\nThis example shows how to use the\n:py:class:`pylops_gpu.optimization.sparsity.FISTA` solver.\n\nThis solver can be used when the model to retrieve is supposed to have\na sparse representation in a certain domain. FISTA solves an\nuncostrained problem with a L1 regularization term:\n\n\\begin{align}J = ||\\mathbf{d} - \\mathbf{Op} \\mathbf{x}||_2 + \\epsilon ||\\mathbf{x}||_1\\end{align}\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport pylops\nimport pylops_gpu\n\nfrom pylops_gpu.utils.backend import device\n\ndev = device()\nprint('PyLops-gpu working on %s...' % dev)\nplt.close('all')\n\ntorch.manual_seed(0)\nnp.random.seed(1)\ndtype = torch.float32"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's start with a simple example, where we create a dense mixing matrix\nand a sparse signal and we use OMP and ISTA to recover such a signal.\nNote that the mixing matrix leads to an underdetermined system of equations\n($N < M$) so being able to add some extra prior information regarding\nthe sparsity of our desired model is essential to be able to invert\nsuch a system.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "N, M = 15, 20\nA = np.random.randn(N, M).astype(np.float32)\nAop = pylops_gpu.MatrixMult(torch.from_numpy(A), device=dev)\n\nx = torch.from_numpy(np.random.rand(M).astype(np.float32))\nx[x < 0.9] = 0\ny = Aop * x\n\n# FISTA\neps = 0.5\nmaxit = 1000\nx_fista = pylops_gpu.optimization.sparsity.FISTA(Aop, y, maxit, eps=eps,\n                                                 tol=1e-10)[0]\n\nfig, ax = plt.subplots(1, 1, figsize=(8, 3))\nax.stem(x, linefmt='k', basefmt='k',\n        markerfmt='ko', label='True')\nax.stem(x_fista, linefmt='--r',\n        markerfmt='ro', label='FISTA')\nax.set_title('Model', size=15, fontweight='bold')\nax.legend()\nplt.tight_layout()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We now consider a more interesting problem problem, *wavelet deconvolution*\nfrom a signal that we assume being composed by a train of spikes convolved\nwith a certain wavelet. We will see how solving such a problem with a\nleast-squares solver such as\n:py:class:`pylops_gpu.optimization.leastsquares.RegularizedInversion` does\nnot produce the expected results (especially in the presence of noisy data),\nconversely using the :py:class:`pylops_gpu.optimization.sparsity.FISTA`\nsolver allows us to succesfully retrieve the input signal even\nin the presence of noise.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "nt = 61\ndt = 0.004\nt = np.arange(nt)*dt\nx = np.zeros(nt, dtype=np.float32)\nx[10] = -.4\nx[int(nt/2)] = 1\nx[nt-20] = 0.5\nx = torch.from_numpy(x)\n\nh, th, hcenter = pylops.utils.wavelets.ricker(t[:101], f0=20)\nh = torch.from_numpy(h.astype(np.float32))\nCop = pylops_gpu.signalprocessing.Convolve1D(nt, h=h, offset=int(hcenter),\n                                             dtype=dtype)\ny = Cop * x\n\nxls = pylops_gpu.optimization.cg.cg(Cop, y, niter=10,  tol=1e-10)[0]\n\nxfista = \\\n    pylops_gpu.optimization.sparsity.FISTA(Cop, y, niter=400, eps=5e-1,\n                                           tol=1e-8)[0]\n\nfig, ax = plt.subplots(1, 1, figsize=(8, 3))\nax.plot(t, x, 'k', lw=8, label=r'$x$')\nax.plot(t, y, 'r', lw=4, label=r'$y=Ax$')\nax.plot(t, xls, '--g', lw=4, label=r'$x_{LS}$')\nax.plot(t, xfista, '--m', lw=4, label=r'$x_{FISTA}$')\nax.set_title('Deconvolution', fontsize=14, fontweight='bold')\nax.legend()\nplt.tight_layout()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}