"""Created on Aug 18 23:52:19 2024"""
__all__ = ["parameter_logic", "sanity_check", "_plot_fit"]
from typing import List, Tuple, Union, Optional, Callable
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from mpyez.backend.uPlotting import LinePlot # type: ignore
from mpyez.ezPlotting import plot_xy # type: ignore
from .. import OneDArray
# SAFEGUARD:
xy_tuple = Tuple[np.ndarray, np.ndarray]
indexType = Union[int, List[int], None]
[docs]
def sanity_check(x_values: OneDArray, y_values: OneDArray) -> Tuple[OneDArray, OneDArray]:
"""
Convert input lists to NumPy arrays if necessary.
Parameters
----------
x_values : list of float or np.ndarray
Input x-values that will be converted to a NumPy array if they are in list format.
y_values : list of float or np.ndarray
Input y-values that will be converted to a NumPy array if they are in list format.
Returns
-------
x_values : np.ndarray
The x-values as a NumPy array.
y_values : np.ndarray
The y-values as a NumPy array.
"""
x_values = np.asarray(a=x_values, dtype=float)
y_values = np.asarray(a=y_values, dtype=float)
return x_values, y_values
[docs]
def parameter_logic(par_array: OneDArray, n_par: int, selected_models) -> OneDArray:
"""
Extract parameter values from a given function based on the number of parameters per fit and selected indices.
Parameters
----------
par_array : np.ndarray
A 2D array where the first column contains the parameter values and the second contains its standard errors.
n_par : int
The number of parameters per fit (e.g., amplitude, mu, sigma, etc.).
selected_models : int, list of int, or None
Indices of model components to extract.
- If None, selects all components.
- If int or list of int, selects the specified components (1-based indexing).
Returns
-------
np.ndarray
A 2D array containing the selected parameter values for the specified mean and error values for the fit.
"""
indices = np.array(selected_models) - 1 if selected_models is not None else slice(None)
return par_array.reshape(-1, n_par)[indices]
[docs]
def _plot_fit(
x_values: OneDArray,
y_values: OneDArray,
parameters: OneDArray,
n_fits: int,
class_name: str,
_n_fitter: Callable,
_n_plotter: Callable,
show_individuals: bool = False,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
title: Optional[str] = None,
data_label: Optional[str] = None,
fit_label: Optional[str] = None,
axis: Optional[Axes] = None,
):
"""
Base function to plot the fitted models.
Parameters
----------
x_values : array-like
The x-axis values.
y_values : array-like
The observed data values corresponding to `x_values`.
parameters : tuple or list
The optimized parameters from the fitting process.
n_fits : int
The number of fits performed.
class_name : str
The name of the fitting model class used.
_n_fitter : callable
A function that evaluates the fitted model given `x_values` and `parameters`.
_n_plotter : callable
A function that plots individual model components if `show_individuals` is True.
show_individuals: bool, optional
Whether to show individually fitted models or not.
x_label: str, optional
The label for the x-axis.
y_label: str, optional
The label for the y-axis.
title: str, optional
The title for the plot.
data_label: str, optional
The label for the data.
axis: Axes, optional
Axes to plot instead of the entire figure. Defaults to None.
Returns
-------
plotter
The plotter handle for the drawn plot.
"""
if parameters is None:
raise RuntimeError("Fit not performed yet. Call fit() first.")
if data_label is None:
dl, tt = "Data", "Total fit"
elif len(data_label) == 1 or isinstance(data_label, str):
dl, tt = data_label, "Total fit"
elif 1 < len(data_label) <= 2:
dl, tt = data_label, fit_label
else:
raise ValueError()
plotter = plot_xy(x_data=x_values, y_data=y_values, data_label=dl, axis=axis, plot_dictionary=LinePlot(alpha=0.75))
plot_xy(
x_data=x_values,
y_data=_n_fitter(x_values, *parameters),
x_label=x_label,
y_label=y_label,
plot_title=title,
data_label=tt,
plot_dictionary=LinePlot(color="k"),
axis=plotter,
)
if show_individuals:
_n_plotter(plotter=plotter)
plotter2: Axes = plotter[0] if isinstance(plotter, list) else plotter
plotter2.set_xlabel(x_label if x_label else "X")
plotter2.set_ylabel(y_label if y_label else "Y")
plotter2.set_title(title if title else f"{n_fits} {class_name} fit")
plt.tight_layout()
return plotter2