Source code for qens.plotting

"""
Plots

Each function here takes already processed inputs (datasets, samples,
HWHM table) and writes out a single figure. None of them recomputes the
physics — they're presentation only. The forward-model evaluations are
done by :mod:'qens.models' and called by the inference layer.


All figures use the same colour scheme:


    elastic component     — blue   ''#1565c0''
    quasi-elastic         — orange ''#e67e22''
    fit / MAP             — red    ''#c0392b''
    posterior fan / data  — slate  ''#2471a3''
    resolution            — grey   ''#888''

    
Save by passing ''save_path''; if None the figure is returned for the
user to handle.


"""


from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LinearSegmentedColormap, LogNorm
from scipy.ndimage import gaussian_filter
from scipy.optimize import nnls

from .models import (ce_hwhm,
                     fickian_hwhm,
                     get_model,
                     ss_hwhm)

__all__ = ["plot_overview",
           "plot_sqw_map",
           "plot_per_q_fits",
           "plot_hwhm_vs_q2",
           "plot_posteriors",
           "plot_joint_posterior"]


_SQW_CMAP = LinearSegmentedColormap.from_list("qens",
                                             ["#0a0e1a", "#0c2d6b", "#1565c0", "#42a5f5","#e3f2fd", "#ff8f00", "#e65100"],N=512)




def _despine(ax):
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)




def _save(fig, save_path):
    if save_path:
        fig.savefig(save_path, bbox_inches="tight", dpi=150)


# overview of all loaded files

