Source code for pymultifit.fitters.utilities_f

"""Created on Aug 18 23:52:19 2024"""

__all__ = ['parameter_logic', 'sanity_check', '_plot_fit']

from typing import List, Tuple, Union, Optional, Callable

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from mpyez.backend.uPlotting import LinePlot
from mpyez.ezPlotting import plot_xy

# SAFEGUARD:
xy_values = Union[List[float], np.ndarray]
xy_tuple = Tuple[np.ndarray, np.ndarray]
indexType = Union[int, List[int], None]


[docs] def sanity_check(x_values: xy_values, y_values: xy_values) -> xy_tuple: """ Convert input lists to NumPy arrays if necessary. Parameters ---------- x_values : list of float or np.ndarray Input x-values that will be converted to a NumPy array if they are in list format. y_values : list of float or np.ndarray Input y-values that will be converted to a NumPy array if they are in list format. Returns ------- x_values : np.ndarray The x-values as a NumPy array. y_values : np.ndarray The y-values as a NumPy array. """ if isinstance(x_values, list): x_values = np.array(x_values) if isinstance(y_values, list): y_values = np.array(y_values) return x_values, y_values
[docs] def parameter_logic(par_array: np.ndarray, n_par: int, selected_models: indexType) -> np.ndarray: """ Extract parameter values from a given function based on the number of parameters per fit and selected indices. Parameters ---------- par_array : np.ndarray A 2D array where the first column contains the parameter values and the second contains its standard errors. n_par : int The number of parameters per fit (e.g., amplitude, mu, sigma, etc.). selected_models : int, list of int, or None Indices of model components to extract. - If None, selects all components. - If int or list of int, selects the specified components (1-based indexing). Returns ------- np.ndarray A 2D array containing the selected parameter values for the specified mean and error values for the fit. """ indices = np.array(selected_models) - 1 if selected_models is not None else slice(None) return par_array.reshape(-1, n_par)[indices]
[docs] def _plot_fit(x_values: xy_values, y_values: xy_values, parameters: xy_values, n_fits: int, class_name: str, _n_fitter: Callable, _n_plotter: Callable, show_individuals: bool = False, x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None, data_label: Union[list[str], str] = None, axis: Optional[Axes] = None): """ Base function to plot the fitted models. Parameters ---------- x_values : array-like The x-axis values. y_values : array-like The observed data values corresponding to `x_values`. parameters : tuple or list The optimized parameters from the fitting process. n_fits : int The number of fits performed. class_name : str The name of the fitting model class used. _n_fitter : callable A function that evaluates the fitted model given `x_values` and `parameters`. _n_plotter : callable A function that plots individual model components if `show_individuals` is True. show_individuals: bool, optional Whether to show individually fitted models or not. x_label: str, optional The label for the x-axis. y_label: str, optional The label for the y-axis. title: str, optional The title for the plot. data_label: str, optional The label for the data. axis: Axes, optional Axes to plot instead of the entire figure. Defaults to None. Returns ------- plotter The plotter handle for the drawn plot. """ if parameters is None: raise RuntimeError("Fit not performed yet. Call fit() first.") if data_label is None: dl, tt = 'Data', 'Total fit' elif len(data_label) == 1 or isinstance(data_label, str): dl, tt = data_label, 'Total fit' elif 1 < len(data_label) <= 2: dl, tt = data_label else: raise ValueError() plotter = plot_xy(x_data=x_values, y_data=y_values, data_label=dl, axis=axis, plot_dictionary=LinePlot(alpha=0.75)) plot_xy(x_data=x_values, y_data=_n_fitter(x_values, *parameters), x_label=x_label, y_label=y_label, plot_title=title, data_label=tt, plot_dictionary=LinePlot(color='k'), axis=plotter) if show_individuals: _n_plotter(plotter=plotter) plotter.set_xlabel(x_label if x_label else 'X') plotter.set_ylabel(y_label if y_label else 'Y') plotter.set_title(title if title else f'{n_fits} {class_name} fit') plt.tight_layout() return plotter