Skip to content

otf.optim

Optimization utilities and implementations used by the package.

This package exposes optimizers, learning-rate schedulers and gradient computers used for parameter estimation when assimilating data into BaseSystem instances.

Modules:

Name Description
base

Optimization utilities for parameter estimation during data assimilation.

gradient
lr_scheduler

Learning-rate schedulers for optimizers.

optimizer

Optimization algorithms for estimating unknown system parameters.

Classes:

Name Description
AdjointGradient

Compute parameter gradients using adjoint-based methods.

DummyLRScheduler

No-op scheduler for testing and compatibility.

DummyOptimizer

Optimizer that performs no parameter updates (useful for testing).

ExponentialLR

Multiply an optimizer's learning rate by a constant factor on each

GradientDescent

Simple gradient-descent optimizer.

LevenbergMarquardt

Levenberg–Marquardt optimizer using sensitivity-based gradients.

MultiStepLR

Reduce learning rate at specified step milestones.

OptaxWrapper

Adapter that wraps an Optax optimizer as a BaseOptimizer.

OptimizerChain

Combine several optimizers by summing their weighted updates.

PartialOptimizer

Wrap an optimizer to update only a subset of parameters.

Regularizer

Produce parameter-space penalties used as a regularization term.

SensitivityGradient

Compute gradients using sensitivity (forward) equations.

WeightedLevenbergMarquardt

Weighted Levenberg–Marquardt optimizer (Gauss–Newton variant).

Functions:

Name Description
pruned_factory

Return a 'pruned' variant of system_type.

AdjointGradient

AdjointGradient(
    system: BaseSystem,
    update_option: UpdateOption = UpdateOption.asymptotic,
    solver: tuple[type[SinglestepSolver | MultistepSolver]]
    | type[SinglestepSolver | MultistepSolver]
    | None = None,
    dt: float | None = None,
    interval_fraction: float = 1,
)

Bases: GradientComputer

Compute parameter gradients using adjoint-based methods.

Supports different update strategies controlled by UpdateOption.

Initialize an AdjointGradient.

Parameters:

Name Type Description Default
system BaseSystem

BaseSystem instance to analyze.

required
update_option UpdateOption

Which adjoint/update method to use (UpdateOption).

asymptotic
solver tuple[type[SinglestepSolver | MultistepSolver]] | type[SinglestepSolver | MultistepSolver] | None

Solver class or tuple of solver classes used when simulation-based adjoint computation is selected (complete or unobserved).

None
dt float | None

Time-step used with the solver (required when solver is used).

None
interval_fraction float

Fraction of the input time series to use for gradient computation (value in (0, 1]).

1
Source code in src/otf/optim/gradient/adjoint.py
def __init__(
    self,
    system: BaseSystem,
    update_option: UpdateOption = UpdateOption.asymptotic,
    solver: tuple[type[SinglestepSolver | MultistepSolver]]
    | type[SinglestepSolver | MultistepSolver]
    | None = None,
    dt: float | None = None,
    interval_fraction: float = 1,
):
    """Initialize an `AdjointGradient`.

    Parameters
    ----------
    system
        `BaseSystem` instance to analyze.
    update_option
        Which adjoint/update method to use (`UpdateOption`).
    solver
        Solver class or tuple of solver classes used when simulation-based
        adjoint computation is selected (`complete` or `unobserved`).
    dt
        Time-step used with the solver (required when `solver` is used).
    interval_fraction
        Fraction of the input time series to use for gradient computation
        (value in (0, 1]).
    """
    super().__init__(system)

    if not (0 < interval_fraction <= 1):
        raise ValueError(
            "`interval_fraction` should be in (0, 1]"
            f" (was {interval_fraction})"
        )
    self._interval_fraction = interval_fraction

    self._compute_adjoint = self._set_up_adjoint_method(update_option)

    if update_option is not UpdateOption.asymptotic:
        if dt is None:
            raise ValueError("`dt` must not be None for this update option")
        if solver is None:
            raise ValueError(
                "`solver` must not be None for the given update option"
            )

        adjoint_system = self._set_up_adjoint_system(system, update_option)

        self._dt = dt
        self._solver = self._set_up_solver(adjoint_system, solver)

