Source code for amici.sim.jax.petab

"""PEtab wrappers for JAX models.""" ""

import logging
import os
import re
import shutil
from collections.abc import Callable, Iterable, Sized
from numbers import Number
from pathlib import Path

import diffrax
import equinox as eqx
import jax.lax
import jax.numpy as jnp
import jaxtyping as jt
import numpy as np
import optimistix
import pandas as pd
import petab.v1 as petabv1
import petab.v2 as petabv2
from optimistix import AbstractRootFinder

from amici import _module_from_path
from amici.logging import get_logger
from amici.sim._parameter_mapping import ParameterMappingForCondition
from amici.sim.jax.model import JAXModel, ReturnValue

DEFAULT_CONTROLLER_SETTINGS = {
    "atol": 1e-8,
    "rtol": 1e-8,
    "pcoeff": 0.4,
    "icoeff": 0.3,
    "dcoeff": 0.0,
}

DEFAULT_ROOT_FINDER_SETTINGS = {
    "atol": 1e-12,
    "rtol": 1e-12,
}

SCALE_TO_INT = {
    petabv2.C.LIN: 0,
    petabv2.C.LOG: 1,
    petabv2.C.LOG10: 2,
}

logger = get_logger(__name__, logging.WARNING)


def jax_unscale(
    parameter: jnp.float_,
    scale_str: str,
) -> jnp.float_:
    """Unscale parameter according to ``scale_str``.

    Arguments:
        parameter:
            Parameter to be unscaled.
        scale_str:
            One of ``petabv2.C.LIN``, ``petabv2.C.LOG``, ``petabv2.C.LOG10``.

    Returns:
        The unscaled parameter.
    """
    if scale_str == petabv2.C.LIN or not scale_str:
        return parameter
    if scale_str == petabv2.C.LOG:
        return jnp.exp(parameter)
    if scale_str == petabv2.C.LOG10:
        return jnp.power(10, parameter)
    raise ValueError(f"Invalid parameter scaling: {scale_str}")


# IDEA: Implement this class in petab-sciml instead?
class HybridProblem(petabv1.Problem):
    hybridization_df: pd.DataFrame

    def __init__(self, petab_problem: petabv1.Problem):
        self.__dict__.update(petab_problem.__dict__)
        self.hybridization_df = _get_hybridization_df(petab_problem)


class HybridV2Problem(petabv2.Problem):
    hybridization_df: pd.DataFrame
    extensions_config: dict

    def __init__(self, petab_problem: petabv2.Problem):
        if not hasattr(petab_problem, "extensions_config"):
            self.extensions_config = {}
        self.__dict__.update(petab_problem.__dict__)
        self.hybridization_df = _get_hybridization_df(petab_problem)


def _get_hybridization_df(petab_problem):
    if not hasattr(petab_problem, "extensions_config"):
        return None

    if "sciml" in petab_problem.extensions_config:
        hybridizations = [
            pd.read_csv(hf, sep="\t", index_col=0)
            for hf in petab_problem.extensions_config["sciml"][
                "hybridization_files"
            ]
        ]
        hybridization_df = pd.concat(hybridizations)
        return hybridization_df


def _get_hybrid_petab_problem(
    petab_problem: petabv1.Problem | petabv2.Problem,
):
    if isinstance(petab_problem, petabv2.Problem):
        return HybridV2Problem(petab_problem)
    return HybridProblem(petab_problem)


