Source code for smpy.plotting.plot

"""Plotting functions for SMPy mass and SNR maps.

Provides high-level plotting for pixel and RA/Dec coordinate systems
with consistent styling, scaling, overlays, and saving. Helpers live in
``smpy.plotting.utils`` to keep this module focused on public plotting
APIs and orchestration.
"""
# Standard Library
from __future__ import annotations

# Third Party
import matplotlib.pyplot as plt
import numpy as np

# Local
from smpy.utils import find_peaks2d
from smpy.plotting.utils import (
    add_colorbar,
    apply_axes_style,
    apply_ra_orientation,
    compute_pixel_extent,
    configure_labels,
    convert_center_to_scaled,
    create_normalization,
    overlay_xray_contours,
    peaks_to_plot_coords,
    propose_ticks,
    set_ticks,
)
import matplotlib.patheffects as patheffects


[docs] def plot_mass_map(data, scaled_boundaries, true_boundaries, config, output_name=None, return_handles=False, map_category="convergence", counts_overlay=None): """Plot a mass-like map (E/B mode) with styling and overlays. Parameters ---------- data : `numpy.ndarray` 2D convergence map data (E or B mode). scaled_boundaries : `dict` Scaled coordinate boundaries for plotting extent. true_boundaries : `dict` True coordinate boundaries for tick labels. config : `dict` Plot configuration settings including figsize, cmap, scaling, 'coordinate_system', optional 'axis_reference' (pixel only), and optional x-ray contour settings under ``xray_contours``. output_name : `str`, optional Path for saving the plot file. return_handles : `bool`, optional If ``True``, return ``(fig, ax, im)`` instead of closing. map_category : `str`, optional Map category used for scaling and overlays. Options: 'convergence', 'snr', 'counts'. counts_overlay : `numpy.ndarray`, optional If provided, overlays integer per-pixel counts (using existing counts labeling logic) on top of the rendered image. Used when ``general.overlay_counts_map: true`` for convergence plots. Returns ------- handles : tuple, optional Returns ``(fig, ax, im)`` if ``return_handles=True``. """ # Dispatch to coordinate-system specific renderer based on config coord_system = config.get("coordinate_system", "radec").lower() if coord_system == "radec": return _plot_radec(data, scaled_boundaries, true_boundaries, config, output_name, return_handles, map_category, counts_overlay) return _plot_pixel(data, scaled_boundaries, true_boundaries, config, output_name, return_handles, map_category, counts_overlay)
def _plot_pixel(data, scaled_boundaries, true_boundaries, config, output_name, return_handles, map_category, counts_overlay): """Render pixel-coordinate plot with overlays and colorbar.""" # Create figure/axes and apply local styling (no global rc changes) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=config.get("figsize", (12, 8))) fontsize = int(config.get("fontsize", 15)) apply_axes_style(ax, fontsize=fontsize) # Choose how axes are labeled in pixel mode axis_reference = str(config.get("axis_reference", "catalog")).lower() # Build colormap normalization from config (percentiles, power, symlog) norm = create_normalization(config.get("scaling"), data, vmin=config.get("vmin"), vmax=config.get("vmax"), map_type=map_category) # Determine image extent from chosen axis reference extent = compute_pixel_extent(data, scaled_boundaries, axis_reference) im = ax.imshow(data, cmap=config.get("cmap", "viridis"), norm=norm, extent=extent, origin="lower") # Optional: mark cluster center; convert to map pixels if needed cx, cy = convert_center_to_scaled(config.get("cluster_center"), scaled_boundaries, true_boundaries, coord_system_type="pixel") if cx is not None: if axis_reference == "map": # Convert from catalog coordinates to map pixel indices height, width = data.shape x_min = scaled_boundaries["coord1_min"] x_max = scaled_boundaries["coord1_max"] y_min = scaled_boundaries["coord2_min"] y_max = scaled_boundaries["coord2_max"] cx = (cx - x_min) / (x_max - x_min) * width cy = (cy - y_min) / (y_max - y_min) * height ax.plot(cx, cy, "rx", markersize=10) # Optional: overlay peak markers above threshold (disabled for counts maps) threshold = config.get("threshold") if (threshold is not None) and (str(map_category).lower() != "counts"): verbose_peaks = bool(config.get("verbose", False)) # Detect peaks using 2D local maxima algorithm X, Y, _, _ = find_peaks2d(data, threshold=threshold, verbose=verbose_peaks, true_boundaries=true_boundaries, scaled_boundaries=scaled_boundaries) # Convert peak indices to appropriate plotting coordinates px, py = peaks_to_plot_coords(X, Y, data, scaled_boundaries, axis_reference) ax.scatter(px, py, s=100, facecolors="none", edgecolors="g", linewidth=1.5) # Optional: overlay DS9 x-ray contours for convergence and/or SNR maps overlay_xray_contours( ax=ax, data_shape=data.shape, scaled_boundaries=scaled_boundaries, true_boundaries=true_boundaries, config=config, map_category=map_category, coord_system_type="pixel", axis_reference=axis_reference, ) # Overlay integer count labels at pixel centers (for counts map or overlay mode) overlay_mode = str(map_category).lower() == "counts" overlay_data = data if overlay_mode else counts_overlay if overlay_data is not None and (overlay_mode or counts_overlay is not None): _overlay_counts_text_pixel(ax, overlay_data, scaled_boundaries, axis_reference, fontsize) # Labels, title, optional grid configure_labels(ax, config, axis_reference=axis_reference, coord_system_type="pixel", fontsize=fontsize) if config.get("gridlines", False): ax.grid(color="black") # Attach colorbar to the right add_colorbar(ax, im, tick_fontsize=fontsize) # Save and/or return figure fig.tight_layout() if output_name: fig.savefig(output_name) map_label = "Convergence" if map_category.lower() == "convergence" else "SNR" print(f"{map_label} map saved as PNG file: {output_name}") if return_handles: return fig, ax, im plt.close(fig) return None def _overlay_counts_text_pixel(ax, data, scaled_boundaries, axis_reference, base_fontsize): """Draw integer count labels at pixel centers for pixel-coordinate plots.""" height, width = data.shape axis_reference = str(axis_reference or "catalog").lower() # Compute x, y centers in plotting coordinates based on axis reference if axis_reference == "map": x_centers = [j + 0.5 for j in range(width)] y_centers = [i + 0.5 for i in range(height)] else: x_min = scaled_boundaries["coord1_min"] x_max = scaled_boundaries["coord1_max"] y_min = scaled_boundaries["coord2_min"] y_max = scaled_boundaries["coord2_max"] x_centers = [x_min + (j + 0.5) * (x_max - x_min) / width for j in range(width)] y_centers = [y_min + (i + 0.5) * (y_max - y_min) / height for i in range(height)] count_fontsize = max(6, int(base_fontsize * 0.6)) outline = [patheffects.withStroke(linewidth=1.8, foreground="black")] for i in range(height): for j in range(width): val = data[i, j] label = f"{int(round(val))}" ax.text( x_centers[j], y_centers[i], label, color="white", ha="center", va="center", fontsize=count_fontsize, path_effects=outline, ) def _plot_radec(data, scaled_boundaries, true_boundaries, config, output_name, return_handles, map_category, counts_overlay): """Render RA/Dec plot with astronomical orientation and ticks.""" # Create figure/axes and apply local styling (no global rc changes) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=config.get("figsize", (12, 8))) fontsize = int(config.get("fontsize", 15)) apply_axes_style(ax, fontsize=fontsize) # Build colormap normalization from config norm = create_normalization(config.get("scaling"), data, vmin=config.get("vmin"), vmax=config.get("vmax"), map_type=map_category) # Draw image using scaled RA/Dec extents im = ax.imshow( data, cmap=config.get("cmap", "viridis"), norm=norm, extent=[ scaled_boundaries["coord1_min"], scaled_boundaries["coord1_max"], scaled_boundaries["coord2_min"], scaled_boundaries["coord2_max"], ], origin="lower", ) # Optional: overlay peak markers above threshold (disabled for counts maps) threshold = config.get("threshold") if (threshold is not None) and (str(map_category).lower() != "counts"): verbose_peaks = bool(config.get("verbose", False)) # Detect peaks using 2D local maxima algorithm X, Y, _, _ = find_peaks2d(data, threshold=threshold, verbose=verbose_peaks, true_boundaries=true_boundaries, scaled_boundaries=scaled_boundaries) # Convert peak pixel indices to RA/Dec coordinates ra_peaks = [ scaled_boundaries["coord1_min"] + (x + 0.5) * (scaled_boundaries["coord1_max"] - scaled_boundaries["coord1_min"]) / data.shape[1] for x in X ] dec_peaks = [ scaled_boundaries["coord2_min"] + (y + 0.5) * (scaled_boundaries["coord2_max"] - scaled_boundaries["coord2_min"]) / data.shape[0] for y in Y ] ax.scatter(ra_peaks, dec_peaks, s=100, facecolors="none", edgecolors="g", linewidth=1.5) # Optional: overlay DS9 x-ray contours for convergence and/or SNR maps overlay_xray_contours( ax=ax, data_shape=data.shape, scaled_boundaries=scaled_boundaries, true_boundaries=true_boundaries, config=config, map_category=map_category, coord_system_type="radec", ) # Overlay integer count labels at pixel centers (for counts map or overlay mode) overlay_mode = str(map_category).lower() == "counts" overlay_data = data if overlay_mode else counts_overlay if overlay_data is not None and (overlay_mode or counts_overlay is not None): _overlay_counts_text_radec(ax, overlay_data, scaled_boundaries, fontsize) # Optional: mark cluster center in RA/Dec coordinates ra_center, dec_center = convert_center_to_scaled(config.get("cluster_center"), scaled_boundaries, true_boundaries, coord_system_type="radec") if ra_center is not None: ax.plot(ra_center, dec_center, "rx", markersize=10) # Generate ticks: propose in true coordinate space, then map to scaled plotting space ra_ticks_true, ra_labels = propose_ticks(true_boundaries["coord1_min"], true_boundaries["coord1_max"], 5) dec_ticks_true, dec_labels = propose_ticks(true_boundaries["coord2_min"], true_boundaries["coord2_max"], 5) # Transform tick positions from true to scaled coordinates for plotting scaled_x = np.interp( ra_ticks_true, [true_boundaries["coord1_min"], true_boundaries["coord1_max"]], [scaled_boundaries["coord1_min"], scaled_boundaries["coord1_max"]], ) scaled_y = np.interp( dec_ticks_true, [true_boundaries["coord2_min"], true_boundaries["coord2_max"]], [scaled_boundaries["coord2_min"], scaled_boundaries["coord2_max"]], ) set_ticks(ax, scaled_x, scaled_y, ra_labels, dec_labels) # Labels, title, optional grid configure_labels(ax, config, coord_system_type="radec", fontsize=fontsize) if config.get("gridlines", False): ax.grid(color="black") # Astronomical convention: RA increases to the left apply_ra_orientation(ax) # Attach colorbar to the right add_colorbar(ax, im, tick_fontsize=fontsize) # Save and/or return figure fig.tight_layout() if output_name: fig.savefig(output_name) map_label = "Convergence" if map_category.lower() == "convergence" else "SNR" print(f"{map_label} map saved as PNG file: {output_name}") if return_handles: return fig, ax, im plt.close(fig) return None def _overlay_counts_text_radec(ax, data, scaled_boundaries, base_fontsize): """Draw integer count labels at pixel centers for RA/Dec plots (scaled coordinates).""" height, width = data.shape x_min = scaled_boundaries["coord1_min"] x_max = scaled_boundaries["coord1_max"] y_min = scaled_boundaries["coord2_min"] y_max = scaled_boundaries["coord2_max"] x_centers = [x_min + (j + 0.5) * (x_max - x_min) / width for j in range(width)] y_centers = [y_min + (i + 0.5) * (y_max - y_min) / height for i in range(height)] count_fontsize = max(6, int(base_fontsize * 0.6)) outline = [patheffects.withStroke(linewidth=1.8, foreground="black")] for i in range(height): for j in range(width): val = data[i, j] label = f"{int(round(val))}" ax.text( x_centers[j], y_centers[i], label, color="white", ha="center", va="center", fontsize=count_fontsize, path_effects=outline, )
[docs] def plot_snr_map(data, scaled_boundaries, true_boundaries, config, output_name=None, return_handles=False, counts_overlay=None): """Plot an SNR map with styling and overlays. Parameters ---------- See :func:`plot_mass_map`. """ # Delegate to mass_map with SNR category for proper scaling overrides return plot_mass_map( data=data, scaled_boundaries=scaled_boundaries, true_boundaries=true_boundaries, config=config, output_name=output_name, return_handles=return_handles, map_category="snr", counts_overlay=counts_overlay, )