DummyLRScheduler

DummyLRScheduler(*args, **kwargs)

Bases: LRScheduler

No-op scheduler for testing and compatibility.

Initialize a dummy scheduler (accepts arbitrary args).

This scheduler performs no action when step is called.

Source code in src/otf/optim/lr_scheduler.py
def __init__(self, *args, **kwargs):
    """Initialize a dummy scheduler (accepts arbitrary args).

    This scheduler performs no action when `step` is called.
    """
    pass

DummyOptimizer

DummyOptimizer(
    system: BaseSystem,
    gradient_computer: GradientComputer | None = None,
)

Bases: BaseOptimizer

Optimizer that performs no parameter updates (useful for testing).

Initialize the dummy optimizer.

Parameters:

Name Type Description Default
system BaseSystem

Target BaseSystem instance.

required
gradient_computer GradientComputer | None

Optional GradientComputer.

None
Source code in src/otf/optim/optimizer.py
def __init__(
    self,
    system: BaseSystem,
    gradient_computer: gradient.GradientComputer | None = None,
):
    """Initialize the dummy optimizer.

    Parameters
    ----------
    system
        Target `BaseSystem` instance.
    gradient_computer
        Optional `GradientComputer`.
    """
    super().__init__(system, gradient_computer)

ExponentialLR

ExponentialLR(
    optimizer: BaseOptimizer, gamma: float = 0.99
)

Bases: LRScheduler

Multiply an optimizer's learning rate by a constant factor on each step() call.

Initialize the exponential scheduler.

Parameters:

Name Type Description Default
optimizer BaseOptimizer

An instance of BaseOptimizer with a learning_rate attribute.

required
gamma float

Factor to multiply the learning rate by on each step.

0.99
Source code in src/otf/optim/lr_scheduler.py
def __init__(self, optimizer: BaseOptimizer, gamma: float = 0.99):
    """Initialize the exponential scheduler.

    Parameters
    ----------
    optimizer
        An instance of `BaseOptimizer` with a `learning_rate` attribute.
    gamma
        Factor to multiply the learning rate by on each `step`.
    """
    super().__init__(optimizer)
    self.gamma = gamma

GradientDescent

GradientDescent(
    system: BaseSystem,
    learning_rate: float = 0.0001,
    gradient_computer: GradientComputer | None = None,
)

Bases: BaseOptimizer

Simple gradient-descent optimizer.

Create a gradient-descent optimizer.

Parameters:

Name Type Description Default
learning_rate float

Scalar learning rate used to scale the negative gradient.

0.0001
Source code in src/otf/optim/optimizer.py
def __init__(
    self,
    system: BaseSystem,
    learning_rate: float = 1e-4,
    gradient_computer: gradient.GradientComputer | None = None,
):
    """Create a gradient-descent optimizer.

    Parameters
    ----------
    learning_rate
        Scalar learning rate used to scale the negative gradient.
    """
    super().__init__(system, gradient_computer)
    self.learning_rate = learning_rate

LevenbergMarquardt

LevenbergMarquardt(
    system: BaseSystem,
    learning_rate: float = 0.001,
    lam: float = 0.01,
    gradient_computer: SensitivityGradient | None = None,
)

Bases: BaseOptimizer

Levenberg–Marquardt optimizer using sensitivity-based gradients.

Levenberg–Marquardt optimizer using sensitivity-based gradients.

This implementation requires a SensitivityGradient instance and is currently implemented only for the UpdateOption.last_state update method of the gradient computer.

Parameters:

Name Type Description Default
learning_rate float

Scalar multiplier applied to the computed step.

0.001
lam float

Levenberg–Marquardt damping parameter.

