"""The model selection problem class."""
from __future__ import annotations
import copy
import warnings
from collections.abc import Callable, Iterable
from functools import partial
from os.path import relpath
from pathlib import Path
from typing import Annotated, Any
import mkstd
from pydantic import (
BaseModel,
Field,
PlainSerializer,
PrivateAttr,
ValidationInfo,
ValidatorFunctionWrapHandler,
model_validator,
)
from .analyze import get_best
from .candidate_space import CandidateSpace, method_to_candidate_space_class
from .constants import (
CRITERION,
PREDECESSOR_MODEL,
ROOT_PATH,
TYPE_PATH,
Criterion,
Method,
)
from .model import Model, ModelHash, default_compare
from .model_space import ModelSpace
from .models import Models
__all__ = [
"Problem",
"ProblemStandard",
]
class State(BaseModel):
"""Carry the state of applying model selection methods to the problem."""
models: Models = Field(default_factory=Models)
"""All calibrated models."""
iteration: int = Field(default=0)
"""The latest iteration of model selection."""
def increment_iteration(self) -> None:
"""Start the next iteration."""
self.iteration += 1
def reset(self) -> None:
"""Reset the state.
N.B.: does not reset all state information, which currently also exists
in other classes. Open a GitHub issue if you see unusual behavior. A
quick fix is to simply recreate the PEtab Select problem, and any other
objects that you use, e.g. the candidate space, whenever you need a
full reset.
https://github.com/PEtab-dev/petab_select/issues
"""
# FIXME state information is currently distributed across multiple
# classes, e.g. exclusions in model subspaces and candidate spaces.
# move all state information here.
self.models = Models()
self.iteration = 0
[docs]
class Problem(BaseModel):
"""The model selection problem."""
format_version: str = Field(default="1.0.0")
"""The file format version."""
criterion: Annotated[
Criterion, PlainSerializer(lambda x: x.value, return_type=str)
]
"""The criterion used to compare models."""
method: Annotated[
Method, PlainSerializer(lambda x: x.value, return_type=str)
]
"""The method used to search the model space."""
model_space_files: list[Path]
"""The files that define the model space."""
candidate_space_arguments: dict[str, Any] = Field(default_factory=dict)
"""Method-specific arguments.
These are forwarded to the candidate space constructor.
"""
_compare: Callable[[Model, Model], bool] = PrivateAttr(default=None)
"""The method by which models are compared."""
_state: State = PrivateAttr(default_factory=State)
@model_validator(mode="wrap")
def _check_input(
data: dict[str, Any] | Problem,
handler: ValidatorFunctionWrapHandler,
info: ValidationInfo,
) -> Problem:
if isinstance(data, Problem):
return data
compare = data.pop("compare", None) or data.pop("_compare", None)
if "state" in data:
data["_state"] = data["state"]
root_path = Path(data.pop(ROOT_PATH, ""))
problem = handler(data)
if compare is None:
compare = partial(default_compare, criterion=problem.criterion)
problem._compare = compare
problem._model_space = ModelSpace.load(
[
root_path / model_space_file
for model_space_file in problem.model_space_files
]
)
if PREDECESSOR_MODEL in problem.candidate_space_arguments:
problem.candidate_space_arguments[PREDECESSOR_MODEL] = (
root_path
/ problem.candidate_space_arguments[PREDECESSOR_MODEL]
)
return problem
@property
def state(self) -> State:
return self._state
[docs]
@staticmethod
def from_yaml(filename: TYPE_PATH) -> Problem:
"""Load a problem from a YAML file."""
problem = ProblemStandard.load_data(
filename=filename,
root_path=Path(filename).parent,
)
return problem
[docs]
def to_yaml(
self,
filename: str | Path,
) -> None:
"""Save a problem to a YAML file.
All paths will be made relative to the ``filename`` directory.
Args:
filename:
Location of the YAML file.
"""
root_path = Path(filename).parent
problem = copy.deepcopy(self)
problem.model_space_files = [
relpath(
model_space_file.resolve(),
start=root_path,
)
for model_space_file in problem.model_space_files
]
ProblemStandard.save_data(data=problem, filename=filename)
[docs]
def save(
self,
directory: str | Path,
) -> None:
"""Save all data (problem and model space) to a ``directory``.
Inside the directory, two files will be created:
(1) ``petab_select_problem.yaml``, and
(2) ``model_space.tsv``.
All paths will be made relative to the ``directory``.
"""
directory = Path(directory)
directory.mkdir(exist_ok=True, parents=True)
problem = copy.deepcopy(self)
problem.model_space_files = ["model_space.tsv"]
if PREDECESSOR_MODEL in problem.candidate_space_arguments:
problem.candidate_space_arguments[PREDECESSOR_MODEL] = relpath(
problem.candidate_space_arguments[PREDECESSOR_MODEL],
start=directory,
)
ProblemStandard.save_data(
data=problem, filename=directory / "petab_select_problem.yaml"
)
problem.model_space.save(filename=directory / "model_space.tsv")
@property
def compare(self):
return self._compare
@property
def model_space(self):
return self._model_space
def __str__(self):
return (
f"Method: {self.method}\n"
f"Criterion: {self.criterion}\n"
f"Format version: {self.format_version}\n"
)
[docs]
def exclude_models(
self,
models: Models,
) -> None:
"""Exclude models from the model space.
Args:
models:
The models.
"""
self.model_space.exclude_models(models)
[docs]
def exclude_model_hashes(
self,
model_hashes: Iterable[str],
) -> None:
"""Exclude models from the model space, by model hashes.
Args:
model_hashes:
The model hashes.
"""
# FIXME think about design here -- should we have exclude_models here?
warnings.warn(
"Use `exclude_models` instead. It also accepts hashes.",
DeprecationWarning,
stacklevel=2,
)
self.exclude_models(models=Models(models=model_hashes, problem=self))
[docs]
def get_best(
self,
models: Models,
# models: list[Model] | dict[ModelHash, Model] | None,
criterion: str | None | None = None,
compute_criterion: bool = False,
) -> Model:
"""Get the best model from a collection of models.
The best model is selected based on the selection problem's criterion.
Args:
models:
The models.
criterion:
The criterion. Defaults to the problem criterion.
compute_criterion:
Whether to try computing criterion values, if sufficient
information is available (e.g., likelihood and number of
parameters, to compute AIC).
Returns:
The best model.
"""
warnings.warn(
"Use ``petab_select.ui.get_best`` instead.",
DeprecationWarning,
stacklevel=2,
)
if criterion is None:
criterion = self.criterion
return get_best(
models=models,
criterion=criterion,
compare=self.compare,
compute_criterion=compute_criterion,
)
[docs]
def model_hash_to_model(self, model_hash: str | ModelHash) -> Model:
"""Get the model that matches a model hash.
Args:
model_hash:
The model hash.
Returns:
The model.
"""
return ModelHash.from_hash(model_hash).get_model(
petab_select_problem=self,
)
def get_model(
self, model_subspace_id: str, model_subspace_indices: list[int]
) -> Model:
return self.model_space.model_subspaces[
model_subspace_id
].indices_to_model(model_subspace_indices)
[docs]
def new_candidate_space(
self,
*args,
method: Method = None,
**kwargs,
) -> CandidateSpace:
"""Construct a new candidate space.
Args:
args, kwargs:
Arguments are passed to the candidate space constructor.
method:
The model selection method.
"""
if method is None:
method = self.method
kwargs[CRITERION] = kwargs.get(CRITERION, self.criterion)
candidate_space_class = method_to_candidate_space_class(method)
candidate_space_arguments = (
candidate_space_class.read_arguments_from_yaml_dict(
self.candidate_space_arguments
)
)
candidate_space_kwargs = {
**candidate_space_arguments,
**kwargs,
}
candidate_space = candidate_space_class(
*args,
**candidate_space_kwargs,
)
return candidate_space
ProblemStandard = mkstd.YamlStandard(model=Problem)