"""Starlet Transform implementation.
This module implements the isotropic undecimated wavelet transform (starlet
transform), which is used in the KS+ mass-inversion method to apply power
spectrum constraints when reconstructing convergence maps from shear data.
The implementation is based on the algorithm described in J.-L. Starck,
J. Fadili, and F. Murtagh, "The Undecimated Wavelet Decomposition and its
Reconstruction," IEEE Transactions on Image Processing, vol. 16, no. 2,
pp. 297-309, 2007.
This code draws from the CosmoStat implementation:
https://github.com/CosmoStat/cosmostat/blob/master/pycs/sparsity/sparse2d/starlet.py
(MIT Licensed)
"""
import numpy as np
from scipy.ndimage import convolve1d
from typing import Optional
[docs]
def b3spline_filter(step=1):
"""Create a B3-spline filter for the starlet transform.
Generate the 1D B3-spline filter kernel used in the starlet wavelet
transform with appropriate dilation for the à trous algorithm.
Parameters
----------
step : `int`, optional
The dilation step for the à trous algorithm.
Returns
-------
kernel : `numpy.ndarray`
The 1D B3-spline filter with appropriate spacing.
Notes
-----
The B3-spline filter coefficients are [1/16, 1/4, 3/8, 1/4, 1/16].
For step > 1, zeros are inserted between coefficients according to
the à trous algorithm.
"""
# B3-spline filter coefficients
h = np.array([1.0/16, 1.0/4, 3.0/8, 1.0/4, 1.0/16])
# For step=1, return the basic filter
if step == 1:
return h
# For larger steps, add zeros between coefficients (à trous algorithm)
kernel = np.zeros(len(h) + (len(h)-1)*(step-1))
kernel[::step] = h
return kernel
[docs]
def apply_filter(data, kernel):
"""Apply separable convolution with the given kernel.
Perform 2D convolution by applying the 1D kernel separately along
rows and columns using mirror boundary conditions.
Parameters
----------
data : `numpy.ndarray`
Input 2D image.
kernel : `numpy.ndarray`
1D convolution kernel.
Returns
-------
smoothed : `numpy.ndarray`
Smoothed 2D image.
Notes
-----
Uses mirror boundary conditions to handle edges appropriately
for wavelet transforms.
"""
# Apply filter along rows, then columns
temp = convolve1d(data, kernel, axis=0, mode='mirror')
return convolve1d(temp, kernel, axis=1, mode='mirror')
[docs]
def compute_starlet_nscales_max(height: int, width: int) -> int:
"""Compute the safe maximum ``nscales`` for a starlet transform.
The starlet (isotropic undecimated, à trous) transform produces ``J``
detail bands and one coarse residual. The number of scales is defined as
``nscales = J + 1``. The B3–spline à trous kernel support at detail level
``j`` (0-indexed) is ``L_j = 4 * 2^j + 1`` pixels. To avoid border-
dominated coefficients, the coarsest detail (``j = J - 1``) must satisfy
``L_{J-1} <= N`` where ``N = min(height, width)``.
Parameters
----------
height : `int`
Image height in pixels.
width : `int`
Image width in pixels.
Returns
-------
nscales_max : `int`
Maximum safe number of starlet scales (detail bands + coarse).
"""
N = int(min(height, width))
# Guard against very small images. Ensure at least one detail band.
if N <= 1:
return 2
value = (N - 1) / 4.0
# If value < 1, log2 is negative; clamp to at least 1 detail band
J_max = max(1, int(np.floor(np.log2(value))) + 1)
return J_max + 1
[docs]
def starlet_nscales_support_aware(
height: int, width: int, cfg_nscales: Optional[int] = None
) -> int:
"""Return a safe ``nscales`` for the starlet transform.
This function enforces the kernel-support constraint for the B3–spline
starlet and applies an optional user override clipped to the safe range.
Parameters
----------
height : `int`
Image height in pixels.
width : `int`
Image width in pixels.
cfg_nscales : `int`, optional
User-requested number of scales. If provided, it is clipped to the
inclusive range ``[2, nscales_max]``.
Returns
-------
nscales : `int`
Number of scales to use (detail bands + coarse residual).
"""
nscales_max = compute_starlet_nscales_max(height, width)
if cfg_nscales is None:
return nscales_max
return max(2, min(int(cfg_nscales), nscales_max))