0.01
Source code in src/otf/optim/optimizer.py
def __init__(
    self,
    system: BaseSystem,
    learning_rate: float = 1e-3,
    lam: float = 1e-2,
    gradient_computer: gradient.SensitivityGradient | None = None,
):
    """Levenberg–Marquardt optimizer using sensitivity-based gradients.

    This implementation requires a `SensitivityGradient` instance and is
    currently implemented only for the `UpdateOption.last_state` update
    method of the gradient computer.

    Parameters
    ----------
    learning_rate
        Scalar multiplier applied to the computed step.
    lam
        Levenberg–Marquardt damping parameter.
    """
    if not isinstance(gradient_computer, gradient.SensitivityGradient):
        raise NotImplementedError(
            "not yet implemented for adjoint-based gradient computation"
        )
    if gradient_computer.update_option is not (
        gradient.sensitivity.UpdateOption.last_state
    ):
        raise NotImplementedError(
            "currently implemented only for last state gradient computation"
        )

    super().__init__(system, gradient_computer)
    self.learning_rate = learning_rate
    self.lam = lam

MultiStepLR

MultiStepLR(
    optimizer: BaseOptimizer,
    milestones: list[int] | tuple[int],
    gamma: float = 0.5,
)

Bases: LRScheduler

Reduce learning rate at specified step milestones.

Initialize the multi-step scheduler.

Parameters:

Name Type Description Default
optimizer BaseOptimizer

An instance of BaseOptimizer with a learning_rate attribute.

required
milestones list[int] | tuple[int]

For each milestone, update the learning rate after that many calls to step. Specifying the same milestone multiple times multiplies the learning rate repeatedly at that milestone.

required
gamma float

Factor by which to multiply the learning rate at each milestone.

0.5
Source code in src/otf/optim/lr_scheduler.py
def __init__(
    self,
    optimizer: BaseOptimizer,
    milestones: list[int] | tuple[int],
    gamma: float = 0.5,
):
    """Initialize the multi-step scheduler.

    Parameters
    ----------
    optimizer
        An instance of `BaseOptimizer` with a `learning_rate` attribute.
    milestones
        For each milestone, update the learning rate after that many calls
        to `step`. Specifying the same milestone multiple times multiplies
        the learning rate repeatedly at that milestone.
    gamma
        Factor by which to multiply the learning rate at each milestone.
    """
    super().__init__(optimizer)
    self.milestones = Counter(milestones)
    self.gamma = gamma
    self.steps = 0

OptaxWrapper

OptaxWrapper(
    system: BaseSystem,
    optimizer: GradientTransformationExtraArgs,
    gradient_computer: GradientComputer | None = None,
)

Bases: BaseOptimizer

Adapter that wraps an Optax optimizer as a BaseOptimizer.

Wrap a given Optax optimizer.

Parameters:

Name Type Description Default
optimizer GradientTransformationExtraArgs

Instance of optax.GradientTransformationExtraArgs For example, optax.adam(learning_rate=1e-1).

required
Source code in src/otf/optim/optimizer.py
def __init__(
    self,
    system: BaseSystem,
    optimizer: optax.GradientTransformationExtraArgs,
    gradient_computer: gradient.GradientComputer | None = None,
):
    """Wrap a given Optax optimizer.

    Parameters
    ----------
    optimizer
        Instance of `optax.GradientTransformationExtraArgs` For example,
        `optax.adam(learning_rate=1e-1)`.
    """
    super().__init__(system, gradient_computer)
    self.optimizer = optimizer
    self.opt_state = self.optimizer.init(system.cs)

OptimizerChain

OptimizerChain(
    system: BaseSystem,
    learning_rate: float,
    optimizers: list[BaseOptimizer],
    weights: list[float],
)

Bases: BaseOptimizer

Combine several optimizers by summing their weighted updates.

Initialize an OptimizerChain.

Parameters:

Name Type Description Default
learning_rate float

Scalar applied to the total weighted update (controls overall step size).