[docs] def plot_overview(dataset: dict, save_path: str | None = None): """ One line elastic-peak summary per loaded file. Useful sanity check after :func:'qens.preprocessing.assign_resolution'. """ names = sorted(dataset) n = len(names) ncols = min(4, n) nrows = (n + ncols - 1) // ncols fig, axes = plt.subplots(nrows, ncols, figsize=(4.5 * ncols, 3.5 * nrows), squeeze=False) axes = axes.flatten() for ax, name in zip(axes, names): d = dataset[name] good = d["good"] n_lo = max(2, len(good) // 5) avg = np.nanmean([d["data"][good[j]] for j in range(n_lo)], axis=0) avg = np.where(np.isfinite(avg), avg, 0.0) ewin = min(0.5 * d["ei"], 1.2) m = (d["e"] >= -ewin) & (d["e"] <= ewin) peak = avg[m].max() y = avg[m] / peak if peak > 0 else avg[m] col = ("#2ca02c" if d["temp"] <= 270 else "#c0392b" if d["kind"] == "inc" else "#2471a3") ax.plot(d["e"][m], y, color=col, lw=1.5) ax.axvline(0, color="#aaa", lw=0.7, ls=":") ax.set_title(name, fontsize=8, color=col) ax.set_xlabel("ω (meV)", fontsize=8) ax.tick_params(labelsize=7) ax.grid(alpha=0.18) info = (f"E₀={d.get('e0', 0):+.3f} meV\n" f"FWHM={d.get('fwhm_res', 0)*1000:.0f} µeV " f"[{d.get('res_source', '?')}]") ax.text(0.03, 0.96, info, transform=ax.transAxes, va="top", fontsize=6.5, bbox=dict(boxstyle="round", fc="white", alpha=0.8)) _despine(ax) for ax in axes[n:]: ax.axis("off") fig.suptitle("loaded datasets green=frozen red=INC blue=COH", fontsize=10) fig.tight_layout() _save(fig, save_path) return fig, axes
# 2 D S(Q,ω) map for one dataset
[docs] def plot_sqw_map(d: dict, ewin: float = 1.2, save_path: str | None = None): """ Single S(Q,ω) heatmap, log-scaled. """ g = d["good"] qg = d["q"][g] e = d["e"] em = (e >= -ewin) & (e <= ewin) img = d["data"][np.ix_(g, em)] img = np.where(np.isfinite(img) & (img > 0), img, np.nan) qs = np.argsort(qg) ism = gaussian_filter(np.where(np.isfinite(img[qs]), img[qs], 0.0), sigma=[1.5, 0.8]) ism[ism <= 0] = np.nan vmin = max(np.nanpercentile(ism, 2), 1e-8) vmax = np.nanpercentile(ism, 99) fig, ax = plt.subplots(figsize=(8, 6)) fig.patch.set_facecolor("#0d1117") im = ax.pcolormesh(e[em], qg[qs], ism, cmap=_SQW_CMAP, norm=LogNorm(vmin=vmin, vmax=vmax), shading="auto", rasterized=True) ax.axvline(0, color="white", lw=1, ls="--", alpha=0.4) ax.set_xlabel("ω (meV)", color="white", fontsize=12) ax.set_ylabel("Q (Å⁻¹)", color="white", fontsize=12) ax.set_title(f"S(Q,ω) — {d.get('name','?')}", color="white", fontsize=11) ax.tick_params(colors="white") ax.set_facecolor("#0d1117") for sp in ax.spines.values(): sp.set_edgecolor("#555") cb = fig.colorbar(im, ax=ax, pad=0.02, fraction=0.035) cb.set_label("S(Q,ω)", color="white") cb.ax.yaxis.set_tick_params(color="white") plt.setp(cb.ax.yaxis.get_ticklabels(), color="white") fig.tight_layout() if save_path: fig.savefig(save_path, bbox_inches="tight", facecolor="#0d1117", dpi=150) return fig, ax
# per Q bin model vs data panel
[docs] def plot_per_q_fits(data_bins, sigma_res, params, model: str = "anisotropic_rotor", save_path: str | None = None, **extras, ): """ Grid of subplots, one per Q-bin, showing data vs forward-model fit. """ from .fitting import _resolve_kernel_for_bin fm = get_model(model) n = len(data_bins) ncols = 4 nrows = (n + ncols - 1) // ncols fig, axes = plt.subplots(nrows, ncols, figsize=(3.4 * ncols, 2.8 * nrows), squeeze=False) axes = axes.flatten() for ax, k, (omega, spec, errs, q) in zip(axes, range(n), data_bins): sr_k = _resolve_kernel_for_bin(sigma_res, k) shape = fm.predict(omega, q, params, sr_k, **{**fm.extras, **extras}) basis = np.column_stack([shape, np.ones_like(shape)]) amp, _ = nnls(basis / errs[:, None], spec / errs) fit = amp[0] * shape + amp[1] chi2r = np.sum(((spec - fit) / errs) ** 2) / max(len(spec) - 2, 1) ax.errorbar(omega, spec, yerr=errs, fmt=".", color="#222", ms=2.5, elinewidth=0.5, alpha=0.6) ax.plot(omega, fit, "-", color="#c0392b", lw=1.6, label=f"χ²ᵣ={chi2r:.2f}") ax.axhline(amp[1], color="#7f8c8d", lw=0.6, ls=":", alpha=0.6) ax.set_title(f"Q = {q:.2f} Å⁻¹", fontsize=8) ax.legend(fontsize=7, loc="upper right") ax.tick_params(labelsize=7) for ax in axes[n:]: ax.axis("off") pretty = ", ".join(f"{n}={v:.3f}" for n, v in zip(fm.param_names, params)) fig.suptitle(f"forward-model fit per Q-bin model={model} {pretty}", fontsize=10) fig.tight_layout() _save(fig, save_path) return fig, axes
# Γ(Q) vs Q^2 (legacy approach)
[docs] def plot_hwhm_vs_q2(q_centres, hwhm, hwhm_err, model_results: dict | None = None, samples: np.ndarray | None = None, map_params: tuple | None = None, res_hwhm_uev: float = 50, save_path: str | None = None, ): """ Γ(Q) vs Q^2 with optional model curves and posterior fan. """ q_fine = np.linspace(max(q_centres.min() * 0.85, 0.05), q_centres.max() * 1.10, 350) q2f = q_fine ** 2 fig, ax = plt.subplots(figsize=(8, 5.5)) if samples is not None and map_params is not None and len(map_params) >= 2: # CE posterior fan, only meaningful for CE shaped models d_s, l_s = samples[:, 0], np.abs(samples[:, 1]) rng = np.random.default_rng(0) idx = rng.choice(len(d_s), min(300, len(d_s)), replace=False) try: fan = np.array([ce_hwhm(q_fine, d_s[i], l_s[i]) * 1000 for i in idx]) ax.fill_between(q2f, np.percentile(fan, 2.5, axis=0), np.percentile(fan, 97.5, axis=0), alpha=0.18, color="#2471a3", label=f"95% posterior n={len(d_s)}") except Exception: pass if model_results: for name, res in model_results.items(): if "error" in res: continue D = res.get("D", 0.30) try: if name == "ce": L = res.get("l", 2.5) y = ce_hwhm(q_fine, D, L) * 1000 label = (rf"CE $D={D:.4f}$ $\ell={L:.3f}$") elif name == "fickian": y = fickian_hwhm(q_fine, D) * 1000 label = rf"Fickian $D={D:.4f}$" elif name == "ss": ts = res.get("tau_s", 1.0) y = ss_hwhm(q_fine, D, ts) * 1000 label = rf"SS $D={D:.4f}$ $\tau_s={ts:.3f}$" else: continue ax.plot(q2f, y, "-", lw=2.0, label=label) except Exception: continue ax.errorbar(q_centres ** 2, hwhm * 1000, yerr=2 * hwhm_err * 1000, fmt="o", color="#111", ms=6, capsize=3.5, elinewidth=1.4, label=r"data $\pm 2\sigma$") ax.axhline(res_hwhm_uev, color="#888", ls=":", lw=1.3, label=rf"resolution HWHM = {res_hwhm_uev:.0f} µeV") ax.axhspan(0, res_hwhm_uev * 1.1, alpha=0.04, color="#888") ax.set_xlabel(r"$Q^2$ (Å$^{-2}$)", fontsize=12) ax.set_ylabel(r"$\Gamma(Q)$ (µeV)", fontsize=12) ax.set_title(r"$\Gamma(Q)$ vs $Q^2$ — deviation from linearity is " "non-Fickian", fontsize=10) ax.legend(fontsize=9, loc="upper left") ax.set_xlim(left=-0.05) ax.set_ylim(bottom=-10) ax.grid(alpha=0.18) _despine(ax) fig.tight_layout() _save(fig, save_path) return fig, ax
# 1D marginal posteriors
[docs] def plot_posteriors(samples: np.ndarray, model: str = "anisotropic_rotor", reference_values: dict | None = None, derived: dict | None = None, save_path: str | None = None, ): """ Histograms with median, 95% CI band and reference lines. Parameters ---------- samples : ndarray, shape (n, n_params) model : str Registered model name (provides parameter names). reference_values : dict, optional ''{param_name: [(value, label), ...]}'' to plot vertical reference lines (e.g. literature values). derived : dict, optional ''{label: callable(samples)->array}'' for extra panels. """ fm = get_model(model) panels: list[tuple[str, np.ndarray]] = [(n, samples[:, i]) for i, n in enumerate(fm.param_names)] if derived: for name, fn in derived.items(): panels.append((name, np.asarray(fn(samples)))) n = len(panels) cols = min(4, n) rows = (n + cols - 1) // cols fig, axes = plt.subplots(rows, cols, figsize=(4.0 * cols, 3.5 * rows), squeeze=False) axes = axes.flatten() palette = ["#c0392b", "#7f8c8d", "#1e8449", "#2471a3", "#8e44ad", "#e67e22", "#16a085", "#d35400"] for ax, (name, arr), col in zip(axes, panels, palette * (1 + n // len(palette))): med = float(np.median(arr)) lo, hi = np.percentile(arr, [2.5, 97.5]) cnt, _, _ = ax.hist(arr, bins=60, density=True, color=col, alpha=0.78, edgecolor="white", lw=0.3) pk = cnt.max() ax.axvspan(lo, hi, alpha=0.16, color=col) ax.axvline(med, color="black", lw=1.8, label=f"median = {med:.4f}") if reference_values and name in reference_values: for val, lab in reference_values[name]: ax.axvline(val, color="#1565c0", lw=1.7, ls=":", label=f"{lab} = {val}") ax.set_xlabel(name, fontsize=10) ax.set_ylabel("density", fontsize=9) ax.set_title(name, fontsize=10, color=col, fontweight="bold") ax.legend(fontsize=7.5) ax.set_ylim(0, pk * 1.25) ax.grid(alpha=0.18) _despine(ax) for ax in axes[n:]: ax.axis("off") fig.suptitle(f"posteriors — {model} ({len(samples)} samples)", fontsize=11) fig.tight_layout() _save(fig, save_path) return fig, axes
# joint posterior for a pair of parameters
[docs] def plot_joint_posterior(samples: np.ndarray, indices: tuple[int, int] = (0, 1), labels: tuple[str, str] = ("p₁", "p₂"), map_point: tuple[float, float] | None = None, reference_point: tuple[float, float] | None = None, save_path: str | None = None, ): """ Scatter+contour for any two parameter columns of ''samples''. """ i, j = indices x = samples[:, i] y = samples[:, j] fig, ax = plt.subplots(figsize=(7, 6)) n_plot = min(len(x), 3000) idx = np.random.default_rng(0).choice(len(x), n_plot, replace=False) ax.scatter(x[idx], y[idx], c="#2471a3", alpha=0.10, s=4, rasterized=True) if map_point is not None: ax.scatter([map_point[0]], [map_point[1]], color="#f1c40f", marker="o", s=120, edgecolor="black", lw=1.0, zorder=5, label=f"MAP ({map_point[0]:.3f}, {map_point[1]:.3f})") if reference_point is not None: ax.scatter([reference_point[0]], [reference_point[1]], color="#c0392b", marker="*", s=240, edgecolor="black", lw=1.0, zorder=5, label=f"reference " f"({reference_point[0]:.3f}, {reference_point[1]:.3f})") try: from scipy.stats import gaussian_kde xy = np.vstack([x, y]) kde = gaussian_kde(xy) xg = np.linspace(x.min(), x.max(), 100) yg = np.linspace(y.min(), y.max(), 100) X, Y = np.meshgrid(xg, yg) Z = kde(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape) zf = np.sort(Z.ravel())[::-1] cdf = np.cumsum(zf) / zf.sum() l68 = zf[np.searchsorted(cdf, 0.68)] l95 = zf[np.searchsorted(cdf, 0.95)] ax.contour(X, Y, Z, levels=[l95, l68], colors=["#2471a3", "#c0392b"], linewidths=[1.2, 2.0]) except Exception: pass ax.set_xlabel(labels[0], fontsize=12) ax.set_ylabel(labels[1], fontsize=12) ax.set_title(f"joint posterior ({labels[0]}, {labels[1]})", fontsize=11) ax.legend(fontsize=9) ax.grid(alpha=0.18) _despine(ax) fig.tight_layout() _save(fig, save_path) return fig, ax