Skip to content

otf.system

Utilities and system classes for OTF data assimilation.

This package exposes base system abstractions and helpers used when constructing dynamical systems for on-the-fly (OTF) data assimilation.

Modules:

Name Description
base

System abstractions for on-the-fly (OTF) data assimilation.

linear_nonlinear

Helpers for systems with a separable linear and nonlinear part.

utils

Utility functions for adapting ODEs and observation masks for use with

Classes:

Name Description
BaseSystem

Base abstraction for a dynamical system used with OTF assimilation.

System_LinearNonlinear_ModelKnown

Concrete system: known true-model with separable linear/nonlinear parts.

System_LinearNonlinear_ModelUnknown

Concrete system: unknown true-model, only assimilated parts provided.

System_ModelKnown

System where the true ODE is known and can be simulated.

System_ModelUnknown

System where the true ODE is unknown and cannot be simulated.

BaseSystem

BaseSystem(
    mu: float,
    cs: jndarray,
    observed_mask: jndarray,
    assimilated_ode: Callable[
        [jndarray, jndarray], jndarray
    ],
    complex_differentiation: bool = False,
)

Base abstraction for a dynamical system used with OTF assimilation.

This class wraps an ODE for an assimilated system together with nudging behavior that pushes the assimilated state toward observed portions of a (possibly partially observed) true state. Subclasses may provide a known true system or leave the true system unspecified.

Initialize the base system.

Parameters:

Name Type Description Default
mu float

Nudging parameter.

required
cs jndarray

Estimated parameter values used by the assimilated system (may differ from the true system parameters gs used by subclasses).

required
observed_mask jndarray

Boolean jnp.ndarray mask indicating observed entries of a flattened state. Nudging is applied only on these entries.

required
assimilated_ode Callable[[jndarray, jndarray], jndarray]

Callable (cs, state) -> state_dot producing the time derivative for the assimilated system given parameters cs.

required
complex_differentiation bool

If True, treat arrays as potentially complex for autodiff.

False

Methods:

Name Description
f_assimilated

Return time derivative of the assimilated state with nudging.

Source code in src/otf/system/base.py
def __init__(
    self,
    mu: float,
    cs: jndarray,
    observed_mask: jndarray,
    assimilated_ode: Callable[[jndarray, jndarray], jndarray],
    complex_differentiation: bool = False,
):
    """Initialize the base system.

    Parameters
    ----------
    mu
        Nudging parameter.
    cs
        Estimated parameter values used by the assimilated system (may
        differ from the true system parameters `gs` used by subclasses).
    observed_mask
        Boolean `jnp.ndarray` mask indicating observed entries of a
        flattened state. Nudging is applied only on these entries.
    assimilated_ode
        Callable `(cs, state) -> state_dot` producing the time derivative
        for the assimilated system given parameters `cs`.
    complex_differentiation
        If True, treat arrays as potentially complex for autodiff.
    """
    if not isinstance(observed_mask, jndarray):
        raise ValueError(
            "`observed_mask` must be jnp.ndarray boolean array"
        )

    self._mu = mu
    self._observed_mask = observed_mask
    self._unobserved_mask = ~observed_mask
    self._observe_all = not jnp.any(self._unobserved_mask)
    self._cs = cs
    self._assimilated_ode = assimilated_ode

    self._complex_differentiation = complex_differentiation

    _df_dc = jax.jacrev(
        self.assimilated_ode,
        0,
        holomorphic=self._complex_differentiation,
    )
    _df_dv = jax.jacrev(
        self.assimilated_ode,
        1,
        holomorphic=self._complex_differentiation,
    )
    if self._complex_differentiation:

        def df_dc(cs: jndarray, assimilated: jndarray) -> jndarray:
            return _df_dc(cs.astype(complex), assimilated)

        def df_dv(cs: jndarray, assimilated: jndarray) -> jndarray:
            return _df_dv(cs.astype(complex), assimilated)
    else:

        def df_dc(cs: jndarray, assimilated: jndarray) -> jndarray:
            return _df_dc(cs, assimilated)

        def df_dv(cs: jndarray, assimilated: jndarray) -> jndarray:
            return _df_dv(cs, assimilated)

    self._df_dc = df_dc
    self._df_dv = df_dv

f_assimilated

f_assimilated(
    cs: jndarray,
    true_observed: jndarray,
    assimilated: jndarray,
) -> jndarray

Return time derivative of the assimilated state with nudging.

The method applies assimilated_ode and then subtracts a nudging term on observed entries: mu * (assimilated - true_observed).

This method is suitable for JIT compilation.

Parameters:

Name Type Description Default
cs jndarray

Estimated parameter values for the assimilated ODE.

required
true_observed jndarray

Observed portion of the true state (matches observed_mask).

required
assimilated jndarray

Current assimilated (flattened) state.

required

Returns:

Type Description
ndarray