required
optimizers list[BaseOptimizer]

Sequence of BaseOptimizer instances whose step results will be combined. All optimizers should target the same system.

required
weights list[float]

Sequence of relative weights (one per optimizer). These are normalized internally so their sum is one.

required
Notes

It can be convenient to set each individual optimizer's internal learning rate to 1.0 and control the combined step size with learning_rate supplied here.

Source code in src/otf/optim/base.py
def __init__(
    self,
    system: BaseSystem,
    learning_rate: float,
    optimizers: list[BaseOptimizer],
    weights: list[float],
):
    """Initialize an `OptimizerChain`.

    Parameters
    ----------
    learning_rate
        Scalar applied to the total weighted update (controls overall step
        size).
    optimizers
        Sequence of `BaseOptimizer` instances whose `step` results will be
        combined. All optimizers should target the same `system`.
    weights
        Sequence of relative weights (one per optimizer). These are
        normalized internally so their sum is one.

    Notes
    -----
    It can be convenient to set each individual optimizer's internal
    learning rate to 1.0 and control the combined step size with
    `learning_rate` supplied here.
    """
    assert len(optimizers) == len(weights), (
        "`optimizers` and `weights` should have same length"
    )

    super().__init__(system)
    self.learning_rate = learning_rate
    self._optimizers = optimizers
    self._weights = jnp.array(weights) / sum(weights)

PartialOptimizer

PartialOptimizer(
    optimizer: BaseOptimizer,
    param_idx: jndarray | None = None,
)

Bases: BaseOptimizer

Wrap an optimizer to update only a subset of parameters.

The wrapped optimizer computes a full update vector; PartialOptimizer masks that update so only selected indices in system.cs are changed.

Initialize the wrapper.

Parameters:

Name Type Description Default
optimizer BaseOptimizer

Base optimizer used to compute the full parameter update.

required
param_idx jndarray | None

Indices of system.cs that should be updated. Provide an explicit array-like index (e.g. np.array([0, 2])).

None

Methods:

Name Description
__getattr__

For attributes that aren't defined in this class, route access to the

__setattr__

For attributes that aren't defined in this class, route access to the

Source code in src/otf/optim/base.py
def __init__(
    self,
    optimizer: BaseOptimizer,
    param_idx: jndarray | None = None,
):
    """Initialize the wrapper.

    Parameters
    ----------
    optimizer
        Base optimizer used to compute the full parameter update.
    param_idx
        Indices of `system.cs` that should be updated. Provide an explicit
        array-like index (e.g. `np.array([0, 2])`).
    """
    # Define the attributes that belong to this class (versus those of the
    # wrapped class) so they can be distinguished and routed properly.
    super().__setattr__(
        "_own_attrs", {"_system", "system", "optimizer", "mask"}
    )
    super().__init__(optimizer.system)

    self.optimizer = optimizer

    n = len(self.optimizer.system.cs)
    self.mask = jnp.zeros(n, dtype=bool)
    self.mask = self.mask.at[param_idx].set(True)

__getattr__

__getattr__(name)

For attributes that aren't defined in this class, route access to the wrapped optimizer.

Source code in src/otf/optim/base.py
def __getattr__(self, name):
    """For attributes that aren't defined in this class, route access to the
    wrapped optimizer.
    """
    return getattr(self.optimizer, name)

__setattr__

__setattr__(name, value)

For attributes that aren't defined in this class, route access to the wrapped optimizer.

Source code in src/otf/optim/base.py
def __setattr__(self, name, value):
    """For attributes that aren't defined in this class, route access to the
    wrapped optimizer.
    """
    if name in self._own_attrs:
        super().__setattr__(name, value)
    else:
        setattr(self.optimizer, name, value)

Regularizer

Regularizer(
    system: BaseSystem,
    ord: int | float | Callable,
    prior: jndarray | None = None,
    callable_is_derivative: bool | None = None,
)

Bases: BaseOptimizer

Produce parameter-space penalties used as a regularization term.

