"""Created on Jul 18 00:16:01 2024"""
from itertools import chain
from typing import Any, List, Optional, Tuple, Sequence
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from mpyez.backend.uPlotting import LinePlot # type: ignore
from mpyez.ezPlotting import plot_xy # type: ignore
from numpy.typing import NDArray
from scipy.optimize import Bounds, curve_fit
from ..utilities_f import parameter_logic, _plot_fit, sanity_check
from ... import epsilon, ListOrNdArray, Params_
[docs]
class BaseFitter:
"""The base class for multi-fitting functionality."""
def __init__(
self,
x_values: ListOrNdArray,
y_values: ListOrNdArray,
max_iterations: int = 1000,
):
x_values, y_values = sanity_check(x_values=x_values, y_values=y_values)
self.x_values: np.ndarray = x_values
self.y_values: np.ndarray = y_values
self.max_iterations = max_iterations
self.n_par: int = 0
self.pn_par: int = self.n_par
self.sn_par: dict = {}
self.n_fits: int = 0
[docs]
def _adjust_parameters(self, p0: Params_):
"""
Adjust input parameters to include defaults for secondary parameters if missing.
Parameters
----------
p0: List[List[float]]
A list of initial guesses for the parameters.
Returns
-------
adjusted_p0: List[List[float]]
Adjusted parameter list with default values for missing secondary parameters.
"""
adjusted_p0 = []
for params in p0:
# Too few parameters
if len(params) < self.pn_par:
raise ValueError(f"Each parameter set must have at least {self.pn_par} primary parameters.")
primary_params = params[: self.pn_par]
provided_secondary_params = params[self.pn_par:]
secondary_params = dict(self.sn_par)
for key, value in zip(self.sn_par.keys(), provided_secondary_params):
secondary_params[key] = value
adjusted_params = list(primary_params) + list(secondary_params.values())
if len(adjusted_params) != self.n_par:
raise ValueError(f"Adjusted parameter set must have {self.n_par} total parameters.")
adjusted_p0.append(adjusted_params)
return adjusted_p0
[docs]
def _covariance(self):
"""
Store the covariance matrix of the fitted model.
Returns
-------
np.ndarray
An array containing the covariance matrix of the fitted model.
"""
if self.covariance is None:
raise RuntimeError("Fit not performed yet. Call fit() first.")
return self.covariance
[docs]
def _fit_preprocessing(self, p0, frozen):
"""
Process frozen parameters and adjust bounds.
Parameters
----------
p0: Sequences_
A list of initial guesses for the parameters of the models.
For example, [(1, 1, 0), (3, 3, 2)].
frozen: List[bool]
A list of booleans indicating which parameters are frozen.
For example, [False, False, True] for 3 parameters.
Returns
-------
Tuple[np.ndarray, np.ndarray, np.ndarray]
Adjusted lower and upper bounds, and flattened initial guesses.
"""
# Get initial boundaries
try:
lb, ub = self.fit_boundaries()
except NotImplementedError:
# if they're not implemented, self-imposes -inf + inf boundaries
lb = np.repeat(a=-np.inf, repeats=self.n_fits)
ub = np.repeat(a=np.inf, repeats=self.n_fits)
# Resize bounds to match total parameters
lb = np.resize(a=lb, new_shape=self.n_par * self.n_fits)
ub = np.resize(a=ub, new_shape=self.n_par * self.n_fits)
# Validate frozen length
if frozen is None:
frozen = [False] * self.n_par
if len(frozen) != self.n_par:
raise ValueError("The length of 'frozen' must match the number of parameters per model.")
# Repeat frozen mask for all models
frozen = frozen * self.n_fits
# Flatten initial guesses
p0_flat = np.array(p0).flatten()
# Adjust bounds for frozen parameters
for i, is_frozen in enumerate(frozen):
if is_frozen:
lb[i] = p0_flat[i] - epsilon
ub[i] = p0_flat[i] + epsilon
return lb, ub, p0_flat
[docs]
def _n_fitter(self, x: NDArray, *params: Params_) -> np.ndarray:
r"""
Perform N-fitting by summing over multiple parameter sets.
Parameters
----------
x : np.ndarray
Input array of values for which the composite function is evaluated.
params : tuple
A tuple with all parameters to be fitted in an array of size (``self.n_fits, self.n_par``) where:
- ``self.n_fits`` is the number of individual fits.
- ``self.n_par`` is the number of parameters per fit.
Returns
-------
np.ndarray
An array containing the composite fitted values for the input ``x``.
"""
y = np.zeros_like(a=x, dtype=float)
parameters: np.ndarray = np.reshape(a=np.array(params), newshape=(self.n_fits, self.n_par))
for par in parameters:
y += self.fitter(x=x, params=par.tolist())
return y
[docs]
def _params(self) -> np.ndarray:
r"""
Store the fitted parameters of the fitted model.
Returns
-------
np.ndarray
The parameters obtained after performing the fit.
Raises
------
RuntimeError
If the fit has not been performed yet (i.e., ``self.params`` is ``None``).
Notes
-----
This method assumes that the fitting process assigns values to ``self.params``.
"""
if self.params is None:
raise RuntimeError("Fit not performed yet. Call fit() first.")
return self.params
[docs]
def _plot_individual_fitter(self, plotter):
r"""
Plot individual fits from the composite fitter.
Parameters
----------
plotter : matplotlib.axes.Axes
The axis object where the plots will be rendered.
Notes
-----
- ``self.params`` must contain the fitted parameters reshaped into (``self.n_fits``, ``self.n_par``).
- Each plot will be labeled with the class name and the index of the fit, along with the formatted parameters.
"""
x = self.x_values
params = np.reshape(a=self.params, newshape=(self.n_fits, self.n_par))
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"][1:]
for i, par in enumerate(params):
color = colors[i % len(colors)]
plot_xy(
x_data=x,
y_data=self.fitter(x=x, params=list(par)),
data_label=f"{self.__class__.__name__.replace('Fitter', '')} {i + 1}("
f"{', '.join(self._format_param(i) for i in par)})",
plot_dictionary=LinePlot(line_style="--", color=color),
axis=plotter,
x_label="",
y_label="",
plot_title="",
)
[docs]
def _standard_errors(self) -> NDArray:
r"""
Store the standard errors of the fitted parameters.
Returns
-------
np.ndarray
An array containing the standard errors of the fitted parameters.
Raises
------
RuntimeError
If the fit has not been performed yet (i.e., ``self.covariance`` is ``None``).
"""
if self.covariance is None:
raise RuntimeError("Fit not performed yet. Call fit() first.")
return np.sqrt(np.diag(self.covariance))
[docs]
def dry_run(self, axis=None):
"""
Plot the x and y data for a quick visual inspection of the data.
Parameters
----------
axis
The axis to plot the data on.
"""
plot_xy(x_data=self.x_values, y_data=self.y_values, axis=axis)
[docs]
def fit(self, p0: Params_, frozen: Optional[List[bool]] = None):
"""
Fit the data.
Parameters
----------
p0: Sequences_
A list of initial guesses for the parameters of the models.
For example, [(1, 1, 0), (3, 3, 2)].
frozen: List[bool]
A list of booleans indicating whether each parameter is frozen.
For example, [False, False, True] for 3 parameters.
"""
self.n_fits = len(p0)
len_guess = len(list(chain(*p0)))
total_pars = self.n_par * self.n_fits
if len_guess != total_pars:
p0 = self._adjust_parameters(p0)
lb, ub, p0_flat = self._fit_preprocessing(p0=p0, frozen=frozen)
self.params, self.covariance, *_ = curve_fit(
f=self._n_fitter,
xdata=self.x_values,
ydata=self.y_values,
p0=p0_flat,
maxfev=self.max_iterations,
bounds=Bounds(lb=lb, ub=ub),
)
[docs]
def _fit_boundaries(self) -> Tuple[Sequence[float], Sequence[float]]:
"""Defines the internal distribution boundaries to be used by fitter."""
ub = np.repeat(a=np.inf, repeats=self.n_par).tolist()
lb = np.repeat(a=-np.inf, repeats=self.n_par).tolist()
return lb, ub
[docs]
def fit_boundaries(self) -> Tuple[Sequence[float], Sequence[float]]:
"""Defines the distribution boundaries to be used by fitter."""
return self._fit_boundaries()
[docs]
@staticmethod
def fitter(x, params: Params_) -> np.ndarray:
"""
Fitter function for multi-fitting.
Parameters
----------
x: np.ndarray
The x-array on which the fitting is to be performed.
params: Params_
An array of parameters to fit.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
[docs]
def get_fitted_curve(self) -> np.ndarray:
"""
Get the fitted values of the model.
Returns
-------
np.ndarray
An array of fitted values.
"""
if self.params is None:
raise RuntimeError("Fit not performed yet. Call fit() first.")
return self._n_fitter(self.x_values, self.params)
[docs]
def get_model_parameters(self, select: Optional[Tuple[int, Any]] = None, errors: bool = False):
r"""
Extract specific parameter values or their uncertainties from the fitting process.
This method allows for retrieving the fitted parameters or their corresponding standard errors for specific
submodels, or for all sub-models if no selection is provided.
Parameters
----------
select : list of int or None, optional
A list of indices specifying which sub-models to extract parameters for. Indexing starts at 1.
If ``None``, parameters for all sub-models are returned. Defaults to None.
errors : bool, optional
If ``True``, both the parameter values and their standard errors are returned.
Defaults to ``False``.
Returns
-------
np.ndarray or tuple of np.ndarray
* If ``errors`` is ``False``:
- A 2D array of shape `(n_parameters, selected_models)` with parameter values for the selected models.
* If ``errors`` is ``True``: A tuple of two 2D arrays:
- The first array contains the parameter values of shape `(n_parameters, selected_models)`.
- The second array contains the standard errors of the parameters, with the same shape.
Notes
-----
- The ``select`` parameter allows filtering by specific sub-model indices. If ``None``, use all sub-models.
- When ``errors`` is ``True``, both the parameter means and their uncertainties are returned as separate arrays.
Raises
------
ValueError
If the input ``select`` is not a valid list of indices or is incompatible with the model structure.
"""
parameter_mean = self.get_value_error_pair(mean_values=True, std_values=errors)
if not errors:
selected = parameter_logic(
par_array=parameter_mean,
n_par=self.n_par,
selected_models=select,
)
return selected[:, range(self.n_par)].T
else:
par_list = parameter_mean.reshape(self.n_fits, self.n_par, 2)
mean = parameter_logic(
par_array=par_list[:, :, 0].flatten(),
n_par=self.n_par,
selected_models=select,
)
std_ = parameter_logic(
par_array=par_list[:, :, 1].flatten(),
n_par=self.n_par,
selected_models=select,
)
return mean[:, range(self.n_par)].T, std_[:, range(self.n_par)].T
[docs]
def get_value_error_pair(self, mean_values: bool = True, std_values: bool = False) -> np.ndarray:
r"""
Retrieve the value/error pairs for the fitted parameters.
This method provides the fitted parameter values and their corresponding standard errors as a combined array or
individually based on the input flags.
Parameters
----------
mean_values : bool, optional.
If ``True``, return only the values of the fitted parameters.
Defaults to ``True``.
std_values : bool, optional.
If ``True``, return only the standard errors of the fitted parameters.
Defaults to ``False``.
Returns
-------
np.ndarray
- If ``mean_values`` and ``std_values`` are both ``True``: A 2D array of shape (n_parameters, 2),
where each row is ``[value, error]``.
- If ``mean_values`` is ``True`` and ``std_values`` is ``False``: A 1D array of parameter values.
- If ``std_values`` is ``True`` and ``mean_values`` is ``False``: A 1D array of standard errors.
- If both flags are ``False``: An error message.
Raises
------
ValueError
If both ``mean_values`` and ``std_values`` are ``False``.
"""
pairs: np.ndarray = np.column_stack([self._params(), self._standard_errors()])
if mean_values and std_values:
return pairs
elif mean_values:
return pairs[:, 0]
elif std_values:
return pairs[:, 1]
else:
raise ValueError("Either 'mean_values' or 'std_values' must be True.")
[docs]
def plot_fit(
self,
show_individuals: bool = False,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
data_label: Optional[str] = None,
fit_label: Optional[str] = None,
title: Optional[str] = None,
axis: Optional[Axes] = None,
):
"""
Plot the fitted models.
Parameters
----------
show_individuals: bool, optional.
Whether to show individually fitted models or not.
x_label: str, optional
The label for the x-axis.
y_label: str, optional
The label for the y-axis.
title: str, optional
The title for the plot.
data_label: str, optional
The label for the data.
axis: Axes, optional
Axes to plot instead of the entire figure. Defaults to None.
Returns
-------
plotter
The plotter handle for the drawn plot.
"""
return _plot_fit(
x_values=self.x_values,
y_values=self.y_values,
parameters=self.params,
n_fits=self.n_fits,
class_name=self.__class__.__name__,
_n_fitter=self._n_fitter,
_n_plotter=self._plot_individual_fitter,
show_individuals=show_individuals,
x_label=x_label,
y_label=y_label,
title=title,
data_label=data_label,
fit_label=fit_label,
axis=axis,
)