Time derivative of assimilated after applying nudging.

Source code in src/otf/system/base.py
def f_assimilated(
    self,
    cs: jndarray,
    true_observed: jndarray,
    assimilated: jndarray,
) -> jndarray:
    """Return time derivative of the assimilated state with nudging.

    The method applies `assimilated_ode` and then subtracts a nudging term
    on observed entries: `mu * (assimilated - true_observed)`.

    This method is suitable for JIT compilation.

    Parameters
    ----------
    cs
        Estimated parameter values for the assimilated ODE.
    true_observed
        Observed portion of the true state (matches `observed_mask`).
    assimilated
        Current assimilated (flattened) state.

    Returns
    -------
    jnp.ndarray
        Time derivative of `assimilated` after applying nudging.
    """
    mask = self.observed_mask

    assimilated_p = self._assimilated_ode(cs, assimilated)
    assimilated_p = assimilated_p.at[mask].subtract(
        self.mu * (assimilated[mask] - true_observed)
    )

    return assimilated_p

System_LinearNonlinear_ModelKnown

System_LinearNonlinear_ModelKnown(
    mu: float,
    gs: jndarray,
    cs: jndarray,
    observed_mask: jndarray,
    linear_assimilated: Callable[[jndarray], jndarray],
    nonlinear_assimilated_ode: Callable[
        [jndarray, jndarray], jndarray
    ],
    linear_true: Callable[[jndarray], jndarray],
    nonlinear_true_ode: Callable[
        [jndarray, jndarray], jndarray
    ],
    complex_differentiation: bool = False,
    true_observed_mask: jndarray | None = None,
)

Bases: _AssimilatedLinearNonlinearMixin, _TrueLinearNonlinearMixin, System_ModelKnown

Concrete system: known true-model with separable linear/nonlinear parts.

The class constructs compatible assimilated and true ODEs from provided linear and nonlinear component callables.

Source code in src/otf/system/linear_nonlinear.py
def __init__(
    self,
    mu: float,
    gs: jndarray,
    cs: jndarray,
    observed_mask: jndarray,
    linear_assimilated: Callable[[jndarray], jndarray],
    nonlinear_assimilated_ode: Callable[[jndarray, jndarray], jndarray],
    linear_true: Callable[[jndarray], jndarray],
    nonlinear_true_ode: Callable[[jndarray, jndarray], jndarray],
    complex_differentiation: bool = False,
    true_observed_mask: jndarray | None = None,
):
    self._set_assimilated_parts(
        linear_assimilated, nonlinear_assimilated_ode
    )
    assimilated_ode = self._define_ode(
        self.linear_assimilated, self.nonlinear_assimilated_ode
    )
    self._set_true_parts(linear_true, nonlinear_true_ode)
    true_ode = self._define_ode(linear_true, nonlinear_true_ode)
    super().__init__(
        mu,
        gs,
        cs,
        observed_mask,
        assimilated_ode,
        true_ode,
        complex_differentiation,
        true_observed_mask,
    )

System_LinearNonlinear_ModelUnknown

System_LinearNonlinear_ModelUnknown(
    mu: float,
    cs: jndarray,
    observed_mask: jndarray,
    linear_assimilated: Callable[[jndarray], jndarray],
    nonlinear_assimilated_ode: Callable[
        [jndarray, jndarray], jndarray
    ],
    complex_differentiation: bool = False,
)

Bases: _AssimilatedLinearNonlinearMixin, System_ModelUnknown

Concrete system: unknown true-model, only assimilated parts provided.

Source code in src/otf/system/linear_nonlinear.py
def __init__(
    self,
    mu: float,
    cs: jndarray,
    observed_mask: jndarray,
    linear_assimilated: Callable[[jndarray], jndarray],
    nonlinear_assimilated_ode: Callable[[jndarray, jndarray], jndarray],
    complex_differentiation: bool = False,
):
    self._set_assimilated_parts(
        linear_assimilated, nonlinear_assimilated_ode
    )
    assimilated_ode = self._define_ode(
        self.linear_assimilated, self.nonlinear_assimilated_ode
    )
    super().__init__(
        mu, cs, observed_mask, assimilated_ode, complex_differentiation
    )

System_ModelKnown

System_ModelKnown(
    mu: float,
    gs: jndarray,
    cs: jndarray,
    observed_mask: jndarray,
    assimilated_ode: Callable[
        [jndarray, jndarray], jndarray
    ],
    true_ode: Callable[[jndarray, jndarray], jndarray],
    complex_differentiation: bool = False,
    true_observed_mask: jndarray | None = None,
)

Bases: BaseSystem

System where the true ODE is known and can be simulated.

This subclass stores gs (true-system parameters) and a true_ode allowing simultaneous integration of the true and assimilated systems.

Initialize a System_ModelKnown with a provided true ODE.

Parameters:

Name Type Description Default
See
required
gs jndarray