[docs] class JAXProblem(eqx.Module): """ PEtab problem wrapper for JAX models. :ivar parameters: Values for the model parameters. Do not change dimensions, values may be changed during, e.g. model training. :ivar model: JAXModel instance to use for simulation. :ivar _parameter_mappings: :class:`ParameterMappingForCondition` instances for each simulation condition. :ivar _measurements: Preprocessed arrays for each simulation condition. :ivar _petab_problem: PEtab problem to simulate. """ parameters: jnp.ndarray model: JAXModel simulation_conditions: tuple[tuple[str, ...], ...] _parameter_mappings: dict[str, ParameterMappingForCondition] _ts_dyn: np.ndarray _ts_posteq: np.ndarray _my: np.ndarray _iys: np.ndarray _iy_trafos: np.ndarray _ts_masks: np.ndarray _op_numeric: np.ndarray _op_mask: np.ndarray _op_indices: np.ndarray _np_numeric: np.ndarray _np_mask: np.ndarray _np_indices: np.ndarray _petab_measurement_indices: np.ndarray _petab_problem: petabv1.Problem | HybridProblem | petabv2.Problem
[docs] def __init__( self, model: JAXModel, petab_problem: petabv1.Problem | petabv2.Problem ): """ Initialize a JAXProblem instance with a model and a PEtab problem. :param model: JAXModel instance to use for simulation. :param petab_problem: PEtab problem to simulate. """ if isinstance(petab_problem, petabv1.Problem): raise TypeError( "JAXProblem does not support PEtab v1 problems. Upgrade the problem to PEtab v2." ) petab_problem = add_default_experiment_names_to_v2_problem( petab_problem ) scs = get_simulation_conditions_v2(petab_problem) self.simulation_conditions = scs.conditionId.to_list() self._petab_problem = _get_hybrid_petab_problem(petab_problem) self.parameters, self.model = ( self._initialize_model_with_nominal_values(model) ) self._parameter_mappings = None ( self._ts_dyn, self._ts_posteq, self._my, self._iys, self._iy_trafos, self._ts_masks, self._petab_measurement_indices, self._op_numeric, self._op_mask, self._op_indices, self._np_numeric, self._np_mask, self._np_indices, ) = self._get_measurements(scs)
[docs] def save(self, directory: Path): """ Save the problem to a directory. :param directory: Directory to save the problem to. """ self._petab_problem.to_files( prefix_path=directory, model_file="model", condition_file="conditions.tsv", measurement_file="measurements.tsv", parameter_file="parameters.tsv", observable_file="observables.tsv", yaml_file="problem.yaml", ) shutil.copy(self.model.jax_py_file, directory / "jax_py_file.py") with open(directory / "parameters.pkl", "wb") as f: eqx.tree_serialise_leaves(f, self)
[docs] @classmethod def load(cls, directory: Path): """ Load a problem from a directory. :param directory: Directory to load the problem from. :return: Loaded problem instance. """ petab_problem = petabv2.Problem.from_yaml( directory / "problem.yaml", ) model = _module_from_path("jax", directory / "jax_py_file.py").Model() problem = cls(model, petab_problem) with open(directory / "parameters.pkl", "rb") as f: return eqx.tree_deserialise_leaves(f, problem)
def _get_measurements( self, simulation_conditions: pd.DataFrame ) -> tuple[ np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, ]: """ Get measurements for the model based on the provided simulation conditions. :param simulation_conditions: Simulation conditions to create parameter mappings for. Same format as returned by :meth:`petabv1.Problem.get_simulation_conditions_from_measurement_df`. :return: tuple of padded - dynamic time points - post-equilibrium time points - measurements - observable indices - observable transformations indices - measurement masks - data indices (index in petab measurement dataframe). - numeric values for observable parameter overrides - non-numeric mask for observable parameter overrides - parameter indices (problem parameters) for observable parameter overrides - numeric values for noise parameter overrides - non-numeric mask for noise parameter overrides - parameter indices (problem parameters) for noise parameter overrides """ measurements = dict() petab_indices = dict() n_pars = dict() for col in [ petabv2.C.OBSERVABLE_PARAMETERS, petabv2.C.NOISE_PARAMETERS, ]: n_pars[col] = 0 if col in self._petab_problem.measurement_df: if pd.api.types.is_numeric_dtype( self._petab_problem.measurement_df[col].dtype ): n_pars[col] = 1 - int( self._petab_problem.measurement_df[col].isna().all() ) else: n_pars[col] = ( self._petab_problem.measurement_df[col] .str.split(petabv2.C.PARAMETER_SEPARATOR) .apply( lambda x: ( len(x) if isinstance(x, Sized) else 1 - int(pd.isna(x)) ) ) .max() ) for _, simulation_condition in simulation_conditions.iterrows(): if ( "preequilibration" in simulation_condition[petabv2.C.CONDITION_ID] ): continue if isinstance(self._petab_problem, HybridV2Problem): query = " & ".join( [ f"{k} == '{v}'" if isinstance(v, str) else f"{k} == {v}" for k, v in simulation_condition.items() if k != petabv2.C.CONDITION_ID ] ) else: query = " & ".join( [f"{k} == '{v}'" for k, v in simulation_condition.items()] ) m = self._petab_problem.measurement_df.query(query).sort_values( by=petabv2.C.TIME ) ts = m[petabv2.C.TIME] ts_dyn = ts[np.isfinite(ts)] ts_posteq = ts[np.logical_not(np.isfinite(ts))] index = pd.concat([ts_dyn, ts_posteq]).index ts_dyn = ts_dyn.values ts_posteq = ts_posteq.values my = m[petabv2.C.MEASUREMENT].values iys = np.array( [ self.model.observable_ids.index(oid) for oid in m[petabv2.C.OBSERVABLE_ID].values ] ) if ( petabv2.C.NOISE_DISTRIBUTION in self._petab_problem.observable_df ): iy_trafos = np.array( [ SCALE_TO_INT[petabv2.C.LOG] if obs.noise_distribution == petabv2.C.LOG_NORMAL else SCALE_TO_INT[petabv2.C.LIN] for obs in self._petab_problem.observables ] ) else: iy_trafos = np.zeros_like(iys) parameter_overrides_par_indices = dict() parameter_overrides_numeric_vals = dict() parameter_overrides_mask = dict() def get_parameter_override(x): if ( x in self._petab_problem.parameter_df.index and not self._petab_problem.parameter_df.loc[ x, petabv2.C.ESTIMATE ] ): return self._petab_problem.parameter_df.loc[ x, petabv2.C.NOMINAL_VALUE ] return x for col in [ petabv2.C.OBSERVABLE_PARAMETERS, petabv2.C.NOISE_PARAMETERS, ]: if col not in m or m[col].isna().all() or all(m[col] == ""): mat_numeric = jnp.ones((len(m), n_pars[col])) par_mask = np.zeros_like(mat_numeric, dtype=bool) par_index = np.zeros_like(mat_numeric, dtype=int) elif pd.api.types.is_numeric_dtype(m[col].dtype): mat_numeric = np.expand_dims(m[col].values, axis=1) par_mask = np.zeros_like(mat_numeric, dtype=bool) par_index = np.zeros_like(mat_numeric, dtype=int) else: split_vals = m[col].str.split( petabv2.C.PARAMETER_SEPARATOR ) list_vals = split_vals.apply( lambda x: ( [get_parameter_override(y) for y in x] if isinstance(x, list) else [] if pd.isna(x) else [x] ) # every string gets transformed to lists, so this is already a float ) vals = list_vals.apply( lambda x: np.pad( x, (0, n_pars[col] - len(x)), mode="constant", constant_values=1.0, ) ) mat = np.stack(vals) # deconstruct such that we can reconstruct mapped parameter overrides via vectorized operations # mat = np.where(par_mask, map(lambda ip: p.at[ip], par_index), mat_numeric) par_index = np.vectorize( lambda x: ( self.parameter_ids.index(x) if x in self.parameter_ids else -1 ) )(mat) # map out numeric values par_mask = par_index != -1 # remove non-numeric values mat[par_mask] = 0.0 mat_numeric = mat.astype(float) # replace dummy index with some valid index par_index[~par_mask] = 0 parameter_overrides_numeric_vals[col] = mat_numeric parameter_overrides_mask[col] = par_mask parameter_overrides_par_indices[col] = par_index measurements[tuple(simulation_condition)] = ( ts_dyn, # 0 ts_posteq, # 1 my, # 2 iys, # 3 iy_trafos, # 4 parameter_overrides_numeric_vals[ petabv2.C.OBSERVABLE_PARAMETERS ], # 5 parameter_overrides_mask[petabv2.C.OBSERVABLE_PARAMETERS], # 6 parameter_overrides_par_indices[ petabv2.C.OBSERVABLE_PARAMETERS ], # 7 parameter_overrides_numeric_vals[ petabv2.C.NOISE_PARAMETERS ], # 8 parameter_overrides_mask[petabv2.C.NOISE_PARAMETERS], # 9 parameter_overrides_par_indices[ petabv2.C.NOISE_PARAMETERS ], # 10 ) petab_indices[tuple(simulation_condition)] = tuple(index.tolist()) # compute maximum lengths n_ts_dyn = max(len(mv[0]) for mv in measurements.values()) n_ts_posteq = max(len(mv[1]) for mv in measurements.values()) # pad with last value and stack ts_dyn = np.stack( [ np.pad(mv[0], (0, n_ts_dyn - len(mv[0])), mode="edge") for mv in measurements.values() ] ) ts_posteq = np.stack( [ np.pad(mv[1], (0, n_ts_posteq - len(mv[1])), mode="edge") for mv in measurements.values() ] ) def pad_measurement(x_dyn, x_peq): # only pad first axis pad_width_dyn = tuple( [(0, n_ts_dyn - len(x_dyn))] + [(0, 0)] * (x_dyn.ndim - 1) ) pad_width_peq = tuple( [(0, n_ts_posteq - len(x_peq))] + [(0, 0)] * (x_peq.ndim - 1) ) return np.concatenate( ( np.pad(x_dyn, pad_width_dyn, mode="edge"), np.pad(x_peq, pad_width_peq, mode="edge"), ) ) def pad_and_stack(output_index: int): return np.stack( [ pad_measurement( mv[output_index][: len(mv[0])], mv[output_index][len(mv[0]) :], ) for mv in measurements.values() ] ) my = pad_and_stack(2) iys = pad_and_stack(3) iy_trafos = pad_and_stack(4) op_numeric = pad_and_stack(5) op_mask = pad_and_stack(6) op_indices = pad_and_stack(7) np_numeric = pad_and_stack(8) np_mask = pad_and_stack(9) np_indices = pad_and_stack(10) ts_masks = np.stack( [ np.concatenate( ( np.pad( np.ones_like(mv[0]), (0, n_ts_dyn - len(mv[0])) ), np.pad( np.ones_like(mv[1]), (0, n_ts_posteq - len(mv[1])) ), ) ) for mv in measurements.values() ] ).astype(bool) petab_indices = np.stack( [ pad_measurement( np.array(idx[: len(mv[0])]), np.array(idx[len(mv[0]) :]), ) for mv, idx in zip( measurements.values(), petab_indices.values() ) ] ) return ( ts_dyn, ts_posteq, my, iys, iy_trafos, ts_masks, petab_indices, op_numeric, op_mask, op_indices, np_numeric, np_mask, np_indices, )
[docs] def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]: if isinstance(self._petab_problem, HybridV2Problem): simulation_conditions = get_simulation_conditions_v2( self._petab_problem ) return tuple( tuple([row.conditionId]) for _, row in simulation_conditions.iterrows() ) else: simulation_conditions = self._petab_problem.get_simulation_conditions_from_measurement_df() return tuple( tuple(row) for _, row in simulation_conditions.iterrows() )
def _initialize_model_parameters(self, model: JAXModel) -> dict: """ Initialize model parameter structure with zeros. :param model: JAX model with neural networks :return: Nested dictionary structure for model parameters """ return { net_id: { layer_id: { attribute: jnp.zeros_like(getattr(layer, attribute)) for attribute in ["weight", "bias"] if hasattr(layer, attribute) } for layer_id, layer in nn.layers.items() } for net_id, nn in model.nns.items() } def _load_parameter_arrays_from_files(self) -> dict: """ Load neural network parameter arrays from HDF5 files. :return: Dictionary mapping network IDs to parameter arrays """ if not self._petab_problem.extensions_config: return {} array_files = self._petab_problem.extensions_config["sciml"].get( "array_files", [] ) import h5py # TODO(performance): Avoid opening each file multiple times return { file_spec.split("_")[0]: h5py.File(file_spec, "r")["parameters"][ file_spec.split("_")[0] ] for file_spec in array_files if "parameters" in h5py.File(file_spec, "r").keys() } def _load_input_arrays_from_files(self) -> dict: """ Load neural network input arrays from HDF5 files. :return: Dictionary mapping network IDs to input arrays """ if not self._petab_problem.extensions_config: return {} array_files = self._petab_problem.extensions_config["sciml"].get( "array_files", [] ) import h5py # TODO(performance): Avoid opening each file multiple times return { file_spec.split("_")[0]: h5py.File(file_spec, "r")["inputs"] for file_spec in array_files if "inputs" in h5py.File(file_spec, "r").keys() } def _parse_parameter_name( self, pname: str, model_pars: dict ) -> list[tuple[str, str]]: """ Parse parameter name to determine which layers and attributes to set. :param pname: Parameter name from PEtab (format: net.layer.attribute) :param model_pars: Model parameters dictionary :return: List of (layer_name, attribute_name) tuples to set """ net = pname.split("_")[0] nn = model_pars[net] to_set = [] name_parts = pname.split(".") if len(name_parts) > 1: layer_name = name_parts[1] layer = nn[layer_name] if len(name_parts) > 2: # Specific attribute specified attribute_name = name_parts[2] to_set.append((layer_name, attribute_name)) else: # All attributes of the layer to_set.extend( [(layer_name, attribute) for attribute in layer.keys()] ) else: # All layers and attributes to_set.extend( [ (layer_name, attribute) for layer_name, layer in nn.items() for attribute in layer.keys() ] ) return to_set def _extract_nominal_values_from_petab( self, model: JAXModel, model_pars: dict, par_arrays: dict ) -> None: """ Extract nominal parameter values from PEtab problem and populate model_pars. :param model: JAX model :param model_pars: Model parameters dictionary to populate (modified in place) :param par_arrays: Parameter arrays loaded from files """ for pname, row in self._petab_problem.parameter_df.iterrows(): net = pname.split("_")[0] if net not in model.nns: continue nn = model_pars[net] scalar = True # Determine value source (scalar from PEtab or array from file) if np.isnan(row[petabv2.C.NOMINAL_VALUE]): value = par_arrays[net] scalar = False else: value = float(row[petabv2.C.NOMINAL_VALUE]) # Parse parameter name and set values to_set = self._parse_parameter_name(pname, model_pars) for layer, attribute in to_set: if scalar: nn[layer][attribute] = value * jnp.ones_like( getattr(model.nns[net].layers[layer], attribute) ) else: nn[layer][attribute] = jnp.array( value[layer][attribute][:] ) def _set_model_parameters( self, model: JAXModel, model_pars: dict ) -> JAXModel: """ Set parameter values in the model using equinox tree_at. :param model: JAX model to update :param model_pars: Dictionary of parameter values to set :return: Updated JAX model """ for net_id in model_pars: for layer_id in model_pars[net_id]: for attribute in model_pars[net_id][layer_id]: logger.debug( f"Setting {attribute} of layer {layer_id} in network " f"{net_id} to {model_pars[net_id][layer_id][attribute]}" ) model = eqx.tree_at( lambda model: getattr( model.nns[net_id].layers[layer_id], attribute ), model, model_pars[net_id][layer_id][attribute], ) return model def _set_input_arrays( self, model: JAXModel, nn_input_arrays: dict, model_pars: dict ) -> JAXModel: """ Set input arrays in the model if provided. :param model: JAX model to update :param nn_input_arrays: Input arrays loaded from files :param model_pars: Model parameters dictionary (for network IDs) :return: Updated JAX model """ if len(nn_input_arrays) == 0: return model for net_id in model_pars: input_array = { input: { k: jnp.array( arr[:], dtype=jnp.float64 if jax.config.jax_enable_x64 else jnp.float32, ) for k, arr in nn_input_arrays[net_id][input].items() } for input in model.nns[net_id].inputs } model = eqx.tree_at( lambda model: model.nns[net_id].inputs, model, input_array ) return model def _create_scaled_parameter_array(self) -> jt.Float[jt.Array, "np"]: """ Create array of scaled nominal parameter values for estimation. :return: JAX array of scaled parameter values """ return jnp.array( [ petabv2.scale( float( self._petab_problem.parameter_df.loc[ pval, petabv2.C.NOMINAL_VALUE ] ), self._petab_problem.parameter_df.loc[ pval, petabv2.PARAMETER_SCALE ], ) for pval in self.parameter_ids ] ) def _initialize_model_with_nominal_values( self, model: JAXModel ) -> tuple[jt.Float[jt.Array, "np"], JAXModel]: """ Initialize the model with nominal parameter values and inputs from the PEtab problem. This method: - Initializes model parameter structure - Loads parameter and input arrays from HDF5 files - Extracts nominal values from PEtab problem - Sets parameter values in the model - Sets input arrays in the model - Creates scaled parameter array to initialized to nominal values :param model: JAX model to initialize :return: Tuple of (scaled parameter array, initialized model) """ # Initialize model parameters structure model_pars = self._initialize_model_parameters(model) # Load arrays from files (getters) par_arrays = self._load_parameter_arrays_from_files() nn_input_arrays = self._load_input_arrays_from_files() # Extract nominal values from PEtab problem self._extract_nominal_values_from_petab(model, model_pars, par_arrays) # Set values in model (setters) model = self._set_model_parameters(model, model_pars) model = self._set_input_arrays(model, nn_input_arrays, model_pars) # Create scaled parameter array if isinstance(self._petab_problem, HybridV2Problem): param_map = { p.id: p.nominal_value for p in self._petab_problem.parameters } parameter_array = jnp.array( [float(param_map[pval]) for pval in self.parameter_ids] ) else: parameter_array = self._create_scaled_parameter_array() return parameter_array, model def _get_inputs(self) -> dict: if self._petab_problem.mapping_df is None: return {} inputs = {net: {} for net in self.model.nns.keys()} for petab_id, row in self._petab_problem.mapping_df.iterrows(): if (filepath := Path(petab_id)).is_file(): data_flat = pd.read_csv(filepath, sep="\t").sort_values( by="ix" ) shape = tuple( np.stack( data_flat["ix"] .astype(str) .str.split(";") .apply(np.array) ) .astype(int) .max(axis=0) + 1 ) inputs[row["netId"]][row[petabv2.C.MODEL_ENTITY_ID]] = ( data_flat["value"].values.reshape(shape) ) return inputs @property def parameter_ids(self) -> list[str]: """ Parameter ids that are estimated in the PEtab problem. Same ordering as values in :attr:`parameters`. :return: PEtab parameter ids """ return self._petab_problem.parameter_df[ petabv2.C.ESTIMATE ].index.tolist() @property def nn_output_ids(self) -> list[str]: """ Parameter ids that are estimated in the PEtab problem. Same ordering as values in :attr:`parameters`. :return: PEtab parameter ids """ if self._petab_problem.mapping_df is None: return [] if ( self._petab_problem.mapping_df[petabv2.C.MODEL_ENTITY_ID] .isnull() .all() ): return [] return self._petab_problem.mapping_df[ self._petab_problem.mapping_df[petabv2.C.MODEL_ENTITY_ID] .str.split(".") .str[1] .str.startswith("output") ].index.tolist()
[docs] def get_petab_parameter_by_id(self, name: str) -> jnp.float_: """ Get the value of a PEtab parameter by name. :param name: PEtab parameter id, as returned by :attr:`parameter_ids`. :return: Value of the parameter """ return self.parameters[self.parameter_ids.index(name)]
def _unscale( self, p: jt.Float[jt.Array, "np"], scales: tuple[str, ...] ) -> jt.Float[jt.Array, "np"]: """ Unscaling of parameters. :param p: Parameter values :param scales: Parameter scalings :return: Unscaled parameter values """ return jnp.array( [jax_unscale(pval, scale) for pval, scale in zip(p, scales)] ) def _eval_nn(self, output_par: str, condition_id: str): net_id = self._petab_problem.mapping_df.loc[ output_par, petabv2.C.MODEL_ENTITY_ID ].split(".")[0] nn = self.model.nns[net_id] def _is_net_input(model_id): comps = model_id.split(".") return comps[0] == net_id and comps[1].startswith("inputs") model_id_map = ( self._petab_problem.mapping_df[ self._petab_problem.mapping_df[ petabv2.C.MODEL_ENTITY_ID ].apply(_is_net_input) ] .reset_index() .set_index(petabv2.C.MODEL_ENTITY_ID)[petabv2.C.PETAB_ENTITY_ID] .to_dict() ) condition_input_map = ( dict( [ ( petab_id, self._petab_problem.parameter_df.loc[ self._petab_problem.condition_df.loc[ condition_id, petab_id ], petabv2.C.NOMINAL_VALUE, ], ) if self._petab_problem.condition_df.loc[ condition_id, petab_id ] in self._petab_problem.parameter_df.index else ( petab_id, np.float64( self._petab_problem.condition_df.loc[ condition_id, petab_id ] ), ) for petab_id in model_id_map.values() ] ) if not self._petab_problem.condition_df.empty else {} ) hybridization_parameter_map = { petab_id: self._petab_problem.hybridization_df.loc[ petab_id, "targetValue" ] for petab_id in model_id_map.values() if petab_id in set(self._petab_problem.hybridization_df.index) } # handle conditions if len(condition_input_map) > 0: net_input = jnp.array( [ condition_input_map[petab_id] for _, petab_id in model_id_map.items() ] ) return nn.forward(net_input).squeeze() # handle array inputs if isinstance(self.model.nns[net_id].inputs, dict): net_input = jnp.array( [ self.model.nns[net_id].inputs[petab_id][condition_id] if condition_id in self.model.nns[net_id].inputs[petab_id] else self.model.nns[net_id].inputs[petab_id]["0"] for _, petab_id in model_id_map.items() ] ) return nn.forward(net_input).squeeze() net_input = jnp.array( [ jax.lax.stop_gradient(self.model.nns[net_id][model_id]) if model_id in self.model.nns[net_id].inputs else self.get_petab_parameter_by_id(petab_id) if petab_id in self.parameter_ids else self._petab_problem.parameter_df.loc[ petab_id, petabv2.C.NOMINAL_VALUE ] if petab_id in set(self._petab_problem.parameter_df.index) else self._petab_problem.parameter_df.loc[ hybridization_parameter_map[petab_id], petabv2.C.NOMINAL_VALUE, ] for model_id, petab_id in model_id_map.items() ] ) return nn.forward(net_input).squeeze() def _map_model_parameter_value( self, mapping: ParameterMappingForCondition, pname: str, condition_id: str, ) -> jt.Float[jt.Scalar, ""] | float: # noqa: F722 pval = mapping.map_sim_var[pname] if hasattr(self, "nn_output_ids") and pval in self.nn_output_ids: nn_output = self._eval_nn(pval, condition_id) if nn_output.size > 1: entityId = self._petab_problem.mapping_df.loc[ pval, petabv2.C.MODEL_ENTITY_ID ] ind = int(re.search(r"\[\d+\]\[(\d+)\]", entityId).group(1)) return nn_output[ind] else: return nn_output if isinstance(pval, Number): return pval return self.get_petab_parameter_by_id(pval)
[docs] def load_model_parameters( self, experiment: petabv2.Experiment, is_preeq: bool ) -> jt.Float[jt.Array, "np"]: """ Load parameters for an experiment. :param experiment: Experiment to load parameters for. :param is_preeq: Whether to load preequilibration or simulation parameters. :return: Parameters for the experiment. """ p = jnp.array( [ self._map_experiment_model_parameter_value( pname, ind, experiment, is_preeq ) for ind, pname in enumerate(self.model.parameter_ids) ] ) pscale = tuple([petabv2.C.LIN for _ in self.model.parameter_ids]) return self._unscale(p, pscale)
def _map_experiment_model_parameter_value( self, pname: str, p_index: int, experiment: petabv2.Experiment, is_preeq: bool, ): """ Get values for the given parameter `pname` from the relevant petab tables. :param pname: PEtab parameter id :param p_index: Index of the parameter in the model's parameter list :param experiment: PEtab experiment :param is_preeq: Whether to get preequilibration or simulation parameter value :return: Value of the parameter """ condition_ids = [] for p in experiment.sorted_periods: if is_preeq: if not p.is_preequilibration: continue else: condition_ids = p.condition_ids break else: if p.is_preequilibration: continue else: condition_ids = p.condition_ids break init_val = self.model.parameters[p_index] params_nominals = { p.id: p.nominal_value for p in self._petab_problem.parameters } targets_map = { ch.target_id: ch.target_value for c in self._petab_problem.conditions for ch in c.changes if c.id in condition_ids } if pname in params_nominals: return params_nominals[pname] elif pname in targets_map: return float(targets_map[pname]) else: for placeholder_attr, param_attr in ( ("observable_placeholders", "observable_parameters"), ("noise_placeholders", "noise_parameters"), ): placeholders = [ getattr(o, placeholder_attr) for o in self._petab_problem.observables ] for placeholders in placeholders: params_list = getattr( self._petab_problem.measurements[0], param_attr ) for i, p in enumerate(placeholders): if str(p) == pname: val = self._find_val( str(params_list[i]), params_nominals ) return val return init_val def _find_val(self, param_entry: str, params_nominals: dict): val_float = _try_float(param_entry) if isinstance(val_float, float): return val_float elif param_entry in params_nominals: return params_nominals[param_entry] else: return param_entry def _state_needs_reinitialisation( self, simulation_condition: str, state_id: str, ) -> bool: """ Check if a state needs reinitialisation for a simulation condition. :param simulation_condition: simulation condition to check reinitialisation for :param state_id: state id to check reinitialisation for :return: True if state needs reinitialisation, False otherwise """ if state_id in self.nn_output_ids: return True if state_id not in self._petab_problem.condition_df: return False xval = self._petab_problem.condition_df.loc[ simulation_condition, state_id ] if isinstance(xval, Number) and np.isnan(xval): return False return True def _state_reinitialisation_value( self, simulation_condition: str, state_id: str, p: jt.Float[jt.Array, "np"], ) -> jt.Float[jt.Scalar, ""] | float: # noqa: F722 """ Get the reinitialisation value for a state. :param simulation_condition: simulation condition to get reinitialisation value for :param state_id: state id to get reinitialisation value for :param p: parameters for the simulation condition :return: reinitialisation value for the state """ if state_id in self.nn_output_ids: return self._eval_nn(state_id) if state_id not in self._petab_problem.condition_df: # no reinitialisation, return dummy value return 0.0 xval = self._petab_problem.condition_df.loc[ simulation_condition, state_id ] if isinstance(xval, Number) and np.isnan(xval): # no reinitialisation, return dummy value return 0.0 if isinstance(xval, Number): # numerical value, return as is return xval if xval in self.model.parameter_ids: # model parameter, return value return p[self.model.parameter_ids.index(xval)] if xval in self.parameter_ids: # estimated PEtab parameter, return unscaled value return jax_unscale( self.get_petab_parameter_by_id(xval), self._petab_problem.parameter_df.loc[ xval, petabv2.PARAMETER_SCALE ], ) # only remaining option is nominal value for PEtab parameter # that is not estimated, return nominal value return self._petab_problem.parameter_df.loc[ xval, petabv2.C.NOMINAL_VALUE ]
[docs] def load_reinitialisation( self, simulation_condition: str, p: jt.Float[jt.Array, "np"], ) -> tuple[jt.Bool[jt.Array, "nx"], jt.Float[jt.Array, "nx"]]: # noqa: F821 """ Load reinitialisation values and mask for the state vector for a simulation condition. :param simulation_condition: Simulation condition to load reinitialisation for. :param p: Parameters for the simulation condition. :return: Tuple of reinitialisation masm and value for states. """ if not any( x_id in self._petab_problem.condition_df or hasattr(self, "nn_output_ids") and x_id in self.nn_output_ids for x_id in self.model.state_ids ): return jnp.array([]), jnp.array([]) mask = jnp.array( [ self._state_needs_reinitialisation(simulation_condition, x_id) for x_id in self.model.state_ids ] ) reinit_x = jnp.array( [ self._state_reinitialisation_value( simulation_condition, x_id, p ) for x_id in self.model.state_ids ] ) return mask, reinit_x
[docs] def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem": """ Update parameters for the model. :param p: New problem instance with updated parameters. """ return eqx.tree_at(lambda p: p.parameters, self, p)
def _prepare_experiments( self, experiments: list[petabv2.Experiment], conditions: list[str], is_preeq: bool, op_numeric: np.ndarray | None = None, op_mask: np.ndarray | None = None, op_indices: np.ndarray | None = None, np_numeric: np.ndarray | None = None, np_mask: np.ndarray | None = None, np_indices: np.ndarray | None = None, ) -> tuple[ jt.Float[jt.Array, "nc np"], # noqa: F821, F722 jt.Bool[jt.Array, "nx"], # noqa: F821 jt.Float[jt.Array, "nx"], # noqa: F821 jt.Float[jt.Array, "nc nt nop"], # noqa: F821, F722 jt.Float[jt.Array, "nc nt nnp"], # noqa: F821, F722 ]: """ Prepare experiments for simulation. :param experiments: Experiments to prepare simulation arrays for. :param conditions: Simulation conditions to prepare. :param is_preeq: Whether to load preequilibration or simulation parameters. :param op_numeric: Numeric values for observable parameter overrides. If None, no overrides are used. :param op_mask: Mask for observable parameter overrides. True for free parameter overrides, False for numeric values. :param op_indices: Free parameter indices (wrt. `self.parameters`) for observable parameter overrides. :param np_numeric: Numeric values for noise parameter overrides. If None, no overrides are used. :param np_mask: Mask for noise parameter overrides. True for free parameter overrides, False for numeric values. :param np_indices: Free parameter indices (wrt. `self.parameters`) for noise parameter overrides. :return: Tuple of parameter arrays, reinitialisation masks and reinitialisation values, observable parameters and noise parameters. """ p_array = jnp.stack( [self.load_model_parameters(exp, is_preeq) for exp in experiments] ) exp_ids = [exp.id for exp in experiments] all_exp_ids = [exp.id for exp in self._petab_problem.experiments] h_mask = jnp.stack( [ jnp.ones(self.model.n_events) if (exp_id in exp_ids) else jnp.zeros(self.model.n_events) for exp_id in all_exp_ids ] ) t_zeros = jnp.stack( [ exp.periods[0].time if exp.periods[0].time >= 0.0 else 0.0 for exp in experiments ] ) if self.parameters.size: if isinstance(self._petab_problem, HybridV2Problem): unscaled_parameters = jnp.stack( [ self.parameters[ip] for ip, p_id in enumerate(self.parameter_ids) ] ) else: unscaled_parameters = jnp.stack( [ jax_unscale( self.parameters[ip], self._petab_problem.parameter_df.loc[ p_id, petabv2.C.PARAMETER_SCALE ], ) for ip, p_id in enumerate(self.parameter_ids) ] ) else: unscaled_parameters = jnp.zeros((*self._ts_masks.shape[:2], 0)) # placeholder values from sundials code may be needed here if op_numeric is not None and op_numeric.size: op_array = jnp.where( op_mask, jax.vmap( jax.vmap(jax.vmap(lambda ip: unscaled_parameters[ip])) )(op_indices), op_numeric, ) else: op_array = jnp.zeros((*self._ts_masks.shape[:2], 0)) if np_numeric is not None and np_numeric.size: np_array = jnp.where( np_mask, jax.vmap( jax.vmap(jax.vmap(lambda ip: unscaled_parameters[ip])) )(np_indices), np_numeric, ) else: np_array = jnp.zeros((*self._ts_masks.shape[:2], 0)) mask_reinit_array = jnp.stack( [ self.load_reinitialisation(sc, p)[0] for sc, p in zip(conditions, p_array) ] ) x_reinit_array = jnp.stack( [ self.load_reinitialisation(sc, p)[1] for sc, p in zip(conditions, p_array) ] ) return ( p_array, mask_reinit_array, x_reinit_array, op_array, np_array, h_mask, t_zeros, )
[docs] @eqx.filter_vmap( in_axes={ "max_steps": None, "self": None, }, # only list arguments here where eqx.is_array(0) is not the right thing ) def run_simulation( self, p: jt.Float[jt.Array, "np"], # noqa: F821, F722 ts_dyn: np.ndarray, ts_posteq: np.ndarray, my: np.ndarray, iys: np.ndarray, iy_trafos: np.ndarray, ops: jt.Float[jt.Array, "nt *nop"], # noqa: F821, F722 nps: jt.Float[jt.Array, "nt *nnp"], # noqa: F821, F722 mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722 init_override: jt.Float[jt.Array, "nx"], # noqa: F821, F722 init_override_mask: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 h_mask: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, steady_state_event: Callable[ ..., diffrax._custom_types.BoolScalarLike ], max_steps: jnp.int_, x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722 h_preeq: jt.Bool[jt.Array, "*ne"] = jnp.array([]), # noqa: F821, F722 ts_mask: np.ndarray = np.array([]), t_zeros: jnp.float_ = 0.0, ret: ReturnValue = ReturnValue.llh, ) -> tuple[jnp.float_, dict]: """ Run a simulation for a given simulation experiment. :param p: Parameters for the simulation experiment :param ts_dyn: (Padded) dynamic time points :param ts_posteq: (Padded) post-equilibrium time points :param my: (Padded) measurements :param iys: (Padded) observable indices :param iy_trafos: (Padded) observable transformations indices :param ops: (Padded) observable parameters :param nps: (Padded) noise parameters :param mask_reinit: Mask for states that need reinitialisation :param x_reinit: Reinitialisation values for states :param h_mask: Mask for the events that are part of the current experiment :param solver: ODE solver to use for simulation :param controller: Step size controller to use for simulation :param steady_state_event: Steady state event function to use for post-equilibration. Allows customisation of the steady state condition, see :func:`diffrax.steady_state_event` for details. :param max_steps: Maximum number of steps to take during simulation :param x_preeq: Pre-equilibration state. Can be empty if no pre-equilibration is available, in which case the states will be initialised to the model default values. :param h_preeq: Pre-equilibration event mask. Can be empty if no pre-equilibration is available :param ts_mask: padding mask, see :meth:`JAXModel.simulate_condition` for details. :param t_zeros: simulation start time for the current experiment. :param ret: which output to return. See :class:`ReturnValue` for available options. :return: Tuple of output value and simulation statistics """ return self.model.simulate_condition( p=p, ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)), ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)), my=jax.lax.stop_gradient(jnp.array(my)), iys=jax.lax.stop_gradient(jnp.array(iys)), iy_trafos=jax.lax.stop_gradient(jnp.array(iy_trafos)), nps=nps, ops=ops, x_preeq=x_preeq, h_preeq=h_preeq, mask_reinit=jax.lax.stop_gradient(mask_reinit), x_reinit=x_reinit, init_override=init_override, init_override_mask=jax.lax.stop_gradient(init_override_mask), ts_mask=jax.lax.stop_gradient(jnp.array(ts_mask)), h_mask=jax.lax.stop_gradient(jnp.array(h_mask)), t_zero=t_zeros, solver=solver, controller=controller, root_finder=root_finder, max_steps=max_steps, steady_state_event=steady_state_event, adjoint=diffrax.RecursiveCheckpointAdjoint() if ret in (ReturnValue.llh, ReturnValue.chi2) else diffrax.DirectAdjoint(), ret=ret, )
[docs] def run_simulations( self, experiments: list[petabv2.Experiment], preeq_array: jt.Float[jt.Array, "ncond *nx"], # noqa: F821, F722 h_preeqs: jt.Bool[jt.Array, "ncond *ne"], # noqa: F821 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, steady_state_event: Callable[ ..., diffrax._custom_types.BoolScalarLike ], max_steps: jnp.int_, ret: ReturnValue = ReturnValue.llh, ): """ Run simulations for a list of simulation experiments. :param experiments: Experiments to run simulations for. :param preeq_array: Matrix of pre-equilibrated states for the simulation conditions. Ordering must match the simulation conditions. If no pre-equilibration is available for a condition, the corresponding row must be empty. :param h_preeqs: Matrix of pre-equilibration event heaviside variables indicating whether an event condition is false or true after preequilibration. :param solver: ODE solver to use for simulation. :param controller: Step size controller to use for simulation. :param steady_state_event: Steady state event function to use for post-equilibration. Allows customisation of the steady state condition, see :func:`diffrax.steady_state_event` for details. :param max_steps: Maximum number of steps to take during simulation. :param ret: which output to return. See :class:`ReturnValue` for available options. :return: Output value and condition specific results and statistics. Results and statistics are returned as a dict with arrays with the leading dimension corresponding to the simulation conditions. """ simulation_conditions = [ cid for exp in experiments for p in exp.periods for cid in p.condition_ids ] dynamic_conditions = list( sc for sc in simulation_conditions if "preequilibration" not in sc ) dynamic_conditions = list(dict.fromkeys(dynamic_conditions)) ( p_array, mask_reinit_array, x_reinit_array, op_array, np_array, h_mask, t_zeros, ) = self._prepare_experiments( experiments, dynamic_conditions, False, self._op_numeric, self._op_mask, self._op_indices, self._np_numeric, self._np_mask, self._np_indices, ) init_override_mask = jnp.stack( [ jnp.array( [ p in set(self.model.parameter_ids) for p in self.model.state_ids ] ) for _ in experiments ] ) init_override = jnp.stack( [ jnp.array( [ self._eval_nn( p, exp.periods[-1].condition_ids[0] ) # TODO: Add mapping of p to eval_nn? if p in set(self.model.parameter_ids) else 1.0 for p in self.model.state_ids ] ) for exp in experiments ] ) return self.run_simulation( p_array, self._ts_dyn, self._ts_posteq, self._my, self._iys, self._iy_trafos, op_array, np_array, mask_reinit_array, x_reinit_array, init_override, init_override_mask, h_mask, solver, controller, root_finder, steady_state_event, max_steps, preeq_array, h_preeqs, self._ts_masks, t_zeros, ret, )
[docs] @eqx.filter_vmap( in_axes={ "max_steps": None, "self": None, }, # only list arguments here where eqx.is_array(0) is not the right thing ) def run_preequilibration( self, p: jt.Float[jt.Array, "np"], # noqa: F821, F722 mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722 h_mask: jt.Bool[jt.Array, "ne"], # noqa: F821, F722 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, steady_state_event: Callable[ ..., diffrax._custom_types.BoolScalarLike ], max_steps: jnp.int_, ) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821 """ Run a pre-equilibration simulation for a given simulation experiment. :param p: Parameters for the simulation experiment :param mask_reinit: Mask for states that need reinitialisation :param x_reinit: Reinitialisation values for states :param h_mask: Mask for the events that are part of the current experiment :param solver: ODE solver to use for simulation :param controller: Step size controller to use for simulation :param steady_state_event: Steady state event function to use for pre-equilibration. Allows customisation of the steady state condition, see :func:`diffrax.steady_state_event` for details. :param max_steps: Maximum number of steps to take during simulation :return: Pre-equilibration state """ return self.model.preequilibrate_condition( p=p, mask_reinit=mask_reinit, x_reinit=x_reinit, h_mask=h_mask, solver=solver, controller=controller, root_finder=root_finder, max_steps=max_steps, steady_state_event=steady_state_event, )
[docs] def run_preequilibrations( self, experiments: list[petabv2.Experiment], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, steady_state_event: Callable[ ..., diffrax._custom_types.BoolScalarLike ], max_steps: jnp.int_, ): simulation_conditions = [ cid for exp in experiments for p in exp.periods for cid in p.condition_ids ] preequilibration_conditions = list( {sc for sc in simulation_conditions if "preequilibration" in sc} ) p_array, mask_reinit_array, x_reinit_array, _, _, h_mask, _ = ( self._prepare_experiments( experiments, preequilibration_conditions, True, None, None ) ) return self.run_preequilibration( p_array, mask_reinit_array, x_reinit_array, h_mask, solver, controller, root_finder, steady_state_event, max_steps, )
[docs] def run_simulations( problem: JAXProblem, simulation_experiments: Iterable[str] | None = None, solver: diffrax.AbstractSolver = diffrax.Kvaerno5(), controller: diffrax.AbstractStepSizeController = diffrax.PIDController( **DEFAULT_CONTROLLER_SETTINGS ), root_finder: AbstractRootFinder = optimistix.Newton( **DEFAULT_ROOT_FINDER_SETTINGS ), steady_state_event: Callable[ ..., diffrax._custom_types.BoolScalarLike ] = diffrax.steady_state_event(), max_steps: int = 2**13, ret: ReturnValue | str = ReturnValue.llh, ): """ Run simulations for a problem. :param problem: Problem to run simulations for. :param simulation_experiments: Simulation experiments to run simulations for. This is an iterable of experiment ids. Default is to run simulations for all experiments. :param solver: ODE solver to use for simulation. :param controller: Step size controller to use for simulation. :param root_finder: Root finder to use for event detection. :param steady_state_event: Steady state event function to use for pre-/post-equilibration. Allows customisation of the steady state condition, see :func:`diffrax.steady_state_event` for details. :param max_steps: Maximum number of steps to take during simulation. :param ret: which output to return. See :class:`ReturnValue` for available options. :return: Overall output value and condition specific results and statistics. """ if isinstance(problem, HybridProblem) or isinstance( problem._petab_problem, petabv1.Problem ): raise TypeError( "run_simulations does not support PEtab v1 problems. Upgrade the problem to PEtab v2." ) if isinstance(ret, str): ret = ReturnValue[ret] if simulation_experiments is None: experiments = problem._petab_problem.experiments else: experiments = [ exp for exp in problem._petab_problem.experiments if exp.id in simulation_experiments ] simulation_conditions = [ cid for exp in experiments for p in exp.periods for cid in p.condition_ids ] dynamic_conditions = list( sc for sc in simulation_conditions if "preequilibration" not in sc ) dynamic_conditions = list(dict.fromkeys(dynamic_conditions)) conditions = { "dynamic_conditions": dynamic_conditions, } has_preeq = any(exp.periods[0].time < 0.0 for exp in experiments) if has_preeq: preeqs, preresults, h_preeqs = problem.run_preequilibrations( experiments, solver, controller, root_finder, steady_state_event, max_steps, ) preeqs_array = preeqs else: preresults = { "stats_preeq": None, } preeqs_array = jnp.stack([jnp.array([]) for _ in experiments]) h_preeqs = jnp.stack([jnp.array([]) for _ in experiments]) output, results = problem.run_simulations( experiments, preeqs_array, h_preeqs, solver, controller, root_finder, steady_state_event, max_steps, ret, ) if ret in (ReturnValue.llh, ReturnValue.chi2): if os.getenv("JAX_DEBUG") == "1": jax.debug.print( "ret: {}", ret, ) output = jnp.sum(output) return output, results | preresults | conditions
[docs] def petab_simulate( problem: JAXProblem, solver: diffrax.AbstractSolver = diffrax.Kvaerno5(), controller: diffrax.AbstractStepSizeController = diffrax.PIDController( **DEFAULT_CONTROLLER_SETTINGS ), steady_state_event: Callable[ ..., diffrax._custom_types.BoolScalarLike ] = diffrax.steady_state_event(), max_steps: int = 2**13, ): """ Run simulations for a problem and return the results as a petab simulation dataframe. :param problem: Problem to run simulations for. :param solver: ODE solver to use for simulation. :param controller: Step size controller to use for simulation. :param max_steps: Maximum number of steps to take during simulation. :param steady_state_event: Steady state event function to use for pre-/post-equilibration. Allows customisation of the steady state condition, see :func:`diffrax.steady_state_event` for details. :return: petab simulation dataframe. """ y, r = run_simulations( problem, solver=solver, controller=controller, steady_state_event=steady_state_event, max_steps=max_steps, ret=ReturnValue.y, ) if isinstance(problem._petab_problem, HybridV2Problem): return _build_simulation_df_v2(problem, y, r["dynamic_conditions"]) else: dfs = [] for ic, sc in enumerate(r["dynamic_conditions"]): obs = [ problem.model.observable_ids[io] for io in problem._iys[ic, problem._ts_masks[ic, :]] ] t = jnp.concat( ( problem._ts_dyn[ic, :], problem._ts_posteq[ic, :], ) ) df_sc = pd.DataFrame( { petabv2.C.SIMULATION: y[ic, problem._ts_masks[ic, :]], petabv2.C.TIME: t[problem._ts_masks[ic, :]], petabv2.C.OBSERVABLE_ID: obs, petabv2.C.CONDITION_ID: [sc] * len(t), }, index=problem._petab_measurement_indices[ic, :], ) if ( petabv2.C.OBSERVABLE_PARAMETERS in problem._petab_problem.measurement_df ): df_sc[petabv2.C.OBSERVABLE_PARAMETERS] = ( problem._petab_problem.measurement_df.query( f"{petabv2.C.CONDITION_ID} == '{sc}'" )[petabv2.C.OBSERVABLE_PARAMETERS] ) if ( petabv2.C.NOISE_PARAMETERS in problem._petab_problem.measurement_df ): df_sc[petabv2.C.NOISE_PARAMETERS] = ( problem._petab_problem.measurement_df.query( f"{petabv2.C.CONDITION_ID} == '{sc}'" )[petabv2.C.NOISE_PARAMETERS] ) if ( petabv2.C.PREEQUILIBRATION_CONDITION_ID in problem._petab_problem.measurement_df ): df_sc[petabv2.C.PREEQUILIBRATION_CONDITION_ID] = ( problem._petab_problem.measurement_df.query( f"{petabv2.C.CONDITION_ID} == '{sc}'" )[petabv2.C.PREEQUILIBRATION_CONDITION_ID] ) dfs.append(df_sc) return pd.concat(dfs).sort_index()
def add_default_experiment_names_to_v2_problem(petab_problem: petabv2.Problem): """Add default experiment names to PEtab v2 problem. Args: petab_problem: PEtab v2 problem to modify. """ if not hasattr(petab_problem, "extensions_config"): petab_problem.extensions_config = {} petab_problem.visualization_df = None if petab_problem.condition_df is None: default_condition = petabv2.core.Condition( id="__default__", changes=[], conditionId="__default__" ) petab_problem.condition_tables[0].elements = [default_condition] if ( petab_problem.experiment_df is None or petab_problem.experiment_df.empty ): condition_ids = petab_problem.condition_df[ petabv2.C.CONDITION_ID ].values condition_ids = [ c for c in condition_ids if "preequilibration" not in c ] default_experiment = petabv2.core.Experiment( id="__default__", periods=[ petabv2.core.ExperimentPeriod( time=0.0, condition_ids=condition_ids ) ], ) petab_problem.experiment_tables[0].elements = [default_experiment] measurement_tables = petab_problem.measurement_tables.copy() for mt in measurement_tables: for m in mt.elements: m.experiment_id = "__default__" petab_problem.measurement_tables = measurement_tables return petab_problem def get_simulation_conditions_v2(petab_problem) -> pd.DataFrame: """Get simulation conditions from PEtab v2 measurement DataFrame. Returns: A pandas DataFrame mapping experiment_ids to condition ids. """ experiment_df = petab_problem.experiment_df exps = {} for exp_id in experiment_df[petabv2.C.EXPERIMENT_ID].unique(): exps[exp_id] = experiment_df[ experiment_df[petabv2.C.EXPERIMENT_ID] == exp_id ][petabv2.C.CONDITION_ID].unique() experiment_df = experiment_df.drop(columns=[petabv2.C.TIME]) return experiment_df def _build_simulation_df_v2(problem, y, dyn_conditions): """Build petab simulation DataFrame of similation results from a PEtab v2 problem.""" dfs = [] for ic, sc in enumerate(dyn_conditions): experiment_id = _conditions_to_experiment_map( problem._petab_problem.experiment_df )[sc] if experiment_id == "__default__": experiment_id = jnp.nan obs = [ problem.model.observable_ids[io] for io in problem._iys[ic, problem._ts_masks[ic, :]] ] t = jnp.concat( ( problem._ts_dyn[ic, :], problem._ts_posteq[ic, :], ) ) df_sc = pd.DataFrame( { petabv2.C.MODEL_ID: [float("nan")] * len(t), petabv2.C.OBSERVABLE_ID: obs, petabv2.C.EXPERIMENT_ID: [experiment_id] * len(t), petabv2.C.TIME: t[problem._ts_masks[ic, :]], petabv2.C.SIMULATION: y[ic, problem._ts_masks[ic, :]], }, index=problem._petab_measurement_indices[ic, :], ) if ( petabv2.C.OBSERVABLE_PARAMETERS in problem._petab_problem.measurement_df ): df_sc[petabv2.C.OBSERVABLE_PARAMETERS] = ( problem._petab_problem.measurement_df.query( f"{petabv2.C.EXPERIMENT_ID} == '{experiment_id}'" )[petabv2.C.OBSERVABLE_PARAMETERS] ) if petabv2.C.NOISE_PARAMETERS in problem._petab_problem.measurement_df: df_sc[petabv2.C.NOISE_PARAMETERS] = ( problem._petab_problem.measurement_df.query( f"{petabv2.C.EXPERIMENT_ID} == '{experiment_id}'" )[petabv2.C.NOISE_PARAMETERS] ) dfs.append(df_sc) return pd.concat(dfs).sort_index() def _conditions_to_experiment_map( experiment_df: pd.DataFrame, ) -> dict[str, str]: condition_to_experiment = { row.conditionId: row.experimentId for row in experiment_df.itertuples() } return condition_to_experiment def _try_float(value): try: return float(value) except Exception as e: msg = str(e).lower() if isinstance(e, ValueError) and "could not convert" in msg: return value raise