"""PEtab wrappers for JAX models.""" ""
import logging
import os
import re
import shutil
from collections.abc import Callable, Iterable, Sized
from numbers import Number
from pathlib import Path
import diffrax
import equinox as eqx
import jax.lax
import jax.numpy as jnp
import jaxtyping as jt
import numpy as np
import optimistix
import pandas as pd
import petab.v1 as petabv1
import petab.v2 as petabv2
from optimistix import AbstractRootFinder
from amici import _module_from_path
from amici.logging import get_logger
from amici.sim._parameter_mapping import ParameterMappingForCondition
from amici.sim.jax.model import JAXModel, ReturnValue
DEFAULT_CONTROLLER_SETTINGS = {
"atol": 1e-8,
"rtol": 1e-8,
"pcoeff": 0.4,
"icoeff": 0.3,
"dcoeff": 0.0,
}
DEFAULT_ROOT_FINDER_SETTINGS = {
"atol": 1e-12,
"rtol": 1e-12,
}
SCALE_TO_INT = {
petabv2.C.LIN: 0,
petabv2.C.LOG: 1,
petabv2.C.LOG10: 2,
}
logger = get_logger(__name__, logging.WARNING)
def jax_unscale(
parameter: jnp.float_,
scale_str: str,
) -> jnp.float_:
"""Unscale parameter according to ``scale_str``.
Arguments:
parameter:
Parameter to be unscaled.
scale_str:
One of ``petabv2.C.LIN``, ``petabv2.C.LOG``, ``petabv2.C.LOG10``.
Returns:
The unscaled parameter.
"""
if scale_str == petabv2.C.LIN or not scale_str:
return parameter
if scale_str == petabv2.C.LOG:
return jnp.exp(parameter)
if scale_str == petabv2.C.LOG10:
return jnp.power(10, parameter)
raise ValueError(f"Invalid parameter scaling: {scale_str}")
# IDEA: Implement this class in petab-sciml instead?
class HybridProblem(petabv1.Problem):
hybridization_df: pd.DataFrame
def __init__(self, petab_problem: petabv1.Problem):
self.__dict__.update(petab_problem.__dict__)
self.hybridization_df = _get_hybridization_df(petab_problem)
class HybridV2Problem(petabv2.Problem):
hybridization_df: pd.DataFrame
extensions_config: dict
def __init__(self, petab_problem: petabv2.Problem):
if not hasattr(petab_problem, "extensions_config"):
self.extensions_config = {}
self.__dict__.update(petab_problem.__dict__)
self.hybridization_df = _get_hybridization_df(petab_problem)
def _get_hybridization_df(petab_problem):
if not hasattr(petab_problem, "extensions_config"):
return None
if "sciml" in petab_problem.extensions_config:
hybridizations = [
pd.read_csv(hf, sep="\t", index_col=0)
for hf in petab_problem.extensions_config["sciml"][
"hybridization_files"
]
]
hybridization_df = pd.concat(hybridizations)
return hybridization_df
def _get_hybrid_petab_problem(
petab_problem: petabv1.Problem | petabv2.Problem,
):
if isinstance(petab_problem, petabv2.Problem):
return HybridV2Problem(petab_problem)
return HybridProblem(petab_problem)
[docs]
class JAXProblem(eqx.Module):
"""
PEtab problem wrapper for JAX models.
:ivar parameters:
Values for the model parameters. Do not change dimensions, values may be changed during, e.g. model training.
:ivar model:
JAXModel instance to use for simulation.
:ivar _parameter_mappings:
:class:`ParameterMappingForCondition` instances for each simulation condition.
:ivar _measurements:
Preprocessed arrays for each simulation condition.
:ivar _petab_problem:
PEtab problem to simulate.
"""
parameters: jnp.ndarray
model: JAXModel
simulation_conditions: tuple[tuple[str, ...], ...]
_parameter_mappings: dict[str, ParameterMappingForCondition]
_ts_dyn: np.ndarray
_ts_posteq: np.ndarray
_my: np.ndarray
_iys: np.ndarray
_iy_trafos: np.ndarray
_ts_masks: np.ndarray
_op_numeric: np.ndarray
_op_mask: np.ndarray
_op_indices: np.ndarray
_np_numeric: np.ndarray
_np_mask: np.ndarray
_np_indices: np.ndarray
_petab_measurement_indices: np.ndarray
_petab_problem: petabv1.Problem | HybridProblem | petabv2.Problem
[docs]
def __init__(
self, model: JAXModel, petab_problem: petabv1.Problem | petabv2.Problem
):
"""
Initialize a JAXProblem instance with a model and a PEtab problem.
:param model:
JAXModel instance to use for simulation.
:param petab_problem:
PEtab problem to simulate.
"""
if isinstance(petab_problem, petabv1.Problem):
raise TypeError(
"JAXProblem does not support PEtab v1 problems. Upgrade the problem to PEtab v2."
)
petab_problem = add_default_experiment_names_to_v2_problem(
petab_problem
)
scs = get_simulation_conditions_v2(petab_problem)
self.simulation_conditions = scs.conditionId.to_list()
self._petab_problem = _get_hybrid_petab_problem(petab_problem)
self.parameters, self.model = (
self._initialize_model_with_nominal_values(model)
)
self._parameter_mappings = None
(
self._ts_dyn,
self._ts_posteq,
self._my,
self._iys,
self._iy_trafos,
self._ts_masks,
self._petab_measurement_indices,
self._op_numeric,
self._op_mask,
self._op_indices,
self._np_numeric,
self._np_mask,
self._np_indices,
) = self._get_measurements(scs)
[docs]
def save(self, directory: Path):
"""
Save the problem to a directory.
:param directory:
Directory to save the problem to.
"""
self._petab_problem.to_files(
prefix_path=directory,
model_file="model",
condition_file="conditions.tsv",
measurement_file="measurements.tsv",
parameter_file="parameters.tsv",
observable_file="observables.tsv",
yaml_file="problem.yaml",
)
shutil.copy(self.model.jax_py_file, directory / "jax_py_file.py")
with open(directory / "parameters.pkl", "wb") as f:
eqx.tree_serialise_leaves(f, self)
[docs]
@classmethod
def load(cls, directory: Path):
"""
Load a problem from a directory.
:param directory:
Directory to load the problem from.
:return:
Loaded problem instance.
"""
petab_problem = petabv2.Problem.from_yaml(
directory / "problem.yaml",
)
model = _module_from_path("jax", directory / "jax_py_file.py").Model()
problem = cls(model, petab_problem)
with open(directory / "parameters.pkl", "rb") as f:
return eqx.tree_deserialise_leaves(f, problem)
def _get_measurements(
self, simulation_conditions: pd.DataFrame
) -> tuple[
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
]:
"""
Get measurements for the model based on the provided simulation conditions.
:param simulation_conditions:
Simulation conditions to create parameter mappings for. Same format as returned by
:meth:`petabv1.Problem.get_simulation_conditions_from_measurement_df`.
:return:
tuple of padded
- dynamic time points
- post-equilibrium time points
- measurements
- observable indices
- observable transformations indices
- measurement masks
- data indices (index in petab measurement dataframe).
- numeric values for observable parameter overrides
- non-numeric mask for observable parameter overrides
- parameter indices (problem parameters) for observable parameter overrides
- numeric values for noise parameter overrides
- non-numeric mask for noise parameter overrides
- parameter indices (problem parameters) for noise parameter overrides
"""
measurements = dict()
petab_indices = dict()
n_pars = dict()
for col in [
petabv2.C.OBSERVABLE_PARAMETERS,
petabv2.C.NOISE_PARAMETERS,
]:
n_pars[col] = 0
if col in self._petab_problem.measurement_df:
if pd.api.types.is_numeric_dtype(
self._petab_problem.measurement_df[col].dtype
):
n_pars[col] = 1 - int(
self._petab_problem.measurement_df[col].isna().all()
)
else:
n_pars[col] = (
self._petab_problem.measurement_df[col]
.str.split(petabv2.C.PARAMETER_SEPARATOR)
.apply(
lambda x: (
len(x)
if isinstance(x, Sized)
else 1 - int(pd.isna(x))
)
)
.max()
)
for _, simulation_condition in simulation_conditions.iterrows():
if (
"preequilibration"
in simulation_condition[petabv2.C.CONDITION_ID]
):
continue
if isinstance(self._petab_problem, HybridV2Problem):
query = " & ".join(
[
f"{k} == '{v}'"
if isinstance(v, str)
else f"{k} == {v}"
for k, v in simulation_condition.items()
if k != petabv2.C.CONDITION_ID
]
)
else:
query = " & ".join(
[f"{k} == '{v}'" for k, v in simulation_condition.items()]
)
m = self._petab_problem.measurement_df.query(query).sort_values(
by=petabv2.C.TIME
)
ts = m[petabv2.C.TIME]
ts_dyn = ts[np.isfinite(ts)]
ts_posteq = ts[np.logical_not(np.isfinite(ts))]
index = pd.concat([ts_dyn, ts_posteq]).index
ts_dyn = ts_dyn.values
ts_posteq = ts_posteq.values
my = m[petabv2.C.MEASUREMENT].values
iys = np.array(
[
self.model.observable_ids.index(oid)
for oid in m[petabv2.C.OBSERVABLE_ID].values
]
)
if (
petabv2.C.NOISE_DISTRIBUTION
in self._petab_problem.observable_df
):
iy_trafos = np.array(
[
SCALE_TO_INT[petabv2.C.LOG]
if obs.noise_distribution == petabv2.C.LOG_NORMAL
else SCALE_TO_INT[petabv2.C.LIN]
for obs in self._petab_problem.observables
]
)
else:
iy_trafos = np.zeros_like(iys)
parameter_overrides_par_indices = dict()
parameter_overrides_numeric_vals = dict()
parameter_overrides_mask = dict()
def get_parameter_override(x):
if (
x in self._petab_problem.parameter_df.index
and not self._petab_problem.parameter_df.loc[
x, petabv2.C.ESTIMATE
]
):
return self._petab_problem.parameter_df.loc[
x, petabv2.C.NOMINAL_VALUE
]
return x
for col in [
petabv2.C.OBSERVABLE_PARAMETERS,
petabv2.C.NOISE_PARAMETERS,
]:
if col not in m or m[col].isna().all() or all(m[col] == ""):
mat_numeric = jnp.ones((len(m), n_pars[col]))
par_mask = np.zeros_like(mat_numeric, dtype=bool)
par_index = np.zeros_like(mat_numeric, dtype=int)
elif pd.api.types.is_numeric_dtype(m[col].dtype):
mat_numeric = np.expand_dims(m[col].values, axis=1)
par_mask = np.zeros_like(mat_numeric, dtype=bool)
par_index = np.zeros_like(mat_numeric, dtype=int)
else:
split_vals = m[col].str.split(
petabv2.C.PARAMETER_SEPARATOR
)
list_vals = split_vals.apply(
lambda x: (
[get_parameter_override(y) for y in x]
if isinstance(x, list)
else []
if pd.isna(x)
else [x]
) # every string gets transformed to lists, so this is already a float
)
vals = list_vals.apply(
lambda x: np.pad(
x,
(0, n_pars[col] - len(x)),
mode="constant",
constant_values=1.0,
)
)
mat = np.stack(vals)
# deconstruct such that we can reconstruct mapped parameter overrides via vectorized operations
# mat = np.where(par_mask, map(lambda ip: p.at[ip], par_index), mat_numeric)
par_index = np.vectorize(
lambda x: (
self.parameter_ids.index(x)
if x in self.parameter_ids
else -1
)
)(mat)
# map out numeric values
par_mask = par_index != -1
# remove non-numeric values
mat[par_mask] = 0.0
mat_numeric = mat.astype(float)
# replace dummy index with some valid index
par_index[~par_mask] = 0
parameter_overrides_numeric_vals[col] = mat_numeric
parameter_overrides_mask[col] = par_mask
parameter_overrides_par_indices[col] = par_index
measurements[tuple(simulation_condition)] = (
ts_dyn, # 0
ts_posteq, # 1
my, # 2
iys, # 3
iy_trafos, # 4
parameter_overrides_numeric_vals[
petabv2.C.OBSERVABLE_PARAMETERS
], # 5
parameter_overrides_mask[petabv2.C.OBSERVABLE_PARAMETERS], # 6
parameter_overrides_par_indices[
petabv2.C.OBSERVABLE_PARAMETERS
], # 7
parameter_overrides_numeric_vals[
petabv2.C.NOISE_PARAMETERS
], # 8
parameter_overrides_mask[petabv2.C.NOISE_PARAMETERS], # 9
parameter_overrides_par_indices[
petabv2.C.NOISE_PARAMETERS
], # 10
)
petab_indices[tuple(simulation_condition)] = tuple(index.tolist())
# compute maximum lengths
n_ts_dyn = max(len(mv[0]) for mv in measurements.values())
n_ts_posteq = max(len(mv[1]) for mv in measurements.values())
# pad with last value and stack
ts_dyn = np.stack(
[
np.pad(mv[0], (0, n_ts_dyn - len(mv[0])), mode="edge")
for mv in measurements.values()
]
)
ts_posteq = np.stack(
[
np.pad(mv[1], (0, n_ts_posteq - len(mv[1])), mode="edge")
for mv in measurements.values()
]
)
def pad_measurement(x_dyn, x_peq):
# only pad first axis
pad_width_dyn = tuple(
[(0, n_ts_dyn - len(x_dyn))] + [(0, 0)] * (x_dyn.ndim - 1)
)
pad_width_peq = tuple(
[(0, n_ts_posteq - len(x_peq))] + [(0, 0)] * (x_peq.ndim - 1)
)
return np.concatenate(
(
np.pad(x_dyn, pad_width_dyn, mode="edge"),
np.pad(x_peq, pad_width_peq, mode="edge"),
)
)
def pad_and_stack(output_index: int):
return np.stack(
[
pad_measurement(
mv[output_index][: len(mv[0])],
mv[output_index][len(mv[0]) :],
)
for mv in measurements.values()
]
)
my = pad_and_stack(2)
iys = pad_and_stack(3)
iy_trafos = pad_and_stack(4)
op_numeric = pad_and_stack(5)
op_mask = pad_and_stack(6)
op_indices = pad_and_stack(7)
np_numeric = pad_and_stack(8)
np_mask = pad_and_stack(9)
np_indices = pad_and_stack(10)
ts_masks = np.stack(
[
np.concatenate(
(
np.pad(
np.ones_like(mv[0]), (0, n_ts_dyn - len(mv[0]))
),
np.pad(
np.ones_like(mv[1]), (0, n_ts_posteq - len(mv[1]))
),
)
)
for mv in measurements.values()
]
).astype(bool)
petab_indices = np.stack(
[
pad_measurement(
np.array(idx[: len(mv[0])]),
np.array(idx[len(mv[0]) :]),
)
for mv, idx in zip(
measurements.values(), petab_indices.values()
)
]
)
return (
ts_dyn,
ts_posteq,
my,
iys,
iy_trafos,
ts_masks,
petab_indices,
op_numeric,
op_mask,
op_indices,
np_numeric,
np_mask,
np_indices,
)
[docs]
def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]:
if isinstance(self._petab_problem, HybridV2Problem):
simulation_conditions = get_simulation_conditions_v2(
self._petab_problem
)
return tuple(
tuple([row.conditionId])
for _, row in simulation_conditions.iterrows()
)
else:
simulation_conditions = self._petab_problem.get_simulation_conditions_from_measurement_df()
return tuple(
tuple(row) for _, row in simulation_conditions.iterrows()
)
def _initialize_model_parameters(self, model: JAXModel) -> dict:
"""
Initialize model parameter structure with zeros.
:param model:
JAX model with neural networks
:return:
Nested dictionary structure for model parameters
"""
return {
net_id: {
layer_id: {
attribute: jnp.zeros_like(getattr(layer, attribute))
for attribute in ["weight", "bias"]
if hasattr(layer, attribute)
}
for layer_id, layer in nn.layers.items()
}
for net_id, nn in model.nns.items()
}
def _load_parameter_arrays_from_files(self) -> dict:
"""
Load neural network parameter arrays from HDF5 files.
:return:
Dictionary mapping network IDs to parameter arrays
"""
if not self._petab_problem.extensions_config:
return {}
array_files = self._petab_problem.extensions_config["sciml"].get(
"array_files", []
)
import h5py
# TODO(performance): Avoid opening each file multiple times
return {
file_spec.split("_")[0]: h5py.File(file_spec, "r")["parameters"][
file_spec.split("_")[0]
]
for file_spec in array_files
if "parameters" in h5py.File(file_spec, "r").keys()
}
def _load_input_arrays_from_files(self) -> dict:
"""
Load neural network input arrays from HDF5 files.
:return:
Dictionary mapping network IDs to input arrays
"""
if not self._petab_problem.extensions_config:
return {}
array_files = self._petab_problem.extensions_config["sciml"].get(
"array_files", []
)
import h5py
# TODO(performance): Avoid opening each file multiple times
return {
file_spec.split("_")[0]: h5py.File(file_spec, "r")["inputs"]
for file_spec in array_files
if "inputs" in h5py.File(file_spec, "r").keys()
}
def _parse_parameter_name(
self, pname: str, model_pars: dict
) -> list[tuple[str, str]]:
"""
Parse parameter name to determine which layers and attributes to set.
:param pname:
Parameter name from PEtab (format: net.layer.attribute)
:param model_pars:
Model parameters dictionary
:return:
List of (layer_name, attribute_name) tuples to set
"""
net = pname.split("_")[0]
nn = model_pars[net]
to_set = []
name_parts = pname.split(".")
if len(name_parts) > 1:
layer_name = name_parts[1]
layer = nn[layer_name]
if len(name_parts) > 2:
# Specific attribute specified
attribute_name = name_parts[2]
to_set.append((layer_name, attribute_name))
else:
# All attributes of the layer
to_set.extend(
[(layer_name, attribute) for attribute in layer.keys()]
)
else:
# All layers and attributes
to_set.extend(
[
(layer_name, attribute)
for layer_name, layer in nn.items()
for attribute in layer.keys()
]
)
return to_set
def _extract_nominal_values_from_petab(
self, model: JAXModel, model_pars: dict, par_arrays: dict
) -> None:
"""
Extract nominal parameter values from PEtab problem and populate model_pars.
:param model:
JAX model
:param model_pars:
Model parameters dictionary to populate (modified in place)
:param par_arrays:
Parameter arrays loaded from files
"""
for pname, row in self._petab_problem.parameter_df.iterrows():
net = pname.split("_")[0]
if net not in model.nns:
continue
nn = model_pars[net]
scalar = True
# Determine value source (scalar from PEtab or array from file)
if np.isnan(row[petabv2.C.NOMINAL_VALUE]):
value = par_arrays[net]
scalar = False
else:
value = float(row[petabv2.C.NOMINAL_VALUE])
# Parse parameter name and set values
to_set = self._parse_parameter_name(pname, model_pars)
for layer, attribute in to_set:
if scalar:
nn[layer][attribute] = value * jnp.ones_like(
getattr(model.nns[net].layers[layer], attribute)
)
else:
nn[layer][attribute] = jnp.array(
value[layer][attribute][:]
)
def _set_model_parameters(
self, model: JAXModel, model_pars: dict
) -> JAXModel:
"""
Set parameter values in the model using equinox tree_at.
:param model:
JAX model to update
:param model_pars:
Dictionary of parameter values to set
:return:
Updated JAX model
"""
for net_id in model_pars:
for layer_id in model_pars[net_id]:
for attribute in model_pars[net_id][layer_id]:
logger.debug(
f"Setting {attribute} of layer {layer_id} in network "
f"{net_id} to {model_pars[net_id][layer_id][attribute]}"
)
model = eqx.tree_at(
lambda model: getattr(
model.nns[net_id].layers[layer_id], attribute
),
model,
model_pars[net_id][layer_id][attribute],
)
return model
def _set_input_arrays(
self, model: JAXModel, nn_input_arrays: dict, model_pars: dict
) -> JAXModel:
"""
Set input arrays in the model if provided.
:param model:
JAX model to update
:param nn_input_arrays:
Input arrays loaded from files
:param model_pars:
Model parameters dictionary (for network IDs)
:return:
Updated JAX model
"""
if len(nn_input_arrays) == 0:
return model
for net_id in model_pars:
input_array = {
input: {
k: jnp.array(
arr[:],
dtype=jnp.float64
if jax.config.jax_enable_x64
else jnp.float32,
)
for k, arr in nn_input_arrays[net_id][input].items()
}
for input in model.nns[net_id].inputs
}
model = eqx.tree_at(
lambda model: model.nns[net_id].inputs, model, input_array
)
return model
def _create_scaled_parameter_array(self) -> jt.Float[jt.Array, "np"]:
"""
Create array of scaled nominal parameter values for estimation.
:return:
JAX array of scaled parameter values
"""
return jnp.array(
[
petabv2.scale(
float(
self._petab_problem.parameter_df.loc[
pval, petabv2.C.NOMINAL_VALUE
]
),
self._petab_problem.parameter_df.loc[
pval, petabv2.PARAMETER_SCALE
],
)
for pval in self.parameter_ids
]
)
def _initialize_model_with_nominal_values(
self, model: JAXModel
) -> tuple[jt.Float[jt.Array, "np"], JAXModel]:
"""
Initialize the model with nominal parameter values and inputs from the PEtab problem.
This method:
- Initializes model parameter structure
- Loads parameter and input arrays from HDF5 files
- Extracts nominal values from PEtab problem
- Sets parameter values in the model
- Sets input arrays in the model
- Creates scaled parameter array to initialized to nominal values
:param model:
JAX model to initialize
:return:
Tuple of (scaled parameter array, initialized model)
"""
# Initialize model parameters structure
model_pars = self._initialize_model_parameters(model)
# Load arrays from files (getters)
par_arrays = self._load_parameter_arrays_from_files()
nn_input_arrays = self._load_input_arrays_from_files()
# Extract nominal values from PEtab problem
self._extract_nominal_values_from_petab(model, model_pars, par_arrays)
# Set values in model (setters)
model = self._set_model_parameters(model, model_pars)
model = self._set_input_arrays(model, nn_input_arrays, model_pars)
# Create scaled parameter array
if isinstance(self._petab_problem, HybridV2Problem):
param_map = {
p.id: p.nominal_value for p in self._petab_problem.parameters
}
parameter_array = jnp.array(
[float(param_map[pval]) for pval in self.parameter_ids]
)
else:
parameter_array = self._create_scaled_parameter_array()
return parameter_array, model
def _get_inputs(self) -> dict:
if self._petab_problem.mapping_df is None:
return {}
inputs = {net: {} for net in self.model.nns.keys()}
for petab_id, row in self._petab_problem.mapping_df.iterrows():
if (filepath := Path(petab_id)).is_file():
data_flat = pd.read_csv(filepath, sep="\t").sort_values(
by="ix"
)
shape = tuple(
np.stack(
data_flat["ix"]
.astype(str)
.str.split(";")
.apply(np.array)
)
.astype(int)
.max(axis=0)
+ 1
)
inputs[row["netId"]][row[petabv2.C.MODEL_ENTITY_ID]] = (
data_flat["value"].values.reshape(shape)
)
return inputs
@property
def parameter_ids(self) -> list[str]:
"""
Parameter ids that are estimated in the PEtab problem. Same ordering as values in :attr:`parameters`.
:return:
PEtab parameter ids
"""
return self._petab_problem.parameter_df[
petabv2.C.ESTIMATE
].index.tolist()
@property
def nn_output_ids(self) -> list[str]:
"""
Parameter ids that are estimated in the PEtab problem. Same ordering as values in :attr:`parameters`.
:return:
PEtab parameter ids
"""
if self._petab_problem.mapping_df is None:
return []
if (
self._petab_problem.mapping_df[petabv2.C.MODEL_ENTITY_ID]
.isnull()
.all()
):
return []
return self._petab_problem.mapping_df[
self._petab_problem.mapping_df[petabv2.C.MODEL_ENTITY_ID]
.str.split(".")
.str[1]
.str.startswith("output")
].index.tolist()
[docs]
def get_petab_parameter_by_id(self, name: str) -> jnp.float_:
"""
Get the value of a PEtab parameter by name.
:param name:
PEtab parameter id, as returned by :attr:`parameter_ids`.
:return:
Value of the parameter
"""
return self.parameters[self.parameter_ids.index(name)]
def _unscale(
self, p: jt.Float[jt.Array, "np"], scales: tuple[str, ...]
) -> jt.Float[jt.Array, "np"]:
"""
Unscaling of parameters.
:param p:
Parameter values
:param scales:
Parameter scalings
:return:
Unscaled parameter values
"""
return jnp.array(
[jax_unscale(pval, scale) for pval, scale in zip(p, scales)]
)
def _eval_nn(self, output_par: str, condition_id: str):
net_id = self._petab_problem.mapping_df.loc[
output_par, petabv2.C.MODEL_ENTITY_ID
].split(".")[0]
nn = self.model.nns[net_id]
def _is_net_input(model_id):
comps = model_id.split(".")
return comps[0] == net_id and comps[1].startswith("inputs")
model_id_map = (
self._petab_problem.mapping_df[
self._petab_problem.mapping_df[
petabv2.C.MODEL_ENTITY_ID
].apply(_is_net_input)
]
.reset_index()
.set_index(petabv2.C.MODEL_ENTITY_ID)[petabv2.C.PETAB_ENTITY_ID]
.to_dict()
)
condition_input_map = (
dict(
[
(
petab_id,
self._petab_problem.parameter_df.loc[
self._petab_problem.condition_df.loc[
condition_id, petab_id
],
petabv2.C.NOMINAL_VALUE,
],
)
if self._petab_problem.condition_df.loc[
condition_id, petab_id
]
in self._petab_problem.parameter_df.index
else (
petab_id,
np.float64(
self._petab_problem.condition_df.loc[
condition_id, petab_id
]
),
)
for petab_id in model_id_map.values()
]
)
if not self._petab_problem.condition_df.empty
else {}
)
hybridization_parameter_map = {
petab_id: self._petab_problem.hybridization_df.loc[
petab_id, "targetValue"
]
for petab_id in model_id_map.values()
if petab_id in set(self._petab_problem.hybridization_df.index)
}
# handle conditions
if len(condition_input_map) > 0:
net_input = jnp.array(
[
condition_input_map[petab_id]
for _, petab_id in model_id_map.items()
]
)
return nn.forward(net_input).squeeze()
# handle array inputs
if isinstance(self.model.nns[net_id].inputs, dict):
net_input = jnp.array(
[
self.model.nns[net_id].inputs[petab_id][condition_id]
if condition_id in self.model.nns[net_id].inputs[petab_id]
else self.model.nns[net_id].inputs[petab_id]["0"]
for _, petab_id in model_id_map.items()
]
)
return nn.forward(net_input).squeeze()
net_input = jnp.array(
[
jax.lax.stop_gradient(self.model.nns[net_id][model_id])
if model_id in self.model.nns[net_id].inputs
else self.get_petab_parameter_by_id(petab_id)
if petab_id in self.parameter_ids
else self._petab_problem.parameter_df.loc[
petab_id, petabv2.C.NOMINAL_VALUE
]
if petab_id in set(self._petab_problem.parameter_df.index)
else self._petab_problem.parameter_df.loc[
hybridization_parameter_map[petab_id],
petabv2.C.NOMINAL_VALUE,
]
for model_id, petab_id in model_id_map.items()
]
)
return nn.forward(net_input).squeeze()
def _map_model_parameter_value(
self,
mapping: ParameterMappingForCondition,
pname: str,
condition_id: str,
) -> jt.Float[jt.Scalar, ""] | float: # noqa: F722
pval = mapping.map_sim_var[pname]
if hasattr(self, "nn_output_ids") and pval in self.nn_output_ids:
nn_output = self._eval_nn(pval, condition_id)
if nn_output.size > 1:
entityId = self._petab_problem.mapping_df.loc[
pval, petabv2.C.MODEL_ENTITY_ID
]
ind = int(re.search(r"\[\d+\]\[(\d+)\]", entityId).group(1))
return nn_output[ind]
else:
return nn_output
if isinstance(pval, Number):
return pval
return self.get_petab_parameter_by_id(pval)
[docs]
def load_model_parameters(
self, experiment: petabv2.Experiment, is_preeq: bool
) -> jt.Float[jt.Array, "np"]:
"""
Load parameters for an experiment.
:param experiment:
Experiment to load parameters for.
:param is_preeq:
Whether to load preequilibration or simulation parameters.
:return:
Parameters for the experiment.
"""
p = jnp.array(
[
self._map_experiment_model_parameter_value(
pname, ind, experiment, is_preeq
)
for ind, pname in enumerate(self.model.parameter_ids)
]
)
pscale = tuple([petabv2.C.LIN for _ in self.model.parameter_ids])
return self._unscale(p, pscale)
def _map_experiment_model_parameter_value(
self,
pname: str,
p_index: int,
experiment: petabv2.Experiment,
is_preeq: bool,
):
"""
Get values for the given parameter `pname` from the relevant petab tables.
:param pname: PEtab parameter id
:param p_index: Index of the parameter in the model's parameter list
:param experiment: PEtab experiment
:param is_preeq: Whether to get preequilibration or simulation parameter value
:return: Value of the parameter
"""
condition_ids = []
for p in experiment.sorted_periods:
if is_preeq:
if not p.is_preequilibration:
continue
else:
condition_ids = p.condition_ids
break
else:
if p.is_preequilibration:
continue
else:
condition_ids = p.condition_ids
break
init_val = self.model.parameters[p_index]
params_nominals = {
p.id: p.nominal_value for p in self._petab_problem.parameters
}
targets_map = {
ch.target_id: ch.target_value
for c in self._petab_problem.conditions
for ch in c.changes
if c.id in condition_ids
}
if pname in params_nominals:
return params_nominals[pname]
elif pname in targets_map:
return float(targets_map[pname])
else:
for placeholder_attr, param_attr in (
("observable_placeholders", "observable_parameters"),
("noise_placeholders", "noise_parameters"),
):
placeholders = [
getattr(o, placeholder_attr)
for o in self._petab_problem.observables
]
for placeholders in placeholders:
params_list = getattr(
self._petab_problem.measurements[0], param_attr
)
for i, p in enumerate(placeholders):
if str(p) == pname:
val = self._find_val(
str(params_list[i]), params_nominals
)
return val
return init_val
def _find_val(self, param_entry: str, params_nominals: dict):
val_float = _try_float(param_entry)
if isinstance(val_float, float):
return val_float
elif param_entry in params_nominals:
return params_nominals[param_entry]
else:
return param_entry
def _state_needs_reinitialisation(
self,
simulation_condition: str,
state_id: str,
) -> bool:
"""
Check if a state needs reinitialisation for a simulation condition.
:param simulation_condition:
simulation condition to check reinitialisation for
:param state_id:
state id to check reinitialisation for
:return:
True if state needs reinitialisation, False otherwise
"""
if state_id in self.nn_output_ids:
return True
if state_id not in self._petab_problem.condition_df:
return False
xval = self._petab_problem.condition_df.loc[
simulation_condition, state_id
]
if isinstance(xval, Number) and np.isnan(xval):
return False
return True
def _state_reinitialisation_value(
self,
simulation_condition: str,
state_id: str,
p: jt.Float[jt.Array, "np"],
) -> jt.Float[jt.Scalar, ""] | float: # noqa: F722
"""
Get the reinitialisation value for a state.
:param simulation_condition:
simulation condition to get reinitialisation value for
:param state_id:
state id to get reinitialisation value for
:param p:
parameters for the simulation condition
:return:
reinitialisation value for the state
"""
if state_id in self.nn_output_ids:
return self._eval_nn(state_id)
if state_id not in self._petab_problem.condition_df:
# no reinitialisation, return dummy value
return 0.0
xval = self._petab_problem.condition_df.loc[
simulation_condition, state_id
]
if isinstance(xval, Number) and np.isnan(xval):
# no reinitialisation, return dummy value
return 0.0
if isinstance(xval, Number):
# numerical value, return as is
return xval
if xval in self.model.parameter_ids:
# model parameter, return value
return p[self.model.parameter_ids.index(xval)]
if xval in self.parameter_ids:
# estimated PEtab parameter, return unscaled value
return jax_unscale(
self.get_petab_parameter_by_id(xval),
self._petab_problem.parameter_df.loc[
xval, petabv2.PARAMETER_SCALE
],
)
# only remaining option is nominal value for PEtab parameter
# that is not estimated, return nominal value
return self._petab_problem.parameter_df.loc[
xval, petabv2.C.NOMINAL_VALUE
]
[docs]
def load_reinitialisation(
self,
simulation_condition: str,
p: jt.Float[jt.Array, "np"],
) -> tuple[jt.Bool[jt.Array, "nx"], jt.Float[jt.Array, "nx"]]: # noqa: F821
"""
Load reinitialisation values and mask for the state vector for a simulation condition.
:param simulation_condition:
Simulation condition to load reinitialisation for.
:param p:
Parameters for the simulation condition.
:return:
Tuple of reinitialisation masm and value for states.
"""
if not any(
x_id in self._petab_problem.condition_df
or hasattr(self, "nn_output_ids")
and x_id in self.nn_output_ids
for x_id in self.model.state_ids
):
return jnp.array([]), jnp.array([])
mask = jnp.array(
[
self._state_needs_reinitialisation(simulation_condition, x_id)
for x_id in self.model.state_ids
]
)
reinit_x = jnp.array(
[
self._state_reinitialisation_value(
simulation_condition, x_id, p
)
for x_id in self.model.state_ids
]
)
return mask, reinit_x
[docs]
def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem":
"""
Update parameters for the model.
:param p:
New problem instance with updated parameters.
"""
return eqx.tree_at(lambda p: p.parameters, self, p)
def _prepare_experiments(
self,
experiments: list[petabv2.Experiment],
conditions: list[str],
is_preeq: bool,
op_numeric: np.ndarray | None = None,
op_mask: np.ndarray | None = None,
op_indices: np.ndarray | None = None,
np_numeric: np.ndarray | None = None,
np_mask: np.ndarray | None = None,
np_indices: np.ndarray | None = None,
) -> tuple[
jt.Float[jt.Array, "nc np"], # noqa: F821, F722
jt.Bool[jt.Array, "nx"], # noqa: F821
jt.Float[jt.Array, "nx"], # noqa: F821
jt.Float[jt.Array, "nc nt nop"], # noqa: F821, F722
jt.Float[jt.Array, "nc nt nnp"], # noqa: F821, F722
]:
"""
Prepare experiments for simulation.
:param experiments:
Experiments to prepare simulation arrays for.
:param conditions:
Simulation conditions to prepare.
:param is_preeq:
Whether to load preequilibration or simulation parameters.
:param op_numeric:
Numeric values for observable parameter overrides. If None, no overrides are used.
:param op_mask:
Mask for observable parameter overrides. True for free parameter overrides, False for numeric values.
:param op_indices:
Free parameter indices (wrt. `self.parameters`) for observable parameter overrides.
:param np_numeric:
Numeric values for noise parameter overrides. If None, no overrides are used.
:param np_mask:
Mask for noise parameter overrides. True for free parameter overrides, False for numeric values.
:param np_indices:
Free parameter indices (wrt. `self.parameters`) for noise parameter overrides.
:return:
Tuple of parameter arrays, reinitialisation masks and reinitialisation values, observable parameters and
noise parameters.
"""
p_array = jnp.stack(
[self.load_model_parameters(exp, is_preeq) for exp in experiments]
)
exp_ids = [exp.id for exp in experiments]
all_exp_ids = [exp.id for exp in self._petab_problem.experiments]
h_mask = jnp.stack(
[
jnp.ones(self.model.n_events)
if (exp_id in exp_ids)
else jnp.zeros(self.model.n_events)
for exp_id in all_exp_ids
]
)
t_zeros = jnp.stack(
[
exp.periods[0].time if exp.periods[0].time >= 0.0 else 0.0
for exp in experiments
]
)
if self.parameters.size:
if isinstance(self._petab_problem, HybridV2Problem):
unscaled_parameters = jnp.stack(
[
self.parameters[ip]
for ip, p_id in enumerate(self.parameter_ids)
]
)
else:
unscaled_parameters = jnp.stack(
[
jax_unscale(
self.parameters[ip],
self._petab_problem.parameter_df.loc[
p_id, petabv2.C.PARAMETER_SCALE
],
)
for ip, p_id in enumerate(self.parameter_ids)
]
)
else:
unscaled_parameters = jnp.zeros((*self._ts_masks.shape[:2], 0))
# placeholder values from sundials code may be needed here
if op_numeric is not None and op_numeric.size:
op_array = jnp.where(
op_mask,
jax.vmap(
jax.vmap(jax.vmap(lambda ip: unscaled_parameters[ip]))
)(op_indices),
op_numeric,
)
else:
op_array = jnp.zeros((*self._ts_masks.shape[:2], 0))
if np_numeric is not None and np_numeric.size:
np_array = jnp.where(
np_mask,
jax.vmap(
jax.vmap(jax.vmap(lambda ip: unscaled_parameters[ip]))
)(np_indices),
np_numeric,
)
else:
np_array = jnp.zeros((*self._ts_masks.shape[:2], 0))
mask_reinit_array = jnp.stack(
[
self.load_reinitialisation(sc, p)[0]
for sc, p in zip(conditions, p_array)
]
)
x_reinit_array = jnp.stack(
[
self.load_reinitialisation(sc, p)[1]
for sc, p in zip(conditions, p_array)
]
)
return (
p_array,
mask_reinit_array,
x_reinit_array,
op_array,
np_array,
h_mask,
t_zeros,
)
[docs]
@eqx.filter_vmap(
in_axes={
"max_steps": None,
"self": None,
}, # only list arguments here where eqx.is_array(0) is not the right thing
)
def run_simulation(
self,
p: jt.Float[jt.Array, "np"], # noqa: F821, F722
ts_dyn: np.ndarray,
ts_posteq: np.ndarray,
my: np.ndarray,
iys: np.ndarray,
iy_trafos: np.ndarray,
ops: jt.Float[jt.Array, "nt *nop"], # noqa: F821, F722
nps: jt.Float[jt.Array, "nt *nnp"], # noqa: F821, F722
mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722
x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722
init_override: jt.Float[jt.Array, "nx"], # noqa: F821, F722
init_override_mask: jt.Bool[jt.Array, "nx"], # noqa: F821, F722
h_mask: jt.Bool[jt.Array, "nx"], # noqa: F821, F722
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
root_finder: AbstractRootFinder,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722
h_preeq: jt.Bool[jt.Array, "*ne"] = jnp.array([]), # noqa: F821, F722
ts_mask: np.ndarray = np.array([]),
t_zeros: jnp.float_ = 0.0,
ret: ReturnValue = ReturnValue.llh,
) -> tuple[jnp.float_, dict]:
"""
Run a simulation for a given simulation experiment.
:param p:
Parameters for the simulation experiment
:param ts_dyn:
(Padded) dynamic time points
:param ts_posteq:
(Padded) post-equilibrium time points
:param my:
(Padded) measurements
:param iys:
(Padded) observable indices
:param iy_trafos:
(Padded) observable transformations indices
:param ops:
(Padded) observable parameters
:param nps:
(Padded) noise parameters
:param mask_reinit:
Mask for states that need reinitialisation
:param x_reinit:
Reinitialisation values for states
:param h_mask:
Mask for the events that are part of the current experiment
:param solver:
ODE solver to use for simulation
:param controller:
Step size controller to use for simulation
:param steady_state_event:
Steady state event function to use for post-equilibration. Allows customisation of the steady state
condition, see :func:`diffrax.steady_state_event` for details.
:param max_steps:
Maximum number of steps to take during simulation
:param x_preeq:
Pre-equilibration state. Can be empty if no pre-equilibration is available, in which case the states will
be initialised to the model default values.
:param h_preeq:
Pre-equilibration event mask. Can be empty if no pre-equilibration is available
:param ts_mask:
padding mask, see :meth:`JAXModel.simulate_condition` for details.
:param t_zeros:
simulation start time for the current experiment.
:param ret:
which output to return. See :class:`ReturnValue` for available options.
:return:
Tuple of output value and simulation statistics
"""
return self.model.simulate_condition(
p=p,
ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)),
ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)),
my=jax.lax.stop_gradient(jnp.array(my)),
iys=jax.lax.stop_gradient(jnp.array(iys)),
iy_trafos=jax.lax.stop_gradient(jnp.array(iy_trafos)),
nps=nps,
ops=ops,
x_preeq=x_preeq,
h_preeq=h_preeq,
mask_reinit=jax.lax.stop_gradient(mask_reinit),
x_reinit=x_reinit,
init_override=init_override,
init_override_mask=jax.lax.stop_gradient(init_override_mask),
ts_mask=jax.lax.stop_gradient(jnp.array(ts_mask)),
h_mask=jax.lax.stop_gradient(jnp.array(h_mask)),
t_zero=t_zeros,
solver=solver,
controller=controller,
root_finder=root_finder,
max_steps=max_steps,
steady_state_event=steady_state_event,
adjoint=diffrax.RecursiveCheckpointAdjoint()
if ret in (ReturnValue.llh, ReturnValue.chi2)
else diffrax.DirectAdjoint(),
ret=ret,
)
[docs]
def run_simulations(
self,
experiments: list[petabv2.Experiment],
preeq_array: jt.Float[jt.Array, "ncond *nx"], # noqa: F821, F722
h_preeqs: jt.Bool[jt.Array, "ncond *ne"], # noqa: F821
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
root_finder: AbstractRootFinder,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
ret: ReturnValue = ReturnValue.llh,
):
"""
Run simulations for a list of simulation experiments.
:param experiments:
Experiments to run simulations for.
:param preeq_array:
Matrix of pre-equilibrated states for the simulation conditions. Ordering must match the simulation
conditions. If no pre-equilibration is available for a condition, the corresponding row must be empty.
:param h_preeqs:
Matrix of pre-equilibration event heaviside variables indicating whether an event condition is false or
true after preequilibration.
:param solver:
ODE solver to use for simulation.
:param controller:
Step size controller to use for simulation.
:param steady_state_event:
Steady state event function to use for post-equilibration. Allows customisation of the steady state
condition, see :func:`diffrax.steady_state_event` for details.
:param max_steps:
Maximum number of steps to take during simulation.
:param ret:
which output to return. See :class:`ReturnValue` for available options.
:return:
Output value and condition specific results and statistics. Results and statistics are returned as a dict
with arrays with the leading dimension corresponding to the simulation conditions.
"""
simulation_conditions = [
cid
for exp in experiments
for p in exp.periods
for cid in p.condition_ids
]
dynamic_conditions = list(
sc for sc in simulation_conditions if "preequilibration" not in sc
)
dynamic_conditions = list(dict.fromkeys(dynamic_conditions))
(
p_array,
mask_reinit_array,
x_reinit_array,
op_array,
np_array,
h_mask,
t_zeros,
) = self._prepare_experiments(
experiments,
dynamic_conditions,
False,
self._op_numeric,
self._op_mask,
self._op_indices,
self._np_numeric,
self._np_mask,
self._np_indices,
)
init_override_mask = jnp.stack(
[
jnp.array(
[
p in set(self.model.parameter_ids)
for p in self.model.state_ids
]
)
for _ in experiments
]
)
init_override = jnp.stack(
[
jnp.array(
[
self._eval_nn(
p, exp.periods[-1].condition_ids[0]
) # TODO: Add mapping of p to eval_nn?
if p in set(self.model.parameter_ids)
else 1.0
for p in self.model.state_ids
]
)
for exp in experiments
]
)
return self.run_simulation(
p_array,
self._ts_dyn,
self._ts_posteq,
self._my,
self._iys,
self._iy_trafos,
op_array,
np_array,
mask_reinit_array,
x_reinit_array,
init_override,
init_override_mask,
h_mask,
solver,
controller,
root_finder,
steady_state_event,
max_steps,
preeq_array,
h_preeqs,
self._ts_masks,
t_zeros,
ret,
)
[docs]
@eqx.filter_vmap(
in_axes={
"max_steps": None,
"self": None,
}, # only list arguments here where eqx.is_array(0) is not the right thing
)
def run_preequilibration(
self,
p: jt.Float[jt.Array, "np"], # noqa: F821, F722
mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722
x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722
h_mask: jt.Bool[jt.Array, "ne"], # noqa: F821, F722
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
root_finder: AbstractRootFinder,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821
"""
Run a pre-equilibration simulation for a given simulation experiment.
:param p:
Parameters for the simulation experiment
:param mask_reinit:
Mask for states that need reinitialisation
:param x_reinit:
Reinitialisation values for states
:param h_mask:
Mask for the events that are part of the current experiment
:param solver:
ODE solver to use for simulation
:param controller:
Step size controller to use for simulation
:param steady_state_event:
Steady state event function to use for pre-equilibration. Allows customisation of the steady state
condition, see :func:`diffrax.steady_state_event` for details.
:param max_steps:
Maximum number of steps to take during simulation
:return:
Pre-equilibration state
"""
return self.model.preequilibrate_condition(
p=p,
mask_reinit=mask_reinit,
x_reinit=x_reinit,
h_mask=h_mask,
solver=solver,
controller=controller,
root_finder=root_finder,
max_steps=max_steps,
steady_state_event=steady_state_event,
)
[docs]
def run_preequilibrations(
self,
experiments: list[petabv2.Experiment],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
root_finder: AbstractRootFinder,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
):
simulation_conditions = [
cid
for exp in experiments
for p in exp.periods
for cid in p.condition_ids
]
preequilibration_conditions = list(
{sc for sc in simulation_conditions if "preequilibration" in sc}
)
p_array, mask_reinit_array, x_reinit_array, _, _, h_mask, _ = (
self._prepare_experiments(
experiments, preequilibration_conditions, True, None, None
)
)
return self.run_preequilibration(
p_array,
mask_reinit_array,
x_reinit_array,
h_mask,
solver,
controller,
root_finder,
steady_state_event,
max_steps,
)
[docs]
def run_simulations(
problem: JAXProblem,
simulation_experiments: Iterable[str] | None = None,
solver: diffrax.AbstractSolver = diffrax.Kvaerno5(),
controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
**DEFAULT_CONTROLLER_SETTINGS
),
root_finder: AbstractRootFinder = optimistix.Newton(
**DEFAULT_ROOT_FINDER_SETTINGS
),
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
] = diffrax.steady_state_event(),
max_steps: int = 2**13,
ret: ReturnValue | str = ReturnValue.llh,
):
"""
Run simulations for a problem.
:param problem:
Problem to run simulations for.
:param simulation_experiments:
Simulation experiments to run simulations for. This is an iterable of experiment ids.
Default is to run simulations for all experiments.
:param solver:
ODE solver to use for simulation.
:param controller:
Step size controller to use for simulation.
:param root_finder:
Root finder to use for event detection.
:param steady_state_event:
Steady state event function to use for pre-/post-equilibration. Allows customisation of the steady state
condition, see :func:`diffrax.steady_state_event` for details.
:param max_steps:
Maximum number of steps to take during simulation.
:param ret:
which output to return. See :class:`ReturnValue` for available options.
:return:
Overall output value and condition specific results and statistics.
"""
if isinstance(problem, HybridProblem) or isinstance(
problem._petab_problem, petabv1.Problem
):
raise TypeError(
"run_simulations does not support PEtab v1 problems. Upgrade the problem to PEtab v2."
)
if isinstance(ret, str):
ret = ReturnValue[ret]
if simulation_experiments is None:
experiments = problem._petab_problem.experiments
else:
experiments = [
exp
for exp in problem._petab_problem.experiments
if exp.id in simulation_experiments
]
simulation_conditions = [
cid
for exp in experiments
for p in exp.periods
for cid in p.condition_ids
]
dynamic_conditions = list(
sc for sc in simulation_conditions if "preequilibration" not in sc
)
dynamic_conditions = list(dict.fromkeys(dynamic_conditions))
conditions = {
"dynamic_conditions": dynamic_conditions,
}
has_preeq = any(exp.periods[0].time < 0.0 for exp in experiments)
if has_preeq:
preeqs, preresults, h_preeqs = problem.run_preequilibrations(
experiments,
solver,
controller,
root_finder,
steady_state_event,
max_steps,
)
preeqs_array = preeqs
else:
preresults = {
"stats_preeq": None,
}
preeqs_array = jnp.stack([jnp.array([]) for _ in experiments])
h_preeqs = jnp.stack([jnp.array([]) for _ in experiments])
output, results = problem.run_simulations(
experiments,
preeqs_array,
h_preeqs,
solver,
controller,
root_finder,
steady_state_event,
max_steps,
ret,
)
if ret in (ReturnValue.llh, ReturnValue.chi2):
if os.getenv("JAX_DEBUG") == "1":
jax.debug.print(
"ret: {}",
ret,
)
output = jnp.sum(output)
return output, results | preresults | conditions
[docs]
def petab_simulate(
problem: JAXProblem,
solver: diffrax.AbstractSolver = diffrax.Kvaerno5(),
controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
**DEFAULT_CONTROLLER_SETTINGS
),
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
] = diffrax.steady_state_event(),
max_steps: int = 2**13,
):
"""
Run simulations for a problem and return the results as a petab simulation dataframe.
:param problem:
Problem to run simulations for.
:param solver:
ODE solver to use for simulation.
:param controller:
Step size controller to use for simulation.
:param max_steps:
Maximum number of steps to take during simulation.
:param steady_state_event:
Steady state event function to use for pre-/post-equilibration. Allows customisation of the steady state
condition, see :func:`diffrax.steady_state_event` for details.
:return:
petab simulation dataframe.
"""
y, r = run_simulations(
problem,
solver=solver,
controller=controller,
steady_state_event=steady_state_event,
max_steps=max_steps,
ret=ReturnValue.y,
)
if isinstance(problem._petab_problem, HybridV2Problem):
return _build_simulation_df_v2(problem, y, r["dynamic_conditions"])
else:
dfs = []
for ic, sc in enumerate(r["dynamic_conditions"]):
obs = [
problem.model.observable_ids[io]
for io in problem._iys[ic, problem._ts_masks[ic, :]]
]
t = jnp.concat(
(
problem._ts_dyn[ic, :],
problem._ts_posteq[ic, :],
)
)
df_sc = pd.DataFrame(
{
petabv2.C.SIMULATION: y[ic, problem._ts_masks[ic, :]],
petabv2.C.TIME: t[problem._ts_masks[ic, :]],
petabv2.C.OBSERVABLE_ID: obs,
petabv2.C.CONDITION_ID: [sc] * len(t),
},
index=problem._petab_measurement_indices[ic, :],
)
if (
petabv2.C.OBSERVABLE_PARAMETERS
in problem._petab_problem.measurement_df
):
df_sc[petabv2.C.OBSERVABLE_PARAMETERS] = (
problem._petab_problem.measurement_df.query(
f"{petabv2.C.CONDITION_ID} == '{sc}'"
)[petabv2.C.OBSERVABLE_PARAMETERS]
)
if (
petabv2.C.NOISE_PARAMETERS
in problem._petab_problem.measurement_df
):
df_sc[petabv2.C.NOISE_PARAMETERS] = (
problem._petab_problem.measurement_df.query(
f"{petabv2.C.CONDITION_ID} == '{sc}'"
)[petabv2.C.NOISE_PARAMETERS]
)
if (
petabv2.C.PREEQUILIBRATION_CONDITION_ID
in problem._petab_problem.measurement_df
):
df_sc[petabv2.C.PREEQUILIBRATION_CONDITION_ID] = (
problem._petab_problem.measurement_df.query(
f"{petabv2.C.CONDITION_ID} == '{sc}'"
)[petabv2.C.PREEQUILIBRATION_CONDITION_ID]
)
dfs.append(df_sc)
return pd.concat(dfs).sort_index()
def add_default_experiment_names_to_v2_problem(petab_problem: petabv2.Problem):
"""Add default experiment names to PEtab v2 problem.
Args:
petab_problem: PEtab v2 problem to modify.
"""
if not hasattr(petab_problem, "extensions_config"):
petab_problem.extensions_config = {}
petab_problem.visualization_df = None
if petab_problem.condition_df is None:
default_condition = petabv2.core.Condition(
id="__default__", changes=[], conditionId="__default__"
)
petab_problem.condition_tables[0].elements = [default_condition]
if (
petab_problem.experiment_df is None
or petab_problem.experiment_df.empty
):
condition_ids = petab_problem.condition_df[
petabv2.C.CONDITION_ID
].values
condition_ids = [
c for c in condition_ids if "preequilibration" not in c
]
default_experiment = petabv2.core.Experiment(
id="__default__",
periods=[
petabv2.core.ExperimentPeriod(
time=0.0, condition_ids=condition_ids
)
],
)
petab_problem.experiment_tables[0].elements = [default_experiment]
measurement_tables = petab_problem.measurement_tables.copy()
for mt in measurement_tables:
for m in mt.elements:
m.experiment_id = "__default__"
petab_problem.measurement_tables = measurement_tables
return petab_problem
def get_simulation_conditions_v2(petab_problem) -> pd.DataFrame:
"""Get simulation conditions from PEtab v2 measurement DataFrame.
Returns:
A pandas DataFrame mapping experiment_ids to condition ids.
"""
experiment_df = petab_problem.experiment_df
exps = {}
for exp_id in experiment_df[petabv2.C.EXPERIMENT_ID].unique():
exps[exp_id] = experiment_df[
experiment_df[petabv2.C.EXPERIMENT_ID] == exp_id
][petabv2.C.CONDITION_ID].unique()
experiment_df = experiment_df.drop(columns=[petabv2.C.TIME])
return experiment_df
def _build_simulation_df_v2(problem, y, dyn_conditions):
"""Build petab simulation DataFrame of similation results from a PEtab v2 problem."""
dfs = []
for ic, sc in enumerate(dyn_conditions):
experiment_id = _conditions_to_experiment_map(
problem._petab_problem.experiment_df
)[sc]
if experiment_id == "__default__":
experiment_id = jnp.nan
obs = [
problem.model.observable_ids[io]
for io in problem._iys[ic, problem._ts_masks[ic, :]]
]
t = jnp.concat(
(
problem._ts_dyn[ic, :],
problem._ts_posteq[ic, :],
)
)
df_sc = pd.DataFrame(
{
petabv2.C.MODEL_ID: [float("nan")] * len(t),
petabv2.C.OBSERVABLE_ID: obs,
petabv2.C.EXPERIMENT_ID: [experiment_id] * len(t),
petabv2.C.TIME: t[problem._ts_masks[ic, :]],
petabv2.C.SIMULATION: y[ic, problem._ts_masks[ic, :]],
},
index=problem._petab_measurement_indices[ic, :],
)
if (
petabv2.C.OBSERVABLE_PARAMETERS
in problem._petab_problem.measurement_df
):
df_sc[petabv2.C.OBSERVABLE_PARAMETERS] = (
problem._petab_problem.measurement_df.query(
f"{petabv2.C.EXPERIMENT_ID} == '{experiment_id}'"
)[petabv2.C.OBSERVABLE_PARAMETERS]
)
if petabv2.C.NOISE_PARAMETERS in problem._petab_problem.measurement_df:
df_sc[petabv2.C.NOISE_PARAMETERS] = (
problem._petab_problem.measurement_df.query(
f"{petabv2.C.EXPERIMENT_ID} == '{experiment_id}'"
)[petabv2.C.NOISE_PARAMETERS]
)
dfs.append(df_sc)
return pd.concat(dfs).sort_index()
def _conditions_to_experiment_map(
experiment_df: pd.DataFrame,
) -> dict[str, str]:
condition_to_experiment = {
row.conditionId: row.experimentId for row in experiment_df.itertuples()
}
return condition_to_experiment
def _try_float(value):
try:
return float(value)
except Exception as e:
msg = str(e).lower()
if isinstance(e, ValueError) and "could not convert" in msg:
return value
raise