True-system parameter values used by true_ode.

required
true_ode Callable[[jndarray, jndarray], jndarray]

Callable (gs, true_state) -> true_state_dot describing the dynamics of the true system.

required
true_observed_mask jndarray | None

Boolean mask indicating observed entries of the true state. If None, the value of observed_mask is reused.

None

Methods:

Name Description
f_true

Return the time derivative of the true state using true_ode.

Source code in src/otf/system/base.py
def __init__(
    self,
    mu: float,
    gs: jndarray,
    cs: jndarray,
    observed_mask: jndarray,
    assimilated_ode: Callable[[jndarray, jndarray], jndarray],
    true_ode: Callable[[jndarray, jndarray], jndarray],
    complex_differentiation: bool = False,
    true_observed_mask: jndarray | None = None,
):
    """Initialize a System_ModelKnown with a provided true ODE.

    Parameters
    ----------
    See `BaseSystem` for other parameter definitions.

    gs
        True-system parameter values used by `true_ode`.
    true_ode
        Callable `(gs, true_state) -> true_state_dot` describing the
        dynamics of the true system.
    true_observed_mask
        Boolean mask indicating observed entries of the true state. If
        `None`, the value of `observed_mask` is reused.
    """
    super().__init__(
        mu, cs, observed_mask, assimilated_ode, complex_differentiation
    )

    self._gs = gs
    self._true_ode = true_ode
    self._true_observed_mask = (
        true_observed_mask
        if true_observed_mask is not None
        else observed_mask
    )

f_true

f_true(true: jndarray) -> jndarray

Return the time derivative of the true state using true_ode.

This method is suitable for JIT compilation.

Parameters:

Name Type Description Default
true jndarray

Current true (flattened) state.

required

Returns:

Type Description
ndarray

Time derivative of true.

Source code in src/otf/system/base.py
def f_true(
    self,
    true: jndarray,
) -> jndarray:
    """Return the time derivative of the true state using `true_ode`.

    This method is suitable for JIT compilation.

    Parameters
    ----------
    true
        Current true (flattened) state.

    Returns
    -------
    jnp.ndarray
        Time derivative of `true`.
    """
    return self._true_ode(self.gs, true)

System_ModelUnknown

System_ModelUnknown(
    mu: float,
    cs: jndarray,
    observed_mask: jndarray,
    assimilated_ode: Callable[
        [jndarray, jndarray], jndarray
    ],
    complex_differentiation: bool = False,
)

Bases: BaseSystem

System where the true ODE is unknown and cannot be simulated.

This subclass does not provide a true_ode or gs; it is suitable when only an assimilated model is available and the true dynamics cannot be integrated alongside the assimilated system.

See BaseSystem for shared behavior and API.

Source code in src/otf/system/base.py
def __init__(
    self,
    mu: float,
    cs: jndarray,
    observed_mask: jndarray,
    assimilated_ode: Callable[[jndarray, jndarray], jndarray],
    complex_differentiation: bool = False,
):
    """Initialize the base system.

    Parameters
    ----------
    mu
        Nudging parameter.
    cs
        Estimated parameter values used by the assimilated system (may
        differ from the true system parameters `gs` used by subclasses).
    observed_mask
        Boolean `jnp.ndarray` mask indicating observed entries of a
        flattened state. Nudging is applied only on these entries.
    assimilated_ode
        Callable `(cs, state) -> state_dot` producing the time derivative
        for the assimilated system given parameters `cs`.
    complex_differentiation
        If True, treat arrays as potentially complex for autodiff.
    """
    if not isinstance(observed_mask, jndarray):
        raise ValueError(
            "`observed_mask` must be jnp.ndarray boolean array"
        )

    self._mu = mu
    self._observed_mask = observed_mask
    self._unobserved_mask = ~observed_mask
    self._observe_all = not jnp.any(self._unobserved_mask)
    self._cs = cs
    self._assimilated_ode = assimilated_ode

    self._complex_differentiation = complex_differentiation

    _df_dc = jax.jacrev(
        self.assimilated_ode,
        0,
        holomorphic=self._complex_differentiation,
    )
    _df_dv = jax.jacrev(
        self.assimilated_ode,
        1,
        holomorphic=self._complex_differentiation,
    )
    if self._complex_differentiation:

        def df_dc(cs: jndarray, assimilated: jndarray) -> jndarray:
            return _df_dc(cs.astype(complex), assimilated)

        def df_dv(cs: jndarray, assimilated: jndarray) -> jndarray:
            return _df_dv(cs.astype(complex), assimilated)
    else:

        def df_dc(cs: jndarray, assimilated: jndarray) -> jndarray:
            return _df_dc(cs, assimilated)

        def df_dv(cs: jndarray, assimilated: jndarray) -> jndarray:
            return _df_dv(cs, assimilated)

    self._df_dc = df_dc
    self._df_dv = df_dv