Source code for qens.config

"""
Configuration dataclass - runtime parameters for the QENS analysis pipeline.

Holds every tunable knob in one place: file lists, Q range, energy window,
binning, MCMC settings and output directory. 

Save/restore as JSON for reuse.

"""
from __future__ import annotations


import json
from dataclasses import dataclass, field, asdict
from typing import List



[docs] @dataclass class Config: """ Runtime parameters for the analysis pipeline. Attributes ---------- files_to_fit : list[str] Filenames to load (relative to ``data_dir`` passed to :func:`qens.io.load_dataset`). primary_file : str Which of the loaded files is the *target* spectrum to fit. resolution_file : str | None Frozen-sample file used as resolution function. If None, the loader auto-picks any T ≤ ``frozen_temp_threshold`` incoherent file. frozen_temp_threshold : int Files at T ≤ this temperature (K) are treated as resolution refs. q_min, q_max : float Q range in Å⁻¹ over which the fit is performed. energy_window : float Half-width of the ω window in meV used for the joint fit (paper found ±1.25 meV needed for benzene anisotropy). n_q_bins : int Number of Q-bins for the joint S(Q,ω) fit. n_walkers : int Number of emcee walkers (must be even, ≥ 2 × n_dim). n_warmup, n_keep : int Burn-in and production steps per walker. thin : int Chain thinning factor. n_map_starts : int Random starts for the MAP search. random_seed : int Master seed for reproducibility. save_dir : str Output directory for figures, samples, summaries. """ # data files_to_fit: List[str] = field(default_factory=list) primary_file: str = "" resolution_file: str | None = None frozen_temp_threshold: int = 270 # defaukt threhold # fit window q_min: float = 0.30 q_max: float = 2.50 energy_window: float = 1.25 # binning n_q_bins: int = 12 # sampling n_walkers: int = 32 n_warmup: int = 500 n_keep: int = 2000 thin: int = 5 n_map_starts: int = 30 # misc random_seed: int = 42 save_dir: str = "qens_results" # validation on init def __post_init__(self): if self.q_min >= self.q_max: raise ValueError( f"q_min ({self.q_min}) must be < q_max ({self.q_max})" ) # check Q range if self.energy_window <= 0: raise ValueError( f"energy_window must be > 0, got {self.energy_window}" ) # energy window should be larger than 0 if self.n_walkers < 4 or self.n_walkers % 2: raise ValueError("n_walkers must be even and ≥ 4") # emcee requirement if self.n_q_bins < 2: raise ValueError("n_q_bins must be ≥ 2") if self.thin < 1: raise ValueError("thin must be ≥ 1") # >= 1 to aviod error and infinite loop if self.n_warmup < 0 or self.n_keep < 1: raise ValueError("n_warmup ≥ 0 and n_keep ≥ 1 required") # deserialise
[docs] def to_dict(self) -> dict: # store as dict return asdict(self)
[docs] def to_json(self, path: str) -> None: # store as json with open(path, "w") as f: json.dump(self.to_dict(), f, indent=2)
[docs] @classmethod def from_json(cls, path: str) -> "Config": with open(path) as f: return cls(**json.load(f))
def __repr__(self) -> str: # ouput parameters and values for viewing lines = ["Config("] for k, v in self.to_dict().items(): lines.append(f" {k} = {v!r},") lines.append(")") return "\n".join(lines)