"""Implementation of Kaiser-Squires Plus (KS+) mass mapping method.
This enhanced Kaiser-Squires method corrects for systematic effects including
missing data, field borders, and reduced shear approximation using sparsity
priors in the DCT domain and wavelet-based power spectrum constraints.
"""
import numpy as np
from scipy import fft
from scipy.ndimage import gaussian_filter
from ..base import MassMapper
from smpy.filters.starlet import (
starlet_transform_2d,
inverse_starlet_transform_2d,
compute_starlet_nscales_max,
starlet_nscales_support_aware,
)
[docs]
class KSPlusMapper(MassMapper):
"""Implementation of Kaiser-Squires Plus mass mapping.
The KS+ method extends the standard Kaiser-Squires approach with
several enhancements for improved mass reconstruction accuracy:
1. Correcting for missing data using DCT-domain sparsity
2. Reducing field border effects through field extension
3. Iteratively correcting for reduced shear approximation
4. Preserving proper statistical properties using wavelet constraints
Notes
-----
The KS+ algorithm implements the iterative inpainting scheme described
in the literature, combining sparsity priors in the DCT domain with
wavelet-based power spectrum preservation for robust mass map
reconstruction in the presence of missing data and systematic effects.
The binary mask encodes data availability only. Pixels are marked as
data (1) if measurements exist and are valid; gaps (0) are used solely
where data are missing or invalid. Zero shear values are valid data and
must not be masked.
"""
@property
def name(self):
"""Name identifier for the KS+ method.
Returns
-------
method_name : `str`
String identifier 'ks_plus'.
"""
return "ks_plus"
[docs]
def create_maps(self, g1_grid, g2_grid):
"""Create convergence maps using Kaiser-Squires Plus inversion.
Perform enhanced mass mapping reconstruction with iterative
inpainting, reduced shear corrections, and wavelet-based
power spectrum preservation.
Parameters
----------
g1_grid : `numpy.ndarray`
First reduced shear component grid.
g2_grid : `numpy.ndarray`
Second reduced shear component grid.
Returns
-------
kappa_e : `numpy.ndarray`
E-mode convergence map.
kappa_b : `numpy.ndarray`
B-mode convergence map.
Notes
-----
The algorithm performs the following steps:
1. Field extension to reduce border effects
2. Iterative reduced shear correction
3. DCT-domain inpainting with sparsity constraints
4. Wavelet-based power spectrum preservation
5. Optional Gaussian smoothing
"""
# Get dimensions and configuration
npix_dec, npix_ra = g1_grid.shape
config = self.method_config
# Initialize output
kappa_e = np.zeros_like(g1_grid)
kappa_b = np.zeros_like(g1_grid)
# Set up mask (1 where data exists, 0 in gaps). If a per-pixel
# weight grid has been provided by the pipeline, use it to derive
# data availability; otherwise fall back to finite checks only.
mask = self._create_mask(g1_grid, g2_grid)
# Extend field to handle border effects
extension_config = config.get('extension_size', 'double')
if extension_config == 'double':
# Double the field size (add half the field width on each side)
extension_size_dec = npix_dec // 2
extension_size_ra = npix_ra // 2
else:
# Use the specified number of pixels
try:
extension_size_dec = extension_size_ra = int(extension_config)
except (ValueError, TypeError):
print(f"Warning: Invalid extension_size '{extension_config}', using default 'double'")
extension_size_dec = npix_dec // 2
extension_size_ra = npix_ra // 2
g1_extended, g2_extended, mask_extended = self._extend_field(
g1_grid, g2_grid, mask, extension_size_dec, extension_size_ra)
# Reduced shear correction loop
max_iterations = config.get('reduced_shear_iterations', 3)
for k in range(max_iterations):
# Correct for reduced shear: γ = g(1-κ)
g1_corrected = g1_extended.copy()
g2_corrected = g2_extended.copy()
if k > 0: # Skip on first iteration (κ=0)
g1_corrected[mask_extended > 0] *= (1 - kappa_e_extended[mask_extended > 0])
g2_corrected[mask_extended > 0] *= (1 - kappa_e_extended[mask_extended > 0])
# Perform inpainting-based reconstruction
kappa_e_extended, kappa_b_extended = self._inpainting_reconstruction(
g1_corrected, g2_corrected, mask_extended, config)
# Extract the central part for next iteration
start_dec = extension_size_dec
start_ra = extension_size_ra
end_ra = start_ra + npix_ra
end_dec = start_dec + npix_dec
kappa_e = kappa_e_extended[start_dec:end_dec, start_ra:end_ra]
kappa_b = kappa_b_extended[start_dec:end_dec, start_ra:end_ra]
# Apply smoothing if configured
smoothing_config = self.method_config.get('smoothing')
if smoothing_config and smoothing_config.get('type'):
sigma = smoothing_config.get('sigma')
if sigma is None:
raise ValueError("Smoothing enabled but 'sigma' parameter missing. Please specify smoothing.sigma in your configuration.")
kappa_e = gaussian_filter(kappa_e, sigma=sigma)
kappa_b = gaussian_filter(kappa_b, sigma=sigma)
return kappa_e, kappa_b
def _create_mask(self, g1_grid, g2_grid):
"""Create binary mask from data presence/validity.
Construct the KS+ data-availability mask with the convention
M=1 for pixels that have valid measurements and M=0 for gaps.
Zero shear values are considered valid measurements and must
not be treated as gaps.
Data presence is defined as:
- Both shear components are finite (not NaN/Inf); and
- If a per-pixel weight/count grid is available, weight > 0.
Parameters
----------
g1_grid : `numpy.ndarray`
First shear component grid.
g2_grid : `numpy.ndarray`
Second shear component grid.
Returns
-------
mask : `numpy.ndarray`
Binary mask where 1 indicates data present and 0 indicates a gap.
"""
finite_data = np.isfinite(g1_grid) & np.isfinite(g2_grid)
weight_grid = getattr(self, '_weight_grid', None)
if weight_grid is not None:
positive_weight = np.isfinite(weight_grid) & (weight_grid > 0)
mask = (finite_data & positive_weight).astype(g1_grid.dtype)
else:
mask = finite_data.astype(g1_grid.dtype)
return mask
def _extend_field(self, g1_grid, g2_grid, mask, extension_size_dec, extension_size_ra):
"""Extend field to reduce border effects.
Pad the input grids with zeros to create extended fields that
minimize boundary artifacts during the reconstruction process.
Parameters
----------
g1_grid : `numpy.ndarray`
First shear component grid.
g2_grid : `numpy.ndarray`
Second shear component grid.
mask : `numpy.ndarray`
Binary mask.
extension_size_dec : `int`
Size of extension in declination (vertical) pixels.
extension_size_ra : `int`
Size of extension in right ascension (horizontal) pixels.
Returns
-------
g1_extended : `numpy.ndarray`
Extended first shear component grid with zero padding.
g2_extended : `numpy.ndarray`
Extended second shear component grid with zero padding.
mask_extended : `numpy.ndarray`
Extended mask with zero padding.
"""
# Get dimensions
npix_dec, npix_ra = g1_grid.shape
# Create extended grids
new_dec = npix_dec + 2 * extension_size_dec
new_ra = npix_ra + 2 * extension_size_ra
g1_extended = np.zeros((new_dec, new_ra))
g2_extended = np.zeros((new_dec, new_ra))
mask_extended = np.zeros((new_dec, new_ra))
# Insert original field in center
g1_extended[extension_size_dec:extension_size_dec+npix_dec,
extension_size_ra:extension_size_ra+npix_ra] = g1_grid
g2_extended[extension_size_dec:extension_size_dec+npix_dec,
extension_size_ra:extension_size_ra+npix_ra] = g2_grid
mask_extended[extension_size_dec:extension_size_dec+npix_dec,
extension_size_ra:extension_size_ra+npix_ra] = mask
return g1_extended, g2_extended, mask_extended
def _inpainting_reconstruction(self, g1_grid, g2_grid, mask, config):
"""Perform inpainting-based reconstruction.
Execute the core KS+ inpainting algorithm using DCT-domain
sparsity constraints and wavelet-based power spectrum preservation.
Parameters
----------
g1_grid : `numpy.ndarray`
First shear component grid.
g2_grid : `numpy.ndarray`
Second shear component grid.
mask : `numpy.ndarray`
Binary mask indicating data locations.
config : `dict`
Configuration dictionary with algorithm parameters.
Returns
-------
kappa_e : `numpy.ndarray`
E-mode convergence map.
kappa_b : `numpy.ndarray`
B-mode convergence map.
Notes
-----
The inpainting algorithm alternates between DCT thresholding
and data consistency enforcement, with wavelet-based
power spectrum constraints to preserve statistical properties.
"""
# Initial KS inversion to estimate convergence
kappa_e, kappa_b = self._standard_ks_inversion(g1_grid, g2_grid)
# Initialize for DCT inpainting
kappa_complex = kappa_e + 1j * kappa_b
# Get algorithm parameters
max_iterations = config.get('inpainting_iterations', 100)
# Calculate initial threshold max (min fraction deprecated)
dct_coeffs = fft.dctn(kappa_e, norm='ortho')
lambda_max = np.max(np.abs(dct_coeffs))
for i in range(max_iterations):
# DCT thresholding
kappa_e_dct = fft.dctn(np.real(kappa_complex), norm='ortho')
kappa_b_dct = fft.dctn(np.imag(kappa_complex), norm='ortho')
# Calculate threshold for current iteration
lambda_i = self._update_threshold(i, max_iterations, 0.0, lambda_max)
# Apply threshold
kappa_e_dct[np.abs(kappa_e_dct) < lambda_i] = 0
kappa_b_dct[np.abs(kappa_b_dct) < lambda_i] = 0
# Inverse DCT
kappa_e = fft.idctn(kappa_e_dct, norm='ortho')
kappa_b = fft.idctn(kappa_b_dct, norm='ortho')
# Wavelet-based power spectrum constraints
if config.get('use_wavelet_constraints', True):
# Apply power matching to E-mode only by default. KS+ leaves B free.
kappa_e = self._apply_wavelet_constraints(kappa_e, mask)
constrain_B = bool(config.get('constrain_B', False))
if constrain_B:
kappa_b = self._apply_wavelet_constraints(kappa_b, mask)
# Enforce consistency with observed data
gamma1, gamma2 = self._kappa_to_gamma(kappa_e, kappa_b)
# Replace with observed data on data pixels (mask==1)
gamma1[mask > 0] = g1_grid[mask > 0]
gamma2[mask > 0] = g2_grid[mask > 0]
# Convert back to convergence
kappa_e, kappa_b = self._gamma_to_kappa(gamma1, gamma2)
kappa_complex = kappa_e + 1j * kappa_b
return kappa_e, kappa_b
def _standard_ks_inversion(self, g1_grid, g2_grid):
"""Perform standard Kaiser-Squires inversion.
Apply the classical Kaiser-Squires Fourier-domain inversion
to convert shear components to convergence maps.
Parameters
----------
g1_grid : `numpy.ndarray`
First shear component grid.
g2_grid : `numpy.ndarray`
Second shear component grid.
Returns
-------
kappa_e : `numpy.ndarray`
E-mode convergence map.
kappa_b : `numpy.ndarray`
B-mode convergence map.
Notes
-----
Uses the standard KS relations in Fourier space:
kappa_E = ((k1^2 - k2^2) * g1 + 2 * k1 * k2 * g2) / k^2
kappa_B = ((k1^2 - k2^2) * g2 - 2 * k1 * k2 * g1) / k^2
"""
# Get dimensions
npix_dec, npix_ra = g1_grid.shape
# Fourier transform the shear components
g1_hat = np.fft.fft2(g1_grid)
g2_hat = np.fft.fft2(g2_grid)
# Create wavenumber grids
k1, k2 = np.meshgrid(np.fft.fftfreq(npix_ra), np.fft.fftfreq(npix_dec))
k_squared = k1**2 + k2**2
# Avoid division by zero
k_squared[k_squared == 0] = np.finfo(float).eps
# Kaiser-Squires inversion in Fourier space
kappa_e_hat = (1 / k_squared) * ((k1**2 - k2**2) * g1_hat + 2 * k1 * k2 * g2_hat)
kappa_b_hat = (1 / k_squared) * ((k1**2 - k2**2) * g2_hat - 2 * k1 * k2 * g1_hat)
# Inverse Fourier transform
kappa_e = np.real(np.fft.ifft2(kappa_e_hat))
kappa_b = np.real(np.fft.ifft2(kappa_b_hat))
return kappa_e, kappa_b
def _kappa_to_gamma(self, kappa_e, kappa_b):
"""Convert convergence to shear using Fourier space relation.
Apply the forward Kaiser-Squires transform to convert convergence
maps to shear components for data consistency enforcement.
Parameters
----------
kappa_e : `numpy.ndarray`
E-mode convergence map.
kappa_b : `numpy.ndarray`
B-mode convergence map.
Returns
-------
gamma1 : `numpy.ndarray`
First shear component grid.
gamma2 : `numpy.ndarray`
Second shear component grid.
Notes
-----
Implements the forward KS transform in Fourier space to ensure
consistency between convergence and shear during the iterative
inpainting process.
"""
# Get dimensions
npix_dec, npix_ra = kappa_e.shape
# Combine E and B modes
kappa_complex = kappa_e + 1j * kappa_b
kappa_hat = np.fft.fft2(kappa_complex)
# Create wavenumber grids
k1, k2 = np.meshgrid(np.fft.fftfreq(npix_ra), np.fft.fftfreq(npix_dec))
k_squared = k1**2 + k2**2
# Avoid division by zero
mask = k_squared > 0
# Initialize shear components in Fourier space
gamma1_hat = np.zeros_like(kappa_hat, dtype=complex)
gamma2_hat = np.zeros_like(kappa_hat, dtype=complex)
# Apply KS forward transform
gamma1_hat[mask] = ((k1**2 - k2**2) / k_squared)[mask] * kappa_hat[mask]
gamma2_hat[mask] = (2 * k1 * k2 / k_squared)[mask] * kappa_hat[mask]
# Inverse Fourier transform
gamma1 = np.real(np.fft.ifft2(gamma1_hat))
gamma2 = np.real(np.fft.ifft2(gamma2_hat))
return gamma1, gamma2
def _gamma_to_kappa(self, gamma1, gamma2):
"""Convert shear to convergence using Fourier space relation.
Apply the inverse Kaiser-Squires transform to convert shear
components back to convergence maps.
Parameters
----------
gamma1 : `numpy.ndarray`
First shear component grid.
gamma2 : `numpy.ndarray`
Second shear component grid.
Returns
-------
kappa_e : `numpy.ndarray`
E-mode convergence map.
kappa_b : `numpy.ndarray`
B-mode convergence map.
Notes
-----
This method uses the standard KS inversion to maintain consistency
in the forward-backward transform cycle during inpainting.
"""
# This is the standard KS inversion
return self._standard_ks_inversion(gamma1, gamma2)
def _apply_wavelet_constraints(self, kappa, mask):
"""Apply wavelet-based power spectrum constraints.
Preserve statistical properties of the convergence field by
normalizing wavelet coefficients in missing data regions to
match the statistics of observed regions.
Parameters
----------
kappa : `numpy.ndarray`
Convergence map.
mask : `numpy.ndarray`
Binary mask (1 where data exists, 0 in gaps).
Returns
-------
kappa_corrected : `numpy.ndarray`
Convergence map with corrected power spectrum.
Notes
-----
Uses starlet wavelet decomposition to analyze power at different
scales and ensures that the reconstructed field maintains proper
statistical properties in both observed and inpainted regions.
"""
# Determine number of scales with support-aware cap and optional user override
height, width = kappa.shape
cfg_nscales = self.method_config.get('nscales') if self.method_config else None
nscales_max = compute_starlet_nscales_max(height, width)
nscales = starlet_nscales_support_aware(height, width, cfg_nscales)
# Log selection once per mapper instance to avoid excessive output during iterations
if not hasattr(self, '_wavelet_nscales_logged'):
print(
f"KS+: nscales chosen = {nscales} (user={cfg_nscales!r}, "
f"safe_max={nscales_max}, image={height}x{width})"
)
if cfg_nscales is not None and int(cfg_nscales) != nscales:
print(
"Warning: KS+ wavelet.nscales override was clipped to the safe maximum "
f"({nscales_max})."
)
self._wavelet_nscales_logged = True
# Decompose into wavelet coefficients
wavelet_bands = starlet_transform_2d(kappa, nscales)
# Use true binary mask across scales (no wavelet of mask)
data_mask = mask.astype(bool)
# Stability parameters (hard-coded defaults; not exposed to YAML)
clip_min = 0.1
clip_max = 10.0
sigma_floor_abs = 1e-12
sigma_floor_rel = 1e-6
min_samples = 16
# Process each scale except the coarsest
for j in range(nscales - 1):
band = wavelet_bands[j]
# Collect observed/gap samples using the same mask at all scales
obs_vals = band[data_mask]
gap_vals = band[~data_mask]
# Require minimum sample counts for stable statistics
if obs_vals.size < min_samples or gap_vals.size < min_samples:
continue
# Compute standard deviations
std_obs = np.std(obs_vals)
std_gap = np.std(gap_vals)
std_band = np.std(band)
# Skip if non-finite
if not (np.isfinite(std_obs) and np.isfinite(std_gap) and np.isfinite(std_band)):
continue
# Robust floor on denominator to avoid huge ratios
sigma_floor = max(sigma_floor_abs, sigma_floor_rel * std_band)
denom = max(std_gap, sigma_floor)
# Compute and clip scale factor
scale = float(std_obs / denom) if denom > 0 else 1.0
scale = float(np.clip(scale, clip_min, clip_max))
# Apply scaling inside gaps only
band_gap_scaled = gap_vals * scale
band[~data_mask] = band_gap_scaled
wavelet_bands[j] = band
# Reconstruct
kappa_corrected = inverse_starlet_transform_2d(wavelet_bands)
return kappa_corrected
def _update_threshold(self, iteration, max_iterations, lambda_min, lambda_max):
"""Update threshold using a stable exponential schedule.
Compute the DCT coefficient threshold for the current iteration
using exponential decay from maximum to minimum values.
Parameters
----------
iteration : `int`
Current iteration number (0-indexed).
max_iterations : `int`
Maximum number of iterations.
lambda_min : `float`
Minimum threshold value (kept for backward compatibility; not used in 'exp').
lambda_max : `float`
Maximum threshold value (first iteration).
Returns
-------
lambda_i : `float`
Threshold for current iteration.
Notes
-----
The previous schedule collapsed to zero when ``lambda_min=0`` after
the first iteration. We switch to ``exp`` schedule:
lambda_i = lambda_max * exp(-i / tau),
with ``tau`` configurable (default: ``max_iterations/4``).
Configuration (optional, under method config):
- threshold_schedule: 'exp' (default and currently only supported)
- threshold_tau: float, decay constant in iterations
"""
schedule = 'exp'
cfg = self.method_config or {}
schedule = cfg.get('threshold_schedule', schedule)
if schedule != 'exp':
# Fallback gracefully to 'exp'
schedule = 'exp'
# Exponential schedule parameters
tau = cfg.get('threshold_tau')
if tau is None:
# default decay: a quarter of max_iterations (at least 1)
tau = max(1.0, float(max_iterations) / 4.0)
else:
try:
tau = float(tau)
if not np.isfinite(tau) or tau <= 0:
tau = max(1.0, float(max_iterations) / 4.0)
except Exception:
tau = max(1.0, float(max_iterations) / 4.0)
# Compute threshold with exponential decay
lam = float(lambda_max) * float(np.exp(-float(iteration) / tau))
return lam