IFT starting point:
$$d = Rs+n$$Typically, $s$ continuous field, $d$ discrete data vector. Particularily, $R$ is not invertible.
IFT aims at inverting the above uninvertible problem in the best possible way using Bayesian statistics.
NIFTy (Numerical Information Field Theory, en. raffiniert) is a Python framework in which IFT problems can be tackeled easily.
Main Interfaces:
The Posterior is given by:
$$\mathcal P (s|d) \propto P(s,d) = \mathcal G(d-Rs,N) \,\mathcal G(s,S) \propto \mathcal G (m,D) $$where $$\begin{align} m &= Dj \\ D^{-1}&= (S^{-1} +R^\dagger N^{-1} R )\\ j &= R^\dagger N^{-1} d \end{align}$$
Let us implement this in NIFTy!
N_pixels = 512 # Number of pixels
sigma2 = .5 # Noise variance
def pow_spec(k):
P0, k0, gamma = [.2, 5, 6]
return P0 * (1. + (k/k0)**2)**(- gamma / 2)
import matplotlib.pyplot as plt
import numpy as np
from nifty import (DiagonalOperator, EndomorphicOperator, FFTOperator, Field,
InvertibleOperatorMixin, PowerSpace, RGSpace,
create_power_operator, SmoothingOperator, DiagonalProberMixin, Prober)
class PropagatorOperator(InvertibleOperatorMixin, EndomorphicOperator):
def __init__(self, R, N, Sh, default_spaces=None):
super(PropagatorOperator, self).__init__(default_spaces=default_spaces,
preconditioner=lambda x : fft.adjoint_times(Sh.times(fft.times(x))))
self.R = R
self.N = N
self.Sh = Sh
self._domain = R.domain
self.fft = FFTOperator(domain=R.domain, target=Sh.domain)
def _inverse_times(self, x, spaces, x0=None):
return self.R.adjoint_times(self.N.inverse_times(self.R(x))) \
+ self.fft.adjoint_times(self.Sh.inverse_times(self.fft(x)))
@property
def domain(self):
return self._domain
@property
def unitary(self):
return False
@property
def symmetric(self):
return False
@property
def self_adjoint(self):
return True
$D$ is defined via: $$D^{-1} = \mathcal F^\dagger S_h^{-1}\mathcal F + R^\dagger N^{-1} R.$$ In the end, we want to apply $D$ to $j$, i.e. we need the inverse action of $D^{-1}$. This is done numerically (algorithm: Conjugate Gradient).
One can define the condition number of a non-singular and normal matrix $A$: $$\kappa (A) := \frac{|\lambda_{\text{max}}|}{|\lambda_{\text{min}}|},$$ where $\lambda_{\text{max}}$ and $\lambda_{\text{min}}$ are the largest and smallest eigenvalue of $A$, respectively.
The larger $\kappa$ the slower Conjugate Gradient.
By default, conjugate gradient solves: $D^{-1} m = j$ for $m$, where $D^{-1}$ can be bad conditioned. If one knows a non-singular matrix $T$ for which $TD^{-1}$ is better conditioned, one can solve the equivalent problem: $$\tilde A m = \tilde j,$$ where $\tilde A = T D^{-1}$ and $\tilde j = Tj$.
In our case $S^{-1}$ is responsible for the bad conditioning of $D$ depending on the chosen power spectrum. Thus, we choose
s_space = RGSpace(N_pixels)
fft = FFTOperator(s_space)
h_space = fft.target[0]
p_space = PowerSpace(h_space)
# Operators
Sh = create_power_operator(h_space, power_spectrum=pow_spec)
N = DiagonalOperator(s_space, diagonal=sigma2, bare=True)
R = DiagonalOperator(s_space, diagonal=1.)
D = PropagatorOperator(R=R, N=N, Sh=Sh)
# Fields and data
sh = Field(p_space, val=pow_spec).power_synthesize(real_signal=True)
s = fft.adjoint_times(sh)
n = Field.from_random(domain=s_space, random_type='normal',
std=np.sqrt(sigma2), mean=0)
d = R(s) + n
j = R.adjoint_times(N.inverse_times(d))
m = D(j)
s_power = sh.power_analyze()
m_power = fft(m).power_analyze()
s_power_data = s_power.val.get_full_data().real
m_power_data = m_power.val.get_full_data().real
# Get signal data and reconstruction data
s_data = s.val.get_full_data().real
m_data = m.val.get_full_data().real
d_data = d.val.get_full_data().real
plt.plot(s_data, 'k', label="Signal", alpha=.5, linewidth=.5)
plt.plot(d_data, 'k+', label="Data")
plt.plot(m_data, 'r', label="Reconstruction")
plt.title("Reconstruction")
plt.legend()
plt.show()
plt.figure()
plt.plot(s_data - s_data, 'k', label="Signal", alpha=.5, linewidth=.5)
plt.plot(d_data - s_data, 'k+', label="Data")
plt.plot(m_data - s_data, 'r', label="Reconstruction")
plt.axhspan(-np.sqrt(sigma2),np.sqrt(sigma2), facecolor='0.9', alpha=.5)
plt.title("Residuals")
plt.legend()
plt.show()
plt.loglog()
plt.xlim(1, int(N_pixels/2))
ymin = min(m_power_data)
plt.ylim(ymin, 1)
xs = np.arange(1,int(N_pixels/2),.1)
plt.plot(xs, pow_spec(xs), label="True Power Spectrum", linewidth=.7, color='k')
plt.plot(s_power_data, 'k', label="Signal", alpha=.5, linewidth=.5)
plt.plot(m_power_data, 'r', label="Reconstruction")
plt.axhline(sigma2 / N_pixels, color="k", linestyle='--', label="Noise level", alpha=.5)
plt.axhspan(sigma2 / N_pixels, ymin, facecolor='0.9', alpha=.5)
plt.title("Power Spectrum")
plt.legend()
plt.show()
# Operators
Sh = create_power_operator(h_space, power_spectrum=pow_spec)
N = DiagonalOperator(s_space, diagonal=sigma2, bare=True)
# R is defined below
# Fields
sh = Field(p_space, val=pow_spec).power_synthesize(real_signal=True)
s = fft.adjoint_times(sh)
n = Field.from_random(domain=s_space, random_type='normal',
std=np.sqrt(sigma2), mean=0)
l = int(N_pixels * 0.2)
h = int(N_pixels * 0.2 * 4)
mask = Field(s_space, val=1)
mask.val[ l : h] = 0
R = DiagonalOperator(s_space, diagonal = mask)
n.val[l:h] = 0
d = R(s) + n
D = PropagatorOperator(R=R, N=N, Sh=Sh)
j = R.adjoint_times(N.inverse_times(d))
m = D(j)
class DiagonalProber(DiagonalProberMixin, Prober):
def __init__(self, *args, **kwargs):
super(DiagonalProber,self).__init__(*args, **kwargs)
diagProber = DiagonalProber(domain=s_space, probe_dtype=np.complex, probe_count=200)
diagProber(D)
m_var = Field(s_space,val=diagProber.diagonal.val).weight(-1)
s_power = sh.power_analyze()
m_power = fft(m).power_analyze()
s_power_data = s_power.val.get_full_data().real
m_power_data = m_power.val.get_full_data().real
# Get signal data and reconstruction data
s_data = s.val.get_full_data().real
m_data = m.val.get_full_data().real
m_var_data = m_var.val.get_full_data().real
uncertainty = np.sqrt(np.abs(m_var_data))
d_data = d.val.get_full_data().real
# Set lost data to NaN for proper plotting
d_data[d_data == 0] = np.nan
fig = plt.figure(figsize=(15,10))
plt.plot(s_data, 'k', label="Signal", alpha=.5, linewidth=1)
plt.plot(d_data, 'k+', label="Data", alpha=1)
plt.axvspan(l, h, facecolor='0.8', alpha=.5)
plt.title("Incomplete Data")
plt.legend()
fig
fig = plt.figure(figsize=(15,10))
plt.plot(s_data, 'k', label="Signal", alpha=1, linewidth=1)
plt.plot(d_data, 'k+', label="Data", alpha=.5)
plt.plot(m_data, 'r', label="Reconstruction")
plt.axvspan(l, h, facecolor='0.8', alpha=.5)
plt.fill_between(range(N_pixels), m_data - uncertainty, m_data + uncertainty, facecolor='0')
plt.title("Reconstruction of incomplete data")
plt.legend()
fig
N_pixels = 256 # Number of pixels
sigma2 = 1000 # Noise variance
def pow_spec(k):
P0, k0, gamma = [.2, 20, 4]
return P0 * (1. + (k/k0)**2)**(- gamma / 2)
s_space = RGSpace([N_pixels, N_pixels])
fft = FFTOperator(s_space)
h_space = fft.target[0]
p_space = PowerSpace(h_space)
# Operators
Sh = create_power_operator(h_space, power_spectrum=pow_spec)
N = DiagonalOperator(s_space, diagonal=sigma2, bare=True)
R = SmoothingOperator(s_space, sigma=.01)
D = PropagatorOperator(R=R, N=N, Sh=Sh)
# Fields and data
sh = Field(p_space, val=pow_spec).power_synthesize(real_signal=True)
s = fft.adjoint_times(sh)
n = Field.from_random(domain=s_space, random_type='normal',
std=np.sqrt(sigma2), mean=0)
# Lose some data
l = int(N_pixels * 0.2)
h = int(N_pixels * 0.2 * 2)
mask = Field(s_space, val=1)
mask.val[l:h,l:h] = 0
R = DiagonalOperator(s_space, diagonal = mask)
n.val[l:h, l:h] = 0
D = PropagatorOperator(R=R, N=N, Sh=Sh)
d = R(s) + n
j = R.adjoint_times(N.inverse_times(d))
# Run Wiener filter
m = D(j)
# Uncertainty
diagProber = DiagonalProber(domain=s_space, probe_dtype=np.complex, probe_count=10)
diagProber(D)
m_var = Field(s_space, val=diagProber.diagonal.val).weight(-1)
# Get data
s_power = sh.power_analyze()
m_power = fft(m).power_analyze()
s_power_data = s_power.val.get_full_data().real
m_power_data = m_power.val.get_full_data().real
s_data = s.val.get_full_data().real
m_data = m.val.get_full_data().real
m_var_data = m_var.val.get_full_data().real
d_data = d.val.get_full_data().real
uncertainty = np.sqrt(np.abs(m_var_data))
cm = ['magma', 'inferno', 'plasma', 'viridis'][1]
mi = np.min(s_data)
ma = np.max(s_data)
fig, axes = plt.subplots(1, 2, figsize=(15, 7))
data = [s_data, d_data]
caption = ["Signal", "Data"]
for ax in axes.flat:
im = ax.imshow(data.pop(0), interpolation='nearest', cmap=cm, vmin=mi,
vmax=ma)
ax.set_title(caption.pop(0))
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
fig