The step method returns the negative derivative of the regularization penalty (same shape as system.cs) and can be combined with a gradient-based optimizer.

Initialize a Regularizer.

Parameters:

Name Type Description Default
ord int | float | Callable

If an int/float, interpreted as the ord-norm penalty (1 or 2 supported explicitly). If a callable, see callable_is_derivative.

required
prior jndarray | None

Optional prior parameter vector; penalty is applied to (cs - prior). If omitted, the prior is taken to be zero.

None
callable_is_derivative bool | None

If ord is callable, set to True when ord already returns the derivative (an array shaped like system.cs); set to False when ord returns a scalar penalty and its derivative should be auto-differentiated.

None

Methods:

Name Description
step

Return the negative derivative of the regularization penalty.

Source code in src/otf/optim/base.py
def __init__(
    self,
    system: BaseSystem,
    ord: int | float | Callable,
    prior: jndarray | None = None,
    callable_is_derivative: bool | None = None,
):
    """Initialize a `Regularizer`.

    Parameters
    ----------
    ord
        If an int/float, interpreted as the `ord`-norm penalty (1 or 2
        supported explicitly). If a callable, see `callable_is_derivative`.
    prior
        Optional prior parameter vector; penalty is applied to `(cs -
        prior)`. If omitted, the prior is taken to be zero.
    callable_is_derivative
        If `ord` is callable, set to True when `ord` already returns the
        derivative (an array shaped like `system.cs`); set to False when
        `ord` returns a scalar penalty and its derivative should be
        auto-differentiated.
    """
    if prior is None:
        self._prior = jnp.zeros_like(system.cs)
    else:
        if prior.shape != system.cs.shape:
            raise ValueError(
                "`prior` should have same shape as `system.cs`"
            )
        self._prior = prior

    match ord:
        case int() | float():
            pass
        case Callable() if callable_is_derivative is None:
            raise ValueError(
                "`callable_is_derivative` must be a bool when `ord` is a "
                "callable"
            )
        case Callable() if callable_is_derivative:
            if ord(system.cs).shape != system.cs.shape:
                raise ValueError(
                    "`ord` must return an array of the same shape as the "
                    "parameters `system.cs`"
                )
        case Callable() if not callable_is_derivative:
            if not jnp.isscalar(ord(system.cs)):
                raise ValueError(
                    "`ord` must be scalar-valued since "
                    "`callable_is_derivative` is False"
                )
        case _:
            raise ValueError("`ord` is an invalid type")

    super().__init__(system)
    self._ord = ord
    self._callable_is_derivative = callable_is_derivative

step

step(*_)

Return the negative derivative of the regularization penalty.

The returned array has the same shape as system.cs and can be added to a gradient-based update.

Source code in src/otf/optim/base.py
def step(self, *_):
    """Return the negative derivative of the regularization penalty.

    The returned array has the same shape as `system.cs` and can be added to
    a gradient-based update.
    """
    ord, prior = self.ord, self.prior
    cs = self.system.cs
    match ord:
        case 2:
            return -2 * (cs - prior)
        case 1:
            return -jnp.sign(cs - prior)
        case int() | float():
            # FutureFIXME: Evaluating at `cs - prior` might not be right.
            return -jax.jacfwd(lambda ps: jnp.norm(ps, ord=ord))(cs - prior)
        case Callable() if self.callable_is_derivative:
            # FutureFIXME: Evaluating at `cs - prior` might not be right.
            return -ord(cs - prior)
        case Callable() if not self.callable_is_derivative:
            # FutureFIXME: Evaluating at `cs - prior` might not be right.
            return -jax.jacfwd(ord, holomorphic=True)(cs - prior)
        case _:
            raise ValueError("`self.ord` is no longer a valid value")

SensitivityGradient

SensitivityGradient(
    system: BaseSystem,
    update_option: UpdateOption = UpdateOption.last_state,
    solver: tuple[type[SinglestepSolver | MultistepSolver]]
    | type[SinglestepSolver | MultistepSolver]
    | None = None,
    dt: float | None = None,
    use_unobserved_asymptotics: bool = False,
)

