Skip to content

circuit_diffeq ¤

Stripped-down ODE integrator for circuit simulation.

Compared to diffrax.diffeqsolve this removes all features not used by Circulax transient solvers:

Removed: - Events (event=, event state, root-finding after the loop) - Dense output (saveat.dense) - Progress meter - SDE / CDE support and related warnings - Backward time direction – always assumes t0 < t1 - made_jump – treated as static False, not carried in state - Deprecated discrete_terminating_event - Term-compatibility checking (assumed ODETerm always) - solver_state / controller_state passthrough arguments - saveat.steps saving mode

Retained: - SaveAt(ts=..., t1=True, t0=False, fn=...) – used by benchmarks - RecursiveCheckpointAdjoint – binomial checkpointing for autodiff - Full step-size control (PIDController, ConstantStepSize) - diffrax.Solution output – fully compatible with existing code

Classes:

Name Description
CircuitState

Carry state for the circuit integration while-loop.

Functions:

Name Description
circuit_diffeqsolve

Stripped-down :func:diffrax.diffeqsolve for circuit simulation.

CircuitState ¤

Bases: Module

Carry state for the circuit integration while-loop.

circuit_diffeqsolve ¤

circuit_diffeqsolve(
    terms: PyTree[AbstractTerm],
    solver,
    t0: RealScalarLike,
    t1: RealScalarLike,
    dt0: RealScalarLike,
    y0: PyTree[ArrayLike],
    args: PyTree[Any] = None,
    *,
    saveat: SaveAt = SaveAt(t1=True),
    stepsize_controller: AbstractStepSizeController = ConstantStepSize(),
    max_steps: int | None = 4096,
    throw: bool = True,
    checkpoints: int | None = None
) -> Solution

Stripped-down :func:diffrax.diffeqsolve for circuit simulation.

Identical calling convention to diffrax.diffeqsolve except:

  • adjoint, event, progress_meter, solver_state, controller_state, and made_jump arguments are absent.
  • t0 < t1 is assumed; no backward-time integration.
  • No SDE/CDE terms.
  • saveat.dense and saveat.steps are not supported.

checkpoints controls the number of binomial checkpoints used by RecursiveCheckpointAdjoint (None = auto from max_steps).

