"""Created on Jul 18 00:16:01 2024"""
from itertools import chain
from typing import Any, Iterable, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from mpyez.backend.uPlotting import LinePlot
from mpyez.ezPlotting import plot_xy
from scipy.optimize import Bounds, curve_fit
from ..utilities_f import parameter_logic
from ... import listOfTuplesOrArray, epsilon
[docs]
class BaseFitter:
"""The base class for multi-fitting functionality."""
def __init__(self, x_values: Union[list, np.ndarray], y_values: Union[list, np.ndarray],
max_iterations: int = 1000):
self.x_values = x_values
self.y_values = y_values
self.max_iterations = max_iterations
self.n_par = None
self.pn_par = self.n_par
self.sn_par = {}
self.n_fits = None
self.params = None
self.covariance = None
[docs]
def _adjust_parameters(self, p0: List[Tuple[float, ...]]):
"""
Adjust input parameters to include defaults for secondary parameters if missing.
Parameters
----------
p0: List[List[float]]
A list of initial guesses for the parameters.
Returns
-------
adjusted_p0: List[List[float]]
Adjusted parameter list with default values for missing secondary parameters.
"""
adjusted_p0 = []
for params in p0:
# Too few parameters
if len(params) < self.pn_par:
raise ValueError(f"Each parameter set must have at least {self.pn_par} primary parameters.")
primary_params = params[:self.pn_par]
provided_secondary_params = params[self.pn_par:]
secondary_params = dict(self.sn_par)
for key, value in zip(self.sn_par.keys(), provided_secondary_params):
secondary_params[key] = value
adjusted_params = list(primary_params) + list(secondary_params.values())
if len(adjusted_params) != self.n_par:
raise ValueError(f"Adjusted parameter set must have {self.n_par} total parameters.")
adjusted_p0.append(adjusted_params)
return adjusted_p0
[docs]
def _covariance(self):
"""
Store the covariance matrix of the fitted model.
Returns
-------
np.ndarray
An array containing the covariance matrix of the fitted model.
"""
if self.covariance is None:
raise RuntimeError("Fit not performed yet. Call fit() first.")
return self.covariance
[docs]
def _fit_preprocessing(self, p0: listOfTuplesOrArray, frozen: List[bool]):
"""
Process frozen parameters and adjust bounds.
Parameters
----------
p0: listOfTuplesOrArray
A list of initial guesses for the parameters of the models.
For example: [(1, 1, 0), (3, 3, 2)].
frozen: List[bool]
A list of booleans indicating which parameters are frozen.
For example: [False, False, True] for 3 parameters.
Returns
-------
Tuple[np.ndarray, np.ndarray, np.ndarray]
Adjusted lower and upper bounds, and flattened initial guesses.
"""
# Get initial boundaries
try:
lb, ub = self.fit_boundaries()
except NotImplementedError:
# if they're not implemented, self impose -inf + inf boundaries
lb = [-np.inf] * self.n_fits
ub = [np.inf] * self.n_fits
# Resize bounds to match total parameters
lb = np.resize(a=lb, new_shape=self.n_par * self.n_fits)
ub = np.resize(a=ub, new_shape=self.n_par * self.n_fits)
# Validate frozen length
if frozen is None:
frozen = [False] * self.n_par
if len(frozen) != self.n_par:
raise ValueError("The length of 'frozen' must match the number of parameters per model.")
# Repeat frozen mask for all models
frozen = frozen * self.n_fits
# Flatten initial guesses
p0_flat = np.array(p0).flatten()
# Adjust bounds for frozen parameters
for i, is_frozen in enumerate(frozen):
if is_frozen:
lb[i] = p0_flat[i] - epsilon
ub[i] = p0_flat[i] + epsilon
return lb, ub, p0_flat
[docs]
def _n_fitter(self, x: np.ndarray, *params: Iterable) -> np.ndarray:
r"""
Perform N-fitting by summing over multiple parameter sets.
Parameters
----------
x : np.ndarray
Input array of values for which the composite function is evaluated.
params : tuple
A tuple with all parameters to be fitted in an array of size (``self.n_fits, self.n_par``) where:
- ``self.n_fits`` is the number of individual fits.
- ``self.n_par`` is the number of parameters per fit.
Returns
-------
np.ndarray
An array containing the composite fitted values for the input ``x``.
"""
y = np.zeros_like(a=x, dtype=float)
params = np.reshape(a=np.array(params), newshape=(self.n_fits, self.n_par))
for par in params:
y += self.fitter(x=x, params=par)
return y
[docs]
def _params(self) -> np.ndarray:
r"""
Store the fitted parameters of the fitted model.
Returns
-------
np.ndarray
The parameters obtained after performing the fit.
Raises
------
RuntimeError
If the fit has not been performed yet (i.e., ``self.params`` is ``None``).
Notes
-----
This method assumes that the fitting process assigns values to ``self.params``.
"""
if self.params is None:
raise RuntimeError("Fit not performed yet. Call fit() first.")
return self.params
[docs]
def _plot_individual_fitter(self, plotter):
r"""
Plot individual fits from the composite fitter.
Parameters
----------
plotter : matplotlib.axes.Axes
The axis object where the plots will be rendered.
Notes
-----
- `self.params` must contain the fitted parameters reshaped into (``self.n_fits``, ``self.n_par``).
- Each plot will be labeled with the class name and the index of the fit, along with the formatted parameters.
"""
x = self.x_values
params = np.reshape(a=self.params, newshape=(self.n_fits, self.n_par))
colors = plt.rcParams['axes.prop_cycle'].by_key()['color'][1:]
for i, par in enumerate(params):
color = colors[i % len(colors)]
plot_xy(x_data=x, y_data=self.fitter(x=x, params=par),
data_label=f'{self.__class__.__name__.replace("Fitter", "")} {i + 1}('
f'{", ".join(self._format_param(i) for i in par)})',
plot_dictionary=LinePlot(line_style='--', color=color), axis=plotter, x_label='', y_label='',
plot_title='')
[docs]
def _standard_errors(self) -> np.ndarray:
r"""
Store the standard errors of the fitted parameters.
Returns
-------
np.ndarray
An array containing the standard errors of the fitted parameters.
Raises
------
RuntimeError
If the fit has not been performed yet (i.e., ``self.covariance`` is ``None``).
"""
if self.covariance is None:
raise RuntimeError("Fit not performed yet. Call fit() first.")
return np.sqrt(np.diag(self.covariance))
[docs]
def dry_run(self, axis=None):
"""
Plot the x and y data for a quick visual inspection of the data.
Parameters
----------
axis:
The axis to plot the data on.
"""
plot_xy(x_data=self.x_values, y_data=self.y_values, axis=axis)
[docs]
def fit(self, p0: listOfTuplesOrArray, frozen: List[bool] = None):
"""
Fit the data.
Parameters
----------
p0: listOfTuplesOrArray
A list of initial guesses for the parameters of the models.
For example: [(1, 1, 0), (3, 3, 2)].
frozen: List[bool]
A list of booleans indicating whether each parameter is frozen.
For example: [False, False, True] for 3 parameters.
"""
self.n_fits = len(p0)
len_guess = len(list(chain(*p0)))
total_pars = self.n_par * self.n_fits
if len_guess != total_pars:
p0 = self._adjust_parameters(p0)
lb, ub, p0_flat = self._fit_preprocessing(p0=p0, frozen=frozen)
self.params, self.covariance, *_ = curve_fit(f=self._n_fitter,
xdata=self.x_values, ydata=self.y_values,
p0=p0_flat, maxfev=self.max_iterations,
bounds=Bounds(lb=lb, ub=ub))
[docs]
@staticmethod
def fit_boundaries():
"""Defines the distribution boundaries to be used by fitter."""
raise NotImplementedError("This method should be implemented by subclasses.")
[docs]
@staticmethod
def fitter(x: np.ndarray, params: Tuple[float, Any]):
"""
Fitter function for multi-fitting.
Parameters
----------
x: np.ndarray
The x-array on which the fitting is to be performed.
params: List[float]
A list of parameters to fit.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
[docs]
def get_fit_values(self) -> np.ndarray:
"""
Get the fitted values of the model.
Returns
-------
np.ndarray
An array of fitted values.
"""
if self.params is None:
raise RuntimeError('Fit not performed yet. Call fit() first.')
return self._n_fitter(self.x_values, self.params)
[docs]
def get_model_parameters(self, select: Tuple[int, Any] = None, errors: bool = False):
r"""
Extract specific parameter values or their uncertainties from the fitting process.
This method allows for retrieving the fitted parameters or their corresponding standard errors for specific
sub-models, or for all sub-models if no selection is provided.
Parameters
----------
select : list of int or None, optional
A list of indices specifying which sub-models to extract parameters for. Indexing starts at 1.
If ``None``, parameters for all sub-models are returned. Defaults to None.
errors : bool, optional
If ``True``, both the parameter values and their standard errors are returned.
Defaults to ``False``.
Returns
-------
np.ndarray or tuple of np.ndarray
* If ``errors`` is ``False``:
- A 2D array of shape `(n_parameters, selected_models)` with parameter values for the selected models.
* If ``errors`` is ``True``: A tuple of two 2D arrays:
- The first array contains the parameter values of shape `(n_parameters, selected_models)`.
- The second array contains the standard errors of the parameters, with the same shape.
Notes
-----
- The ``select`` parameter allows filtering by specific sub-model indices. If ``None``, use all sub-models.
- When ``errors`` is ``True``, both the parameter means and their uncertainties are returned as separate arrays.
Raises
------
ValueError
If the input ``select`` is not a valid list of indices or is incompatible with the model structure.
"""
parameter_mean = self.get_value_error_pair(mean_values=True, std_values=errors)
if not errors:
selected = parameter_logic(par_array=parameter_mean, n_par=self.n_par, selected_models=select)
return selected[:, range(self.n_par)].T
else:
par_list = parameter_mean.reshape(self.n_fits, self.n_par, 2)
mean = parameter_logic(par_array=par_list[:, :, 0].flatten(), n_par=self.n_par, selected_models=select)
std_ = parameter_logic(par_array=par_list[:, :, 1].flatten(), n_par=self.n_par, selected_models=select)
return mean[:, range(self.n_par)].T, std_[:, range(self.n_par)].T
[docs]
def get_value_error_pair(self, mean_values: bool = True, std_values: bool = False) -> np.ndarray:
r"""
Retrieve the value/error pairs for the fitted parameters.
This method provides the fitted parameter values and their corresponding standard errors as a combined array or
individually based on the input flags.
Parameters
----------
mean_values : bool, optional
If ``True``, return only the values of the fitted parameters.
Defaults to ``True``.
std_values : bool, optional
If ``True``, return only the standard errors of the fitted parameters.
Defaults to ``False``.
Returns
-------
np.ndarray
- If ``mean_values`` and ``std_values`` are both ``True``: A 2D array of shape (n_parameters, 2),
where each row is ``[value, error]``.
- If ``mean_values`` is ``True`` and ``std_values`` is ``False``: A 1D array of parameter values.
- If ``std_values`` is ``True`` and ``mean_values`` is ``False``: A 1D array of standard errors.
- If both flags are ``False``: An error message.
Raises
------
ValueError
If both ``mean_values`` and ``std_values`` are ``False``.
"""
pairs = np.column_stack([self._params(), self._standard_errors()])
if mean_values and std_values:
return pairs
elif mean_values:
return pairs[:, 0]
elif std_values:
return pairs[:, 1]
else:
raise ValueError("Either 'mean_values' or 'std_values' must be True.")
[docs]
def plot_fit(self, show_individual: bool = False,
x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None,
data_label: Optional[str] = None, axis: Optional[Axes] = None):
"""
Plot the fitted models.
Parameters
----------
show_individual: bool, optional
Whether to show individually fitted models or not.
x_label: str, optional
The label for the x-axis.
y_label: str, optional
The label for the y-axis.
title: str, optional
The title for the plot.
data_label: str, optional
The label for the data.
axis: Axes, optional
Axes to plot instead of the entire figure. Defaults to None.
Returns
-------
plotter
The plotter handle for the drawn plot.
"""
if self.params is None:
raise RuntimeError("Fit not performed yet. Call fit() first.")
plotter = plot_xy(x_data=self.x_values, y_data=self.y_values,
data_label=data_label if data_label else 'Data', axis=axis)
plot_xy(x_data=self.x_values, y_data=self._n_fitter(self.x_values, *self.params),
x_label=x_label, y_label=y_label, plot_title=title, data_label='Total Fit',
plot_dictionary=LinePlot(color='k'), axis=plotter)
if show_individual:
self._plot_individual_fitter(plotter=plotter)
plotter.set_xlabel(x_label if x_label else 'X')
plotter.set_ylabel(y_label if y_label else 'Y')
plotter.set_title(title if title else f'{self.n_fits} {self.__class__.__name__} fit')
plt.tight_layout()
return plotter