Bases: GradientComputer

Compute gradients using sensitivity (forward) equations.

Different UpdateOptions select how the sensitivity information is assembled (last state, mean state, mean gradient, or complete simulation).

Initialize a SensitivityGradient.

Parameters:

Name Type Description Default
system BaseSystem

BaseSystem instance to analyze.

required
update_option UpdateOption

Strategy for forming the gradient (UpdateOption).

last_state
solver tuple[type[SinglestepSolver | MultistepSolver]] | type[SinglestepSolver | MultistepSolver] | None

Solver class or tuple of solver classes used when the complete update option is selected.

None
dt float | None

Time-step used with the solver (required when solver is used).

None
use_unobserved_asymptotics bool

When True attempt to use asymptotic information from unobserved state components (experimental).

False
Source code in src/otf/optim/gradient/sensitivity.py
def __init__(
    self,
    system: BaseSystem,
    update_option: UpdateOption = UpdateOption.last_state,
    solver: tuple[type[SinglestepSolver | MultistepSolver]]
    | type[SinglestepSolver | MultistepSolver]
    | None = None,
    dt: float | None = None,
    use_unobserved_asymptotics: bool = False,
):
    """Initialize a `SensitivityGradient`.

    Parameters
    ----------
    system
        `BaseSystem` instance to analyze.
    update_option
        Strategy for forming the gradient (`UpdateOption`).
    solver
        Solver class or tuple of solver classes used when the `complete`
        update option is selected.
    dt
        Time-step used with the solver (required when `solver` is used).
    use_unobserved_asymptotics
        When True attempt to use asymptotic information from unobserved
        state components (experimental).
    """
    super().__init__(system)

    self._update_option = update_option
    self.compute_gradient = self._set_up_gradient(update_option)

    if update_option is UpdateOption.complete:
        if dt is None:
            raise ValueError("`dt` must not be None for this update option")
        if solver is None:
            raise ValueError(
                "`solver` must not be None for the given update option"
            )

        sensitivity_system = _SensitivitySystem(system)

        self._dt = dt
        self._solver = self._set_up_solver(sensitivity_system, solver)

    self._use_unobserved_asymptotics = use_unobserved_asymptotics

WeightedLevenbergMarquardt

WeightedLevenbergMarquardt(
    system: BaseSystem,
    learning_rate: float = 0.001,
    lam: float = 0.01,
    gradient_computer: GradientComputer | None = None,
)

Bases: BaseOptimizer

Weighted Levenberg–Marquardt optimizer (Gauss–Newton variant).

Perform a weighted version of the Levenberg–Marquardt modification of Gauss–Newton.

Parameters:

Name Type Description Default
learning_rate float

The learning rate (scalar by which to multiply the step)

0.001
lam float

Levenberg–Marquardt parameter

0.01
Source code in src/otf/optim/optimizer.py
def __init__(
    self,
    system: BaseSystem,
    learning_rate: float = 1e-3,
    lam: float = 1e-2,
    gradient_computer: gradient.GradientComputer | None = None,
):
    """Perform a weighted version of the Levenberg–Marquardt modification of
    Gauss–Newton.

    Parameters
    ----------
    learning_rate
        The learning rate (scalar by which to multiply the step)
    lam
        Levenberg–Marquardt parameter
    """
    super().__init__(system, gradient_computer)
    self.learning_rate = learning_rate
    self.lam = lam

pruned_factory

pruned_factory(
    system_type: type[BaseSystem],
) -> type[BaseSystem]

Return a 'pruned' variant of system_type.

If a parameter in cs of the system is to be set below its corresponding threshold (in absolute value), it will be set to zero permanently. Optionally require that this occur at least a specified number of times consecutively before setting a parameter to zero permanently.

Parameters:

Name Type Description Default
system_type type[BaseSystem]

The type of BaseSystem (not an instance) to be wrapped, e.g., system.System_ModelKnown

required
Source code in src/otf/optim/base.py
def pruned_factory(system_type: type[BaseSystem]) -> type[BaseSystem]:
    """Return a 'pruned' variant of `system_type`.

    If a parameter in `cs` of the system is to be set below its corresponding
    threshold (in absolute value), it will be set to zero permanently.
    Optionally require that this occur at least a specified number of times
    consecutively before setting a parameter to zero permanently.

    Parameters
    ----------
    system_type
        The type of `BaseSystem` (not an instance) to be wrapped, e.g.,
        `system.System_ModelKnown`
    """

    class Pruned(system_type):
        def __init__(
            self,
            *args,
            threshold: float | jndarray | np.ndarray,
            iterations: int | jndarray | np.ndarray | None = None,
            **kwargs,
        ):
            """

            Parameters
            ----------
            threshold
                If a float, all parameters `cs` are compared against this common
                value. If an array, each parameter is compared against the value
                in `cs` in the same position. To disable pruning for a
                parameter, set its threshold to zero.
            iterations
                Require each parameter to be less than its corresponding
                threshold at least `iterations` times consecutively before
                setting it to zero (permanently). As with `threshold`, if an
                `int`, each parameter will use this common value, but if an
                array, then each parameter will use its corresponding value. If
                None, only one time being less than `threshold` is needed to set
                a parameter to zero.
            """
            super().__init__(*args, **kwargs)

            if isinstance(threshold, (jndarray, np.ndarray)):
                if self._cs.shape != threshold.shape:
                    raise ValueError(
                        "`threshold` must have same shape as `system.cs`"
                    )
            self.threshold = np.array(threshold)

            if isinstance(iterations, (jndarray, np.ndarray)):
                if self._cs.shape != iterations.shape:
                    raise ValueError(
                        "`iterations` must have same shape as `system.cs`"
                    )
            self.iterations = (
                None if iterations is None else np.array(iterations)
            )

            # A mask in which True indicates the corresponding parameter should
            # be set to zero.
            self._set_zero = np.zeros_like(self.cs, dtype=bool)

            # Count the number of times each parameter is below its threshold in
            # a row.
            self._counter = np.zeros_like(self.cs, dtype=int)

        def _set_cs(self, cs):
            # For parameters under the threshold, set the mask to True. Don't
            # change the mask where it already was True.
            below_threshold = np.abs(self.cs) < self.threshold

            # Increment the counter to parameters below their threshold and
            # reset to zero the counter for parameters not below their
            # threshold.
            if self.iterations is not None:
                self._counter += below_threshold
                self._counter[~below_threshold] = 0
                at_least_counter = self._counter >= self.iterations
            else:
                at_least_counter = True

            set_zero = below_threshold & at_least_counter
            self._set_zero[set_zero] = True
            self._cs = jnp.where(self._set_zero, 0, cs)

            # Reset the counter for parameters already set to zero (no point
            # continuing to count).
            if self.iterations is not None:
                self._counter[self._set_zero] = 0

        cs = property(
            lambda self: self._cs, lambda self, value: self._set_cs(value)
        )

    # TODO: Check if this is working correctly. Or rework how pruned systems are
    # constructed.
    doc = (
        "This is a 'Pruned' version of the original class "
        f"({system_type.__module__}.{system_type.__qualname__}); "
        "that is, if a parameter in `self.cs` is to be set below its "
        "corresponding threshold (in absolute value), it will be set to zero "
        "permanently."
    )

    Pruned.__module__ = system_type.__module__
    Pruned.__name__ = f"{system_type.__name__}_Pruned"
    Pruned.__qualname__ = system_type.__qualname__
    Pruned.__doc__ = (
        doc
        if system_type.__doc__ is None
        else system_type.__doc__ + "\n\n" + doc
    )
    Pruned.__annotations__ = system_type.__annotations__

    return Pruned