Source code for smpy.config

"""Configuration management for SMPy.

This module provides configuration loading, validation, and management
for SMPy mass mapping operations.
"""

import copy
import os
from pathlib import Path
from typing import Dict, Optional

import yaml


[docs] class Config: """Manage configuration dictionaries for SMPy mass mapping analysis. Handle loading, merging, and validating configuration dictionaries from YAML files and user parameters. All configurations use consistent nested structure for clean architecture and reliable error handling. Parameters ---------- config_dict : `dict`, optional Configuration dictionary. If `None`, creates empty config. Notes ----- The Config class uses a consistent nested configuration structure: - **General settings**: ``config['general']`` - Input/output, coordinate system, analysis settings - **Method-specific**: ``config['methods'][method_name]`` - Parameters for each mapping method - **Plotting settings**: ``config['plotting']`` - Visualization parameters - **SNR settings**: ``config['snr']`` - Signal-to-noise map generation parameters Configuration access follows the fail-fast principle: - Required parameters use direct access: ``config['section']['parameter']`` - Optional parameters use ``.get()``: ``config['section'].get('parameter', default)`` - Missing required config raises immediate ``KeyError`` for clear debugging Examples -------- Load default configuration for Kaiser-Squires: >>> config = Config.from_defaults('kaiser_squires') >>> config.show_config() Load existing user configuration: >>> config = Config.from_file('my_config.yaml') >>> config.show_config(section='general') Access method-specific parameters: >>> cfg_dict = config.to_dict() >>> smoothing = cfg_dict['methods']['kaiser_squires']['smoothing'] Save current configuration: >>> config.save_config('output_config.yaml') Update configuration programmatically: >>> config.update_from_kwargs( ... data='catalog.fits', ... coord_system='radec', ... pixel_scale=0.168 ... ) """
[docs] def __init__(self, config_dict=None): """Initialize Config with optional configuration dictionary. Parameters ---------- config_dict : `dict`, optional Configuration dictionary. If None, creates empty config. """ self.config = config_dict if config_dict is not None else {}
[docs] @classmethod def from_file(cls, path): """Load configuration from YAML file. Parameters ---------- path : `str` or `pathlib.Path` Path to YAML configuration file. Returns ------- config : `Config` Configuration instance loaded from file. """ with open(path, 'r') as f: config_dict = yaml.safe_load(f) return cls(config_dict)
[docs] @classmethod def from_defaults(cls, method='kaiser_squires'): """Load default configuration for specified method. Load configuration from the default.yaml file and return the nested structure as-is. This provides consistent configuration structure regardless of loading method. Parameters ---------- method : `str`, optional Method name ('kaiser_squires', 'aperture_mass', or 'ks_plus'). Default is 'kaiser_squires'. Returns ------- config : `Config` Configuration instance with default settings in nested structure. Raises ------ FileNotFoundError If the default configuration file cannot be found. ValueError If the specified method is not supported. Notes ----- Returns the full nested configuration structure: - **General settings**: ``config['general']`` - **Method-specific**: ``config['methods'][method_name]`` - **Plotting settings**: ``config['plotting']`` - **SNR settings**: ``config['snr']`` Examples -------- Load Kaiser-Squires defaults: >>> config = Config.from_defaults('kaiser_squires') >>> smoothing = config.to_dict()['methods']['kaiser_squires']['smoothing'] >>> print(smoothing['type']) gaussian Load KS+ defaults: >>> config = Config.from_defaults('ks_plus') >>> ks_config = config.to_dict()['methods']['ks_plus'] >>> print(ks_config['inpainting_iterations']) 100 """ # Load default.yaml defaults_path = Path(__file__).parent / 'configs' / 'default.yaml' if not defaults_path.exists(): raise FileNotFoundError(f"Default config file not found: {defaults_path}") with open(defaults_path, 'r') as f: config = yaml.safe_load(f) # Set the method in general section config['general']['method'] = method return cls(config)
[docs] def update_from_kwargs(self, **kwargs): """Update configuration from keyword arguments. This method maps simple keyword arguments to the nested configuration structure expected by SMPy. Parameters ---------- **kwargs Keyword arguments to convert to config structure """ # Handle data/input_path if 'data' in kwargs: self._ensure_section('general') self.config['general']['input_path'] = kwargs['data'] # Handle coordinate system if 'coord_system' in kwargs: self._ensure_section('general') coord_system = kwargs['coord_system'] if coord_system not in ['radec', 'pixel']: raise ValueError(f"Invalid coord_system: {coord_system}") self.config['general']['coordinate_system'] = coord_system # Mark that coordinate system was explicitly set by user self.config['general']['_coord_system_set_by_user'] = True # Handle pixel_scale (for radec system) if 'pixel_scale' in kwargs and kwargs['pixel_scale'] is not None: self._ensure_section('general') self._ensure_section('general', 'radec') self.config['general']['radec']['resolution'] = kwargs['pixel_scale'] # Mark that pixel_scale was explicitly set by user self.config['general']['_pixel_scale_set_by_user'] = True # Handle downsample_factor (for pixel system) if 'downsample_factor' in kwargs and kwargs['downsample_factor'] is not None: self._ensure_section('general') self._ensure_section('general', 'pixel') self.config['general']['pixel']['downsample_factor'] = kwargs['downsample_factor'] # Mark that downsample_factor was explicitly set by user self.config['general']['_downsample_factor_set_by_user'] = True # Handle pixel_axis_reference (for pixel plotting) if 'pixel_axis_reference' in kwargs and kwargs['pixel_axis_reference'] is not None: self._ensure_section('general') self._ensure_section('general', 'pixel') axis_ref = kwargs['pixel_axis_reference'] if axis_ref not in ['catalog', 'map']: raise ValueError("pixel_axis_reference must be 'catalog' or 'map'") self.config['general']['pixel']['pixel_axis_reference'] = axis_ref # Handle method if 'method' in kwargs: self._ensure_section('general') self.config['general']['method'] = kwargs['method'] # Handle output directory if 'output_dir' in kwargs: self._ensure_section('general') self.config['general']['output_directory'] = kwargs['output_dir'] # Handle output base name if 'output_base_name' in kwargs: self._ensure_section('general') self.config['general']['output_base_name'] = kwargs['output_base_name'] # Handle smoothing parameter if 'smoothing' in kwargs: # Always nested structure method = self.config['general']['method'] self._ensure_section('methods') self._ensure_section('methods', method) self._ensure_section('methods', method, 'smoothing') self.config['methods'][method]['smoothing']['sigma'] = kwargs['smoothing'] # Handle create_snr if 'create_snr' in kwargs: self._ensure_section('general') self.config['general']['create_snr'] = kwargs['create_snr'] # Handle create_counts_map if 'create_counts_map' in kwargs: self._ensure_section('general') self.config['general']['create_counts_map'] = kwargs['create_counts_map'] # Handle overlay_counts_map if 'overlay_counts_map' in kwargs: self._ensure_section('general') self.config['general']['overlay_counts_map'] = kwargs['overlay_counts_map'] # Handle save_fits if 'save_fits' in kwargs: self._ensure_section('general') self.config['general']['save_fits'] = kwargs['save_fits'] # Handle save_plots if 'save_plots' in kwargs: self._ensure_section('general') self.config['general']['save_plots'] = kwargs['save_plots'] # Handle mode if 'mode' in kwargs: self._ensure_section('general') mode_value = kwargs['mode'] if isinstance(mode_value, str): mode_value = [mode_value] self.config['general']['mode'] = mode_value # Handle data columns for col in ['g1_col', 'g2_col', 'weight_col']: if col in kwargs: self._ensure_section('general') self.config['general'][col] = kwargs[col] # Handle verbosity/timing if 'print_timing' in kwargs: self._ensure_section('general') self.config['general']['print_timing'] = kwargs['print_timing'] if 'verbose' in kwargs: self._ensure_section('plotting') self.config['plotting']['verbose'] = kwargs['verbose'] # Handle plotting fontsize if 'fontsize' in kwargs and kwargs['fontsize'] is not None: self._ensure_section('plotting') self.config['plotting']['fontsize'] = kwargs['fontsize'] # Handle KS+ specific parameters if 'inpainting_iterations' in kwargs: self._ensure_section('methods') self._ensure_section('methods', 'ks_plus') self.config['methods']['ks_plus']['inpainting_iterations'] = kwargs['inpainting_iterations'] if 'reduced_shear_iterations' in kwargs: self._ensure_section('methods') self._ensure_section('methods', 'ks_plus') self.config['methods']['ks_plus']['reduced_shear_iterations'] = kwargs['reduced_shear_iterations'] # KS+ wavelet constraints and schedule options if 'use_wavelet_constraints' in kwargs: self._ensure_section('methods') self._ensure_section('methods', 'ks_plus') self.config['methods']['ks_plus']['use_wavelet_constraints'] = kwargs['use_wavelet_constraints'] if 'constrain_B' in kwargs: self._ensure_section('methods') self._ensure_section('methods', 'ks_plus') self.config['methods']['ks_plus']['constrain_B'] = kwargs['constrain_B'] if 'threshold_schedule' in kwargs: self._ensure_section('methods') self._ensure_section('methods', 'ks_plus') self.config['methods']['ks_plus']['threshold_schedule'] = kwargs['threshold_schedule'] if 'threshold_tau' in kwargs: self._ensure_section('methods') self._ensure_section('methods', 'ks_plus') self.config['methods']['ks_plus']['threshold_tau'] = kwargs['threshold_tau'] # note: wavelet constraint stability parameters are internal defaults (not exposed) # KS+ other advanced options if 'extension_size' in kwargs: self._ensure_section('methods') self._ensure_section('methods', 'ks_plus') self.config['methods']['ks_plus']['extension_size'] = kwargs['extension_size'] # note: min_threshold_fraction is deprecated and not supported # Handle KS+ wavelet settings if 'wavelet' in kwargs and isinstance(kwargs['wavelet'], dict): self._ensure_section('methods') self._ensure_section('methods', 'ks_plus') # Ensure wavelet sub-dict exists, then update if 'wavelet' not in self.config['methods']['ks_plus']: self.config['methods']['ks_plus']['wavelet'] = {} self.config['methods']['ks_plus']['wavelet'].update(kwargs['wavelet']) if 'wavelet_nscales' in kwargs and kwargs['wavelet_nscales'] is not None: self._ensure_section('methods') self._ensure_section('methods', 'ks_plus') self.config['methods']['ks_plus']['nscales'] = kwargs['wavelet_nscales'] # Handle aperture mass filter parameters if 'filter' in kwargs: self._ensure_section('methods') self._ensure_section('methods', 'aperture_mass') if isinstance(kwargs['filter'], dict): self._ensure_section('methods', 'aperture_mass', 'filter') self.config['methods']['aperture_mass']['filter'].update(kwargs['filter']) # Handle individual filter parameters if 'filter_type' in kwargs: self._ensure_section('methods') self._ensure_section('methods', 'aperture_mass') self._ensure_section('methods', 'aperture_mass', 'filter') self.config['methods']['aperture_mass']['filter']['type'] = kwargs['filter_type'] if 'filter_scale' in kwargs: self._ensure_section('methods') self._ensure_section('methods', 'aperture_mass') self._ensure_section('methods', 'aperture_mass', 'filter') self.config['methods']['aperture_mass']['filter']['scale'] = kwargs['filter_scale']
def _ensure_section(self, *keys): """Ensure nested dictionary sections exist. Parameters ---------- *keys : `str` Nested keys to ensure exist """ current = self.config for key in keys: if key not in current: current[key] = {} current = current[key]
[docs] def validate(self): """Validate configuration for required parameters. Expects nested configuration structure only. Raises ------ ValueError If required parameters are missing or invalid """ # Access nested structure directly general = self.config['general'] # Check for required parameters (only if input_path is actually set to a real value) if general.get('input_path') and general['input_path'] != "": required_params = ['input_path', 'coordinate_system'] for param in required_params: if param not in general: raise ValueError(f"Required parameter '{param}' missing from config") # Check coordinate system specific requirements # Only validate if input_path is set (meaning this is a real run, not just loading defaults) input_path = general.get('input_path', '') if input_path and input_path != "": coord_system = general.get('coordinate_system', '').lower() coord_system_set_by_user = general.get('_coord_system_set_by_user', False) if coord_system == 'radec': # If coordinate system was set by user, require pixel_scale to also be set by user if coord_system_set_by_user and not general.get('_pixel_scale_set_by_user', False): raise ValueError(self._missing_coord_param_message('radec')) elif not coord_system_set_by_user and ('radec' not in general or 'resolution' not in general['radec']): raise ValueError(self._missing_coord_param_message('radec')) elif coord_system == 'pixel': # If coordinate system was set by user, require downsample_factor to also be set by user if coord_system_set_by_user and not general.get('_downsample_factor_set_by_user', False): raise ValueError(self._missing_coord_param_message('pixel')) elif not coord_system_set_by_user and ('pixel' not in general or 'downsample_factor' not in general['pixel']): raise ValueError(self._missing_coord_param_message('pixel')) # Validate optional axis reference if present pixel_cfg = general.get('pixel', {}) axis_ref = pixel_cfg.get('pixel_axis_reference') if axis_ref is not None and axis_ref not in ['catalog', 'map']: raise ValueError("'pixel_axis_reference' must be 'catalog' or 'map' when provided") # Validate optional x-ray contour plotting settings when provided plotting = self.config.get('plotting', {}) xray_cfg = plotting.get('xray_contours') if xray_cfg is not None: if not isinstance(xray_cfg, dict): raise ValueError("'plotting.xray_contours' must be a dictionary when provided") ctr_file = xray_cfg.get('ctr_file') if ctr_file is not None and not isinstance(ctr_file, str): raise ValueError("'plotting.xray_contours.ctr_file' must be a string or null") for flag_key in ('show_on_convergence', 'show_on_snr'): flag_value = xray_cfg.get(flag_key) if flag_value is not None and not isinstance(flag_value, bool): raise ValueError(f"'plotting.xray_contours.{flag_key}' must be a boolean when provided") alpha = xray_cfg.get('alpha') if alpha is not None: try: alpha_value = float(alpha) except (TypeError, ValueError) as exc: raise ValueError("'plotting.xray_contours.alpha' must be a number") from exc if not (0.0 <= alpha_value <= 1.0): raise ValueError("'plotting.xray_contours.alpha' must be between 0 and 1") linewidth = xray_cfg.get('linewidth') if linewidth is not None: try: linewidth_value = float(linewidth) except (TypeError, ValueError) as exc: raise ValueError("'plotting.xray_contours.linewidth' must be a number") from exc if linewidth_value <= 0: raise ValueError("'plotting.xray_contours.linewidth' must be greater than 0") # Validate method method = general.get('method', 'kaiser_squires') valid_methods = ['kaiser_squires', 'aperture_mass', 'ks_plus'] if method not in valid_methods: raise ValueError(f"Invalid method '{method}'. Must be one of: {valid_methods}")
def _missing_coord_param_message(self, coord_system: str) -> str: """Create a unified, actionable error message for missing parameters. Parameters ---------- coord_system : `str` The coordinate system specified in the configuration. Expected values are 'radec' or 'pixel'. Returns ------- message : `str` A clear error message that explains what parameter is missing and how to provide it via the Python API or YAML configuration. """ if coord_system == 'radec': return ( "Missing required parameter for coordinate_system='radec'. " "Provide 'pixel_scale' (API: pixel_scale=..., YAML: general.radec.resolution)." ) if coord_system == 'pixel': return ( "Missing required parameter for coordinate_system='pixel'. " "Provide 'downsample_factor' (API: downsample_factor=..., " "YAML: general.pixel.downsample_factor)." ) return ( "Invalid coordinate_system specified. Expected 'radec' or 'pixel'." )
[docs] def validate_file_existence(self): """Validate that input files exist on disk. Expects nested configuration structure only. Raises ------ FileNotFoundError If input file does not exist """ # Access nested structure directly input_path = self.config['general'].get('input_path') # Skip validation for empty paths or test paths if not input_path or input_path == "" or input_path.startswith('/some/fake'): return # Check if file exists if not os.path.exists(input_path): raise FileNotFoundError( f"Input file not found: {input_path}\n" f"Please check that the file exists and the path is correct." )
[docs] def to_dict(self): """Return configuration as dictionary. Returns ------- config : `dict` Configuration dictionary """ return copy.deepcopy(self.config)
[docs] def show_config(self, section=None): """Print current configuration in YAML format. Parameters ---------- section : `str`, optional Show only specific section ('general', 'plotting', 'snr', 'methods'). If `None`, shows entire configuration. """ if section: # Extract and show only requested section if section in self.config: config_to_show = {section: self.config[section]} else: print(f"Section '{section}' not found") return else: # Show entire config config_to_show = self.config # Print as YAML print(yaml.dump(config_to_show, default_flow_style=False, sort_keys=False))
[docs] def save_config(self, path): """Save current configuration to YAML file. Parameters ---------- path : `str` or `pathlib.Path` Path to save configuration file """ with open(path, 'w') as f: yaml.dump(self.config, f, default_flow_style=False, sort_keys=False) print(f"Configuration saved to: {path}")