Source code for pymultifit.fitters.utilities_f

"""Created on Aug 18 23:52:19 2024"""

__all__ = ["parameter_logic", "sanity_check", "_plot_fit"]

from typing import List, Tuple, Union, Optional, Callable

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from mpyez.backend.uPlotting import LinePlot # type: ignore
from mpyez.ezPlotting import plot_xy # type: ignore

from .. import ListOrNdArray

# SAFEGUARD:
xy_tuple = Tuple[np.ndarray, np.ndarray]
indexType = Union[int, List[int], None]


[docs] def sanity_check( x_values: ListOrNdArray, y_values: ListOrNdArray, ) -> Tuple[np.ndarray, np.ndarray]: """ Convert input lists to NumPy arrays if necessary. Parameters ---------- x_values : list of float or np.ndarray Input x-values that will be converted to a NumPy array if they are in list format. y_values : list of float or np.ndarray Input y-values that will be converted to a NumPy array if they are in list format. Returns ------- x_values : np.ndarray The x-values as a NumPy array. y_values : np.ndarray The y-values as a NumPy array. """ x_values = np.asarray(a=x_values, dtype=float) y_values = np.asarray(a=y_values, dtype=float) return x_values, y_values
[docs] def parameter_logic( par_array: np.ndarray, n_par: int, selected_models, ) -> np.ndarray: """ Extract parameter values from a given function based on the number of parameters per fit and selected indices. Parameters ---------- par_array : np.ndarray A 2D array where the first column contains the parameter values and the second contains its standard errors. n_par : int The number of parameters per fit (e.g., amplitude, mu, sigma, etc.). selected_models : int, list of int, or None Indices of model components to extract. - If None, selects all components. - If int or list of int, selects the specified components (1-based indexing). Returns ------- np.ndarray A 2D array containing the selected parameter values for the specified mean and error values for the fit. """ indices = np.array(selected_models) - 1 if selected_models is not None else slice(None) return par_array.reshape(-1, n_par)[indices]
[docs] def _plot_fit( x_values: np.ndarray, y_values: np.ndarray, parameters: np.ndarray, n_fits: int, class_name: str, _n_fitter: Callable, _n_plotter: Callable, show_individuals: bool = False, x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None, data_label: Optional[str] = None, fit_label: Optional[str] = None, axis: Optional[Axes] = None, ): """ Base function to plot the fitted models. Parameters ---------- x_values : array-like The x-axis values. y_values : array-like The observed data values corresponding to `x_values`. parameters : tuple or list The optimized parameters from the fitting process. n_fits : int The number of fits performed. class_name : str The name of the fitting model class used. _n_fitter : callable A function that evaluates the fitted model given `x_values` and `parameters`. _n_plotter : callable A function that plots individual model components if `show_individuals` is True. show_individuals: bool, optional Whether to show individually fitted models or not. x_label: str, optional The label for the x-axis. y_label: str, optional The label for the y-axis. title: str, optional The title for the plot. data_label: str, optional The label for the data. axis: Axes, optional Axes to plot instead of the entire figure. Defaults to None. Returns ------- plotter The plotter handle for the drawn plot. """ if parameters is None: raise RuntimeError("Fit not performed yet. Call fit() first.") if data_label is None: dl, tt = "Data", "Total fit" elif len(data_label) == 1 or isinstance(data_label, str): dl, tt = data_label, "Total fit" elif 1 < len(data_label) <= 2: dl, tt = data_label, fit_label else: raise ValueError() plotter = plot_xy( x_data=x_values, y_data=y_values, data_label=dl, axis=axis, plot_dictionary=LinePlot(alpha=0.75), ) plot_xy( x_data=x_values, y_data=_n_fitter(x_values, *parameters), x_label=x_label, y_label=y_label, plot_title=title, data_label=tt, plot_dictionary=LinePlot(color="k"), axis=plotter, ) if show_individuals: _n_plotter(plotter=plotter) plotter2: Axes = plotter[0] if isinstance(plotter, list) else plotter plotter2.set_xlabel(x_label if x_label else "X") plotter2.set_ylabel(y_label if y_label else "Y") plotter2.set_title(title if title else f"{n_fits} {class_name} fit") plt.tight_layout() return plotter2