from __future__ import annotations
import copy
import warnings
from collections import Counter
from collections.abc import Iterable, MutableSequence
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeAlias
import mkstd
import numpy as np
import pandas as pd
from pydantic import (
Field,
PrivateAttr,
RootModel,
ValidationInfo,
ValidatorFunctionWrapHandler,
model_validator,
)
from .constants import (
CRITERIA,
ESTIMATED_PARAMETERS,
ITERATION,
MODEL_HASH,
MODEL_ID,
MODEL_SUBSPACE_PETAB_PROBLEM,
PREDECESSOR_MODEL_HASH,
ROOT_PATH,
TYPE_PATH,
Criterion,
)
from .model import (
Model,
ModelHash,
VirtualModelBase,
)
if TYPE_CHECKING:
import petab
from .problem import Problem
# `Models` can be constructed from actual `Model`s,
# or `ModelHash`s, or the `str` of a model hash.
ModelLike: TypeAlias = Model | ModelHash | str
ModelsLike: TypeAlias = "Models" | Iterable[Model | ModelHash | str]
# Access a model by list index, model hash, slice of indices, model hash
# string, or an iterable of these things.
ModelIndex: TypeAlias = int | ModelHash | slice | str | Iterable
__all__ = [
"_ListDict",
"Models",
"models_from_yaml_list",
"models_to_yaml_list",
"ModelsStandard",
]
class _ListDict(RootModel, MutableSequence):
"""Acts like a ``list`` and a ``dict``.
Not all methods are implemented -- feel free to request anything that you
think makes sense for a ``list`` or ``dict`` object.
The context is a list of objects that may have some metadata (e.g. a hash)
associated with each of them. The objects can be operated on like a list,
or requested like a dict, by their metadata (hash).
Mostly based on ``UserList`` and ``UserDict``, but some methods are
currently not yet implemented.
https://github.com/python/cpython/blob/main/Lib/collections/__init__.py
The typing is currently based on PEtab Select objects. Hence, objects are
in ``_models``, and metadata (model hashes) are in ``_hashes``.
"""
root: list[Model] = Field(default_factory=list)
"""The list of models."""
_hashes: list[ModelHash] = PrivateAttr(default_factory=list)
"""The list of model hashes."""
_problem: Problem | None = PrivateAttr(default=None)
"""The PEtab Select problem that all models belong to.
If this is provided, then you can add models by hashes.
"""
@model_validator(mode="wrap")
def _check_kwargs(
kwargs: dict[str, list[ModelLike] | Problem] | list[ModelLike],
handler: ValidatorFunctionWrapHandler,
info: ValidationInfo,
) -> Models:
"""Handle `Models` creation from different sources."""
_models = []
_problem = None
if isinstance(kwargs, list):
_models = kwargs
elif isinstance(kwargs, dict):
# Identify the models
if "models" in kwargs and "root" in kwargs:
raise ValueError("Provide only one of `root` and `models`.")
_models = kwargs.get("models") or kwargs.get("root") or []
# Identify the PEtab Select problem
if "problem" in kwargs and "_problem" in kwargs:
raise ValueError(
"Provide only one of `problem` and `_problem`."
)
_problem = kwargs.get("problem") or kwargs.get("_problem")
# Distribute model constructor kwargs to each model dict
if model_kwargs := kwargs.get("model_kwargs"):
for _model_index, _model in enumerate(_models):
if not isinstance(_model, dict):
raise ValueError(
"`model_kwargs` are only intended to be used when "
"constructing models from a YAML file."
)
_models[_model_index] = {**_model, **model_kwargs}
models = handler(_models)
models._problem = _problem
return models
@model_validator(mode="after")
def _check_typing(self: RootModel) -> RootModel:
"""Fix model typing."""
models0 = self._models
self.root = []
# This also converts all model hashes into models.
self.extend(models0)
return self
@property
def _models(self) -> list[Model]:
return self.root
def __repr__(self) -> str:
"""Get the model hashes that can regenerate these models.
N.B.: some information, e.g. criteria, will be lost if the hashes are
used to reproduce the set of models.
"""
return repr(self._hashes)
# skipped __lt__, __le__
def __eq__(self, other) -> bool:
other_hashes = Models(models=other)._hashes
same_length = len(self._hashes) == len(other_hashes)
same_hashes = set(self._hashes) == set(other_hashes)
return same_length and same_hashes
# skipped __gt__, __ge__, __cast
def __contains__(self, item: ModelLike) -> bool:
match item:
case Model():
return item in self._models
case ModelHash() | str():
return item in self._hashes
case VirtualModelBase():
return False
case _:
raise TypeError(f"Unexpected type: `{type(item)}`.")
def __len__(self) -> int:
return len(self._models)
def __getitem__(
self, key: ModelIndex | Iterable[ModelIndex]
) -> Model | Models:
try:
match key:
case int():
return self._models[key]
case ModelHash() | str():
return self._models[self._hashes.index(key)]
case slice():
return self.__class__(self._models[key])
case Iterable():
# TODO sensible to yield here?
return [self[key_] for key_ in key]
case _:
raise TypeError(f"Unexpected type: `{type(key)}`.")
except ValueError as err:
raise KeyError from err
def _model_like_to_model(self, model_like: ModelLike) -> Model:
"""Get the model that corresponds to a model-like object.
Args:
model_like:
Something that uniquely identifies a model; a model or a model
hash.
Returns:
The model.
"""
match model_like:
case Model():
model = model_like
case ModelHash() | str():
model = self._problem.model_hash_to_model(model_like)
case _:
raise TypeError(f"Unexpected type: `{type(model_like)}`.")
return model
def __setitem__(self, key: ModelIndex, item: ModelLike) -> None:
match key:
case int():
pass
case ModelHash() | str():
if key in self._hashes:
key = self._hashes.index(key)
else:
key = len(self)
case slice():
for key_, item_ in zip(
range(*key.indices(len(self))), item, strict=True
):
self[key_] = item_
case Iterable():
for key_, item_ in zip(key, item, strict=True):
self[key_] = item_
case _:
raise TypeError(f"Unexpected type: `{type(key)}`.")
item = self._model_like_to_model(model_like=item)
if key < len(self):
self._models[key] = item
self._hashes[key] = item.hash
else:
# Key doesn't exist, e.g., instead of
# models[1] = model1
# the user did something like
# models[model1_hash] = model1
# to add a new model.
self.append(item)
def _update(self, index: int, item: ModelLike) -> None:
"""Update the models by adding a new model, with possible replacement.
If the instance contains a model with a matching hash, that model
will be replaced.
Args:
index:
The index where the model will be inserted, if it doesn't
already exist.
item:
A model or a model hash.
"""
model = self._model_like_to_model(item)
if model.hash in self:
warnings.warn(
(
f"A model with hash `{model.hash}` already exists "
"in this collection of models. The previous model will be "
"overwritten."
),
RuntimeWarning,
stacklevel=2,
)
self[model.hash] = model
else:
self._models.insert(index, None)
self._hashes.insert(index, None)
# Re-use __setitem__ logic
self[index] = item
def __delitem__(self, key: ModelIndex) -> None:
try:
match key:
case ModelHash() | str():
key = self._hashes.index(key)
case slice():
for key_ in range(*key.indices(len(self))):
del self[key_]
case Iterable():
for key_ in key:
del self[key_]
case _:
raise TypeError(f"Unexpected type: `{type(key)}`.")
except ValueError as err:
raise KeyError from err
del self._models[key]
del self._hashes[key]
def __add__(
self, other: ModelLike | ModelsLike, left: bool = True
) -> Models:
match other:
case Models():
new_models = other._models
case Model():
new_models = [other]
case ModelHash() | str():
# Assumes the models belong to the same PEtab Select problem.
new_models = [self._problem.model_hash_to_model(other)]
case Iterable():
# Assumes the models belong to the same PEtab Select problem.
new_models = Models(
models=other, _problem=self._problem
)._models
case _:
raise TypeError(f"Unexpected type: `{type(other)}`.")
models = self._models + new_models
if not left:
models = new_models + self._models
return Models(models=models, _problem=self._problem)
def __radd__(self, other: ModelLike | ModelsLike) -> Models:
return self.__add__(other=other, left=False)
def __iadd__(self, other: ModelLike | ModelsLike) -> Models:
return self.__add__(other=other)
# skipped __mul__, __rmul__, __imul__
def __copy__(self) -> Models:
return Models(models=self._models, _problem=self._problem)
def append(self, item: ModelLike) -> None:
self._update(index=len(self), item=item)
def insert(self, index: int, item: ModelLike):
self._update(index=len(self), item=item)
# def pop(self, index: int = -1):
# model = self._models[index]
# # Re-use __delitem__ logic
# del self[index]
# return model
# def remove(self, item: ModelLike):
# # Re-use __delitem__ logic
# if isinstance(item, Model):
# item = item.hash
# del self[item]
# skipped clear, copy, count
def index(self, item: ModelLike, *args) -> int:
if isinstance(item, Model):
item = item.hash
return self._hashes.index(item, *args)
# skipped reverse, sort
def extend(self, other: Iterable[ModelLike]) -> None:
# Re-use append and therein __setitem__ logic
for model_like in other:
self.append(model_like)
def __iter__(self):
return iter(self._models)
def __next__(self):
raise NotImplementedError
# `dict` methods.
def get(
self,
key: ModelIndex | Iterable[ModelIndex],
default: ModelLike | None = None,
) -> Model | Models:
try:
return self[key]
except KeyError:
return default
def values(self) -> Models:
"""Get the models. DEPRECATED."""
warnings.warn(
"`models.values()` is deprecated. Use `models` instead.",
DeprecationWarning,
stacklevel=2,
)
return self
[docs]
class Models(_ListDict):
"""A collection of models."""
[docs]
def set_problem(self, problem: Problem) -> None:
"""Set the PEtab Select problem for this set of models."""
self._problem = problem
[docs]
def lint(self):
"""Lint the models, e.g. check all hashes are unique.
Currently raises an exception when invalid.
"""
duplicates = [
model_hash
for model_hash, count in Counter(self._hashes).items()
if count > 1
]
if duplicates:
raise ValueError(
"Multiple models exist with the same hash. "
f"Model hashes: `{duplicates}`."
)
[docs]
@staticmethod
def from_yaml(
filename: TYPE_PATH,
model_subspace_petab_problem: petab.Problem = None,
problem: Problem = None,
) -> Models:
"""Load models from a YAML file.
Args:
filename:
Location of the YAML file.
model_subspace_petab_problem:
A preloaded copy of the PEtab problem. N.B.:
all models should share the same PEtab problem if this is
provided (e.g. all models belong to the same model subspace,
or all model subspaces have the same
``model_subspace_petab_yaml`` in the model space file(s)).
problem:
The PEtab Select problem. N.B.: all models should belong to the
same PEtab Select problem if this is provided.
Returns:
The models.
"""
# Handle single-model files, for backwards compatibility.
try:
model = Model.from_yaml(
filename=filename,
model_subspace_petab_problem=model_subspace_petab_problem,
)
return Models([model])
except: # noqa: S110
pass
return ModelsStandard.load_data(
filename=filename,
_problem=problem,
model_kwargs={
ROOT_PATH: Path(filename).parent,
MODEL_SUBSPACE_PETAB_PROBLEM: model_subspace_petab_problem,
},
)
[docs]
def to_yaml(
self,
filename: TYPE_PATH,
relative_paths: bool = True,
) -> None:
"""Save models to a YAML file.
Args:
filename:
Location of the YAML file.
relative_paths:
Whether to rewrite the paths in each model (e.g. the path to the
model's PEtab problem) relative to the ``filename`` location.
"""
_models = self._models
if relative_paths:
root_path = Path(filename).parent
_models = copy.deepcopy(_models)
for _model in _models:
_model.set_relative_paths(root_path=root_path)
ModelsStandard.save_data(data=Models(_models), filename=filename)
[docs]
def get_criterion(
self,
criterion: Criterion,
as_dict: bool = False,
relative: bool = False,
) -> list[float] | dict[ModelHash, float]:
"""Get the criterion value for all models.
Args:
criterion:
The criterion.
as_dict:
Whether to return a dictionary, with model hashes for keys.
relative:
Whether to compute criterion values relative to the
smallest criterion value.
Returns:
The criterion values.
"""
result = [model.get_criterion(criterion=criterion) for model in self]
if relative:
result = list(np.array(result) - min(result))
if as_dict:
result = dict(zip(self._hashes, result, strict=False))
return result
def _getattr(
self,
attr: str,
key: Any = None,
use_default: bool = False,
default: Any = None,
) -> list[Any]:
"""Get an attribute of each model.
Args:
attr:
The name of the attribute (e.g. ``MODEL_ID``).
key:
The key of the attribute, if you want to further subset.
For example, if ``attr=ESTIMATED_PARAMETERS``, this could
be a specific parameter ID.
use_default:
Whether to use a default value for models that are missing
``attr`` or ``key``.
default:
Value to use for models that do not have ``attr`` or ``key``,
if ``use_default==True``.
Returns:
The list of attribute values.
"""
# FIXME remove when model is `dataclass`
values = []
for model in self:
try:
value = getattr(model, attr)
except:
if not use_default:
raise
value = default
if key is not None:
try:
value = value[key]
except:
if not use_default:
raise
value = default
values.append(value)
return values
@property
def df(self) -> pd.DataFrame:
"""Get a dataframe of model attributes."""
return pd.DataFrame(
{
MODEL_ID: self._getattr(MODEL_ID),
MODEL_HASH: self._getattr(MODEL_HASH),
Criterion.NLLH: self._getattr(
CRITERIA, Criterion.NLLH, use_default=True
),
Criterion.AIC: self._getattr(
CRITERIA, Criterion.AIC, use_default=True
),
Criterion.AICC: self._getattr(
CRITERIA, Criterion.AICC, use_default=True
),
Criterion.BIC: self._getattr(
CRITERIA, Criterion.BIC, use_default=True
),
ITERATION: self._getattr(ITERATION, use_default=True),
PREDECESSOR_MODEL_HASH: self._getattr(
PREDECESSOR_MODEL_HASH, use_default=True
),
ESTIMATED_PARAMETERS: self._getattr(
ESTIMATED_PARAMETERS, use_default=True
),
}
)
@property
def hashes(self) -> list[ModelHash]:
return self._hashes
[docs]
def models_from_yaml_list(
model_list_yaml: TYPE_PATH,
petab_problem: petab.Problem = None,
allow_single_model: bool = True,
problem: Problem = None,
) -> Models:
"""Deprecated. Use `petab_select.Models.from_yaml` instead."""
warnings.warn(
(
"Use `petab_select.Models.from_yaml` instead. "
"The `allow_single_model` argument is fixed to `True` now."
),
DeprecationWarning,
stacklevel=2,
)
return Models.from_yaml(
filename=model_list_yaml,
petab_problem=petab_problem,
problem=problem,
)
[docs]
def models_to_yaml_list(
models: Models,
output_yaml: TYPE_PATH,
relative_paths: bool = True,
) -> None:
"""Deprecated. Use `petab_select.Models.to_yaml` instead."""
warnings.warn(
"Use `petab_select.Models.to_yaml` instead.",
DeprecationWarning,
stacklevel=2,
)
Models(models=models).to_yaml(
filename=output_yaml, relative_paths=relative_paths
)
ModelsStandard = mkstd.YamlStandard(model=Models)