"""Created on Aug 10 23:08:38 2024"""
import itertools
from typing import Optional, Tuple, Union, List, Callable
import matplotlib.pyplot as plt
import numpy as np
from mpyez.backend.uPlotting import LinePlot
from mpyez.ezPlotting import plot_xy
from scipy.optimize import curve_fit
from .utilities_f import sanity_check
from .. import GAUSSIAN, LAPLACE, LINE, LOG_NORMAL, NORMAL, SKEW_NORMAL
from ..distributions import GaussianDistribution, LaplaceDistribution, line, LogNormalDistribution, \
SkewNormalDistribution
class _Line:
"""
Helper class for the line fitting function.
This class is intended for internal use only.
Provides a wrapper for evaluating a linear function with a given slope and intercept.
"""
def __init__(self, slope: float, intercept: float):
self.slope = slope
self.intercept = intercept
def pdf(self, x: np.ndarray) -> np.ndarray:
"""
Calculates the value of the line function.
Parameters
----------
x: np.ndarray
The input array to evaluate the line function.
Returns
-------
np.ndarray
The value of the line function for the given slope and intercept.
"""
return line(x=x, slope=self.slope, intercept=self.intercept)
model_dict = {LINE: [_Line, 2],
GAUSSIAN: [GaussianDistribution, 3],
LOG_NORMAL: [LogNormalDistribution, 3],
SKEW_NORMAL: [SkewNormalDistribution, 4],
LAPLACE: [LaplaceDistribution, 3]}
[docs]
class MixedDataFitter:
r"""
Class to fit a mixture of different models to data.
:param x_values: The x-values for the data.
:param y_values: The y-values for the data.
:param model_list: List of models to fit (e.g., `LINE`, `GAUSSIAN`, `LOG_NORMAL`)
:param max_iterations: The maximum number of iterations for fitting procedure.
"""
def __init__(self, x_values: Union[List, np.ndarray], y_values: Union[List, np.ndarray],
model_list: List[str], max_iterations: int = 1000):
x_values, y_values = sanity_check(x_values=x_values, y_values=y_values)
self.x_values = x_values
self.y_values = y_values
self.model_list = model_list
self.max_iterations = max_iterations
self.params = None
self.covariance = None
# Validate the model list and create the model function
self._validate_models()
self.model_function = self._create_model_function()
def __repr__(self):
return (f"{self.__class__.__name__}(x_values={self.x_values}, y_values={self.y_values}, "
f"model_list={self.model_list}, max_iterations={self.max_iterations})")
[docs]
def _create_model_function(self) -> Callable:
"""
Creates a composite model function based on the specified models.
:return: A composite model for fitting.
"""
def _composite_model(x: np.ndarray, *params) -> np.ndarray:
"""
Compute the composite model.
Parameters
----------
x : np.ndarray
The x-values where the model is evaluated.
params : tuple
Parameters for the model components.
Returns
-------
y : np.ndarray
The computed y-values from the composite model.
"""
y = np.zeros_like(x, dtype=float)
param_index = 0
for model in self.model_list:
func, n_par = model_dict[model]
y += func(*params[param_index:param_index + n_par]).pdf(x=x)
param_index += n_par
return y
return _composite_model
[docs]
def _expected_param_count(self) -> int:
"""
Calculates the expected number of parameters based on the model list.
:return: The number of parameters.
"""
count = 0
for model in self.model_list:
_, n_par = model_dict[model]
count += n_par
return count
[docs]
def _get_bounds(self) -> Tuple[np.ndarray, np.ndarray]:
"""
Sets the bounds for each parameter based on the model list.
:returns: Lower and upper bounds for the parameters.
"""
lower_bounds = []
upper_bounds = []
for model in self.model_list:
if model in [GAUSSIAN, NORMAL, LOG_NORMAL, LAPLACE]:
lower_bounds.extend([0, -np.inf, 0])
upper_bounds.extend([np.inf, np.inf, np.inf])
elif model == LINE:
lower_bounds.extend([-np.inf, -np.inf])
upper_bounds.extend([np.inf, np.inf])
elif model == SKEW_NORMAL:
lower_bounds.extend([0, -np.inf, -np.inf, 0])
upper_bounds.extend([np.inf, np.inf, np.inf, np.inf])
return np.array(lower_bounds), np.array(upper_bounds)
[docs]
def _plot_individual_fitter(self, plotter):
"""
Plot the individual fitters function.
:param plotter: The plotting axis object
"""
x = self.x_values
colors = plt.rcParams['axes.prop_cycle'].by_key()['color'][1:]
param_index = 0
for i, model in enumerate(self.model_list):
color = colors[i % len(colors)]
m_, p_ = model_dict[model]
pars = self.params[param_index:param_index + p_]
y_component = m_(*pars).pdf(x)
plot_xy(x_data=x, y_data=y_component,
x_label='', y_label='', plot_title='',
data_label=f'{model.capitalize()} {i + 1}({", ".join(self.format_param(i) for i in pars)})',
plot_dictionary=LinePlot(line_style='--', color=color),
axis=plotter)
param_index += p_
[docs]
def _validate_models(self):
"""
Validate the models in the model list.
:raise ValueError: If any model in the model list is not recognized.
"""
allowed_models = {GAUSSIAN, LINE, LOG_NORMAL, SKEW_NORMAL, LAPLACE}
if not all(model in allowed_models for model in self.model_list):
raise ValueError(f"All models must be one of {allowed_models}.")
[docs]
def fit(self, p0: Union[List, np.ndarray]):
"""
Fit the data.
:param p0: Initial guess for the fitted parameters.
:raises ValueError: If the length of the initial guess is not equal to the expected parameter count.
"""
p0 = list(itertools.chain.from_iterable(p0))
if len(p0) != self._expected_param_count():
raise ValueError(
f"Initial parameters length {len(p0)} does not match expected count {self._expected_param_count()}.")
self.params, self.covariance, *_ = curve_fit(f=self.model_function, xdata=self.x_values, ydata=self.y_values,
p0=p0, maxfev=self.max_iterations, bounds=self._get_bounds())
[docs]
def plot_fit(self, show_individuals: bool = False,
x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None,
data_label: Optional[str] = None, figure_size: tuple = (12, 6)) -> tuple:
"""
Plots the original data, fitted model, and optionally individual components.
:param show_individuals: Whether to plot individual fitted functions, by default False.
:param x_label: The label for the x-axis of the plot.
:param y_label: The label for the y-axis of the plot.
:param title: The title for the plot.
:param data_label: The label for the data to be plotted.
:param figure_size: The size of the figure. Default is (12,6).
:return: A tuple of figure and axes object for the drawn plot
:raises ValueError: Raised if the plotting function is called before the fitting is done.
"""
if self.y_values is None or self.params is None:
raise ValueError("Data must be fitted before plotting.")
fig, ax = plt.subplots(figsize=figure_size)
plotter = plot_xy(self.x_values, self.y_values, data_label=data_label if data_label else 'Data', axis=ax)
plot_xy(x_data=self.x_values, y_data=self.model_function(self.x_values, *self.params),
data_label='Total Fit', plot_dictionary=LinePlot(color='k'), axis=plotter)
if show_individuals:
self._plot_individual_fitter(plotter=plotter)
plotter.set_xlabel(x_label if x_label else 'X')
plotter.set_ylabel(y_label if y_label else 'Y')
plotter.set_title(title if title else f'{self.__class__.__name__} fit')
plotter.legend(loc='best')
fig.tight_layout()
return fig, plotter
[docs]
def get_fit_values(self) -> np.ndarray:
"""
Gets the y-values from the fitted model.
:return: The y-values from the fitted model
:raises ValueError: If the model has not been fitted yet.
"""
if self.params is None:
raise RuntimeError("Fit not performed yet. Call fit() first.")
return self.model_function(self.x_values, *self.params)
[docs]
def get_parameters(self, model: Optional[str] = None, get_errors: bool = False):
"""
Extracts parameters (and error) values for a specific model, or for all models if no model is specified.
:param model: Model name to extract parameters for. If unspecified, extracts parameters for all models.
Defaults to ``None``.
:param get_errors: If ``True``, includes the errors in the returned output. Defaults to ``False``.
:return: A dictionary containing:
- "parameters": Nested dictionary of parameter values for each model if `get_errors` is True.
- "errors": Nested dictionary of errors for each model (if `get_errors=True`).
Otherwise, returns just the parameters directly.
"""
if not get_errors:
parameters = self._parameter_extractor(self.params)
return parameters if model is None else parameters.get(model, [])
parameters = self._parameter_extractor(self.params)
errors = self._parameter_extractor(np.sqrt(np.diag(self.covariance)))
if model is None:
# Return combined dictionary for all models
return {"parameters": parameters, "errors": errors}
# Prepare output for a specific model
output = {"parameters": {}, "errors": {}}
keys = ["parameters", "errors"]
for temp_, key in zip([parameters, errors], keys):
par_dict = temp_.get(model, [])
_, n_par = model_dict[model]
flattened_list = [item for sublist in par_dict for item in sublist.tolist()]
output[key] = tuple(flattened_list[i::n_par] for i in range(n_par))
return output