Source code in circulax/solvers/circuit_diffeq.py
@eqx.filter_jit
def circuit_diffeqsolve(
    terms: PyTree[AbstractTerm],
    solver,
    t0: RealScalarLike,
    t1: RealScalarLike,
    dt0: RealScalarLike,
    y0: PyTree[ArrayLike],
    args: PyTree[Any] = None,
    *,
    saveat: SaveAt = SaveAt(t1=True),
    stepsize_controller: AbstractStepSizeController = ConstantStepSize(),
    max_steps: int | None = 4096,
    throw: bool = True,
    checkpoints: int | None = None,
) -> Solution:
    """Stripped-down :func:`diffrax.diffeqsolve` for circuit simulation.

    Identical calling convention to ``diffrax.diffeqsolve`` except:

    - ``adjoint``, ``event``, ``progress_meter``, ``solver_state``,
      ``controller_state``, and ``made_jump`` arguments are absent.
    - ``t0 < t1`` is assumed; no backward-time integration.
    - No SDE/CDE terms.
    - ``saveat.dense`` and ``saveat.steps`` are not supported.

    ``checkpoints`` controls the number of binomial checkpoints used by
    ``RecursiveCheckpointAdjoint`` (``None`` = auto from ``max_steps``).
    """
    # ------------------------------------------------------------------
    # dtype promotion for times (same logic as diffrax)
    # ------------------------------------------------------------------
    timelikes = [t0, t1, dt0] + [s.ts for s in jtu.tree_leaves(saveat.subs, is_leaf=_is_subsaveat) if s.ts is not None]
    with jax.numpy_dtype_promotion("standard"):
        time_dtype = jnp.result_type(*timelikes)
    if jnp.issubdtype(time_dtype, jnp.integer):
        time_dtype = lxi.default_floating_dtype()

    t0 = jnp.asarray(t0, dtype=time_dtype)
    t1 = jnp.asarray(t1, dtype=time_dtype)
    dt0 = jnp.asarray(dt0, dtype=time_dtype)

    # Cast save ts to the same dtype
    def _cast_ts(saveat):
        out = [s.ts for s in jtu.tree_leaves(saveat.subs, is_leaf=_is_subsaveat)]
        return [x for x in out if x is not None]

    saveat = eqx.tree_at(_cast_ts, saveat, replace_fn=lambda ts: ts.astype(time_dtype))

    # Promote y0 dtype to be consistent with time (avoids weak-dtype issues)
    def _promote(yi):
        with jax.numpy_dtype_promotion("standard"):
            _dtype = jnp.result_type(yi, time_dtype)
        return jnp.asarray(yi, dtype=_dtype)

    y0 = jtu.tree_map(_promote, y0)

    # ------------------------------------------------------------------
    # Wrap terms with direction=1 (forward-only; WrapTerm still needed
    # so that solver.step receives a properly wrapped term object)
    # ------------------------------------------------------------------
    direction = jnp.asarray(1, dtype=time_dtype)

    def _wrap(term):
        assert isinstance(term, AbstractTerm)
        return WrapTerm(term, direction)

    terms = jtu.tree_map(
        _wrap,
        terms,
        is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, diffrax._term.MultiTerm),
    )

    # ------------------------------------------------------------------
    # Propagate PIDController tolerances into implicit solver root finder
    # ------------------------------------------------------------------
    if isinstance(solver, diffrax.AbstractImplicitSolver):
        from diffrax._root_finder import use_stepsize_tol

        def _get_tols(x):
            outs = []
            for attr in ("rtol", "atol", "norm"):
                if getattr(solver.root_finder, attr) is use_stepsize_tol:
                    outs.append(getattr(x, attr))
            return tuple(outs)

        if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
            solver = eqx.tree_at(
                lambda s: _get_tols(s.root_finder),
                solver,
                _get_tols(stepsize_controller),
            )

    # ------------------------------------------------------------------
    # Validate save timestamps
    # ------------------------------------------------------------------
    def _check_ts(ts):
        ts = eqxi.error_if(ts, ts[1:] < ts[:-1], "saveat.ts must be increasing.")
        ts = eqxi.error_if(ts, (ts > t1) | (ts < t0), "saveat.ts must lie between t0 and t1.")
        return ts

    saveat = eqx.tree_at(_cast_ts, saveat, replace_fn=_check_ts)

    # Wrap custom fn with direction (identity when direction=1, but kept for correctness)
    def _wrap_fn(x):
        if _is_subsaveat(x) and x.fn is not save_y:
            direction_fn = lambda t, y, a: x.fn(direction * t, y, a)
            return eqx.tree_at(lambda x: x.fn, x, direction_fn)
        return x

    saveat = jtu.tree_map(_wrap_fn, saveat, is_leaf=_is_subsaveat)

    # ------------------------------------------------------------------
    # Initialise solver & controller states
    # ------------------------------------------------------------------
    tprev = t0
    error_order = solver.error_order(terms)
    tnext, controller_state = stepsize_controller.init(terms, t0, t1, y0, dt0, args, solver.func, error_order)
    tnext = jnp.minimum(tnext, t1)
    solver_state = solver.init(terms, t0, tnext, y0, args)

    # ------------------------------------------------------------------
    # Allocate output buffers
    # ------------------------------------------------------------------
    def _allocate(subsaveat: SubSaveAt) -> SaveState:
        out_size = 0
        if subsaveat.t0:
            out_size += 1
        if subsaveat.ts is not None:
            out_size += len(subsaveat.ts)
        if subsaveat.t1:
            out_size += 1
        struct = eqx.filter_eval_shape(subsaveat.fn, t0, y0, args)
        ts = jnp.full(out_size, jnp.inf, dtype=time_dtype)
        ys = jtu.tree_map(lambda y: jnp.full((out_size,) + y.shape, jnp.inf, dtype=y.dtype), struct)
        return SaveState(ts=ts, ys=ys, save_index=0, saveat_ts_index=0)

    save_state = jtu.tree_map(_allocate, saveat.subs, is_leaf=_is_subsaveat)

    # ------------------------------------------------------------------
    # Build initial CircuitState
    # ------------------------------------------------------------------
    init_state = CircuitState(
        y=y0,
        tprev=tprev,
        tnext=tnext,
        solver_state=solver_state,
        controller_state=controller_state,
        result=RESULTS.successful,
        num_steps=0,
        num_accepted_steps=0,
        num_rejected_steps=0,
        save_state=save_state,
    )

    # ------------------------------------------------------------------
    # Choose while-loop variant (checkpointed for AD, lax otherwise)
    # ------------------------------------------------------------------
    if max_steps is None:
        inner_while_loop = ft.partial(_inner_loop, kind="lax")
        outer_while_loop = ft.partial(_outer_loop, kind="lax")
    else:
        inner_while_loop = ft.partial(_inner_loop, kind="checkpointed")
        outer_while_loop = ft.partial(_outer_loop, kind="checkpointed", checkpoints=checkpoints)

    # ------------------------------------------------------------------
    # Run the integration
    # ------------------------------------------------------------------
    final_state = _circuit_loop(
        solver=solver,
        stepsize_controller=stepsize_controller,
        saveat=saveat,
        t0=t0,
        t1=t1,
        max_steps=max_steps,
        terms=terms,
        args=args,
        init_state=init_state,
        inner_while_loop=inner_while_loop,
        outer_while_loop=outer_while_loop,
    )

    # ------------------------------------------------------------------
    # Build the Solution (compatible with diffrax.Solution)
    # ------------------------------------------------------------------
    ts = jtu.tree_map(lambda s: s.ts, final_state.save_state, is_leaf=_is_save_state)
    ys = jtu.tree_map(lambda s: s.ys, final_state.save_state, is_leaf=_is_save_state)

    stats = {
        "num_steps": final_state.num_steps,
        "num_accepted_steps": final_state.num_accepted_steps,
        "num_rejected_steps": final_state.num_rejected_steps,
        "max_steps": max_steps,
    }

    sol = Solution(
        t0=t0,
        t1=t1,
        ts=ts,
        ys=ys,
        interpolation=None,
        stats=stats,
        result=final_state.result,
        solver_state=None,
        controller_state=None,
        made_jump=None,
        event_mask=None,
    )

    if throw:
        sol = final_state.result.error_if(sol, jnp.invert(is_okay(final_state.result)))

    return sol