@eqx.filter_jit
def circuit_diffeqsolve(
terms: PyTree[AbstractTerm],
solver,
t0: RealScalarLike,
t1: RealScalarLike,
dt0: RealScalarLike,
y0: PyTree[ArrayLike],
args: PyTree[Any] = None,
*,
saveat: SaveAt = SaveAt(t1=True),
stepsize_controller: AbstractStepSizeController = ConstantStepSize(),
max_steps: int | None = 4096,
throw: bool = True,
checkpoints: int | None = None,
) -> Solution:
"""Stripped-down :func:`diffrax.diffeqsolve` for circuit simulation.
Identical calling convention to ``diffrax.diffeqsolve`` except:
- ``adjoint``, ``event``, ``progress_meter``, ``solver_state``,
``controller_state``, and ``made_jump`` arguments are absent.
- ``t0 < t1`` is assumed; no backward-time integration.
- No SDE/CDE terms.
- ``saveat.dense`` and ``saveat.steps`` are not supported.
``checkpoints`` controls the number of binomial checkpoints used by
``RecursiveCheckpointAdjoint`` (``None`` = auto from ``max_steps``).
"""
# ------------------------------------------------------------------
# dtype promotion for times (same logic as diffrax)
# ------------------------------------------------------------------
timelikes = [t0, t1, dt0] + [s.ts for s in jtu.tree_leaves(saveat.subs, is_leaf=_is_subsaveat) if s.ts is not None]
with jax.numpy_dtype_promotion("standard"):
time_dtype = jnp.result_type(*timelikes)
if jnp.issubdtype(time_dtype, jnp.integer):
time_dtype = lxi.default_floating_dtype()
t0 = jnp.asarray(t0, dtype=time_dtype)
t1 = jnp.asarray(t1, dtype=time_dtype)
dt0 = jnp.asarray(dt0, dtype=time_dtype)
# Cast save ts to the same dtype
def _cast_ts(saveat):
out = [s.ts for s in jtu.tree_leaves(saveat.subs, is_leaf=_is_subsaveat)]
return [x for x in out if x is not None]
saveat = eqx.tree_at(_cast_ts, saveat, replace_fn=lambda ts: ts.astype(time_dtype))
# Promote y0 dtype to be consistent with time (avoids weak-dtype issues)
def _promote(yi):
with jax.numpy_dtype_promotion("standard"):
_dtype = jnp.result_type(yi, time_dtype)
return jnp.asarray(yi, dtype=_dtype)
y0 = jtu.tree_map(_promote, y0)
# ------------------------------------------------------------------
# Wrap terms with direction=1 (forward-only; WrapTerm still needed
# so that solver.step receives a properly wrapped term object)
# ------------------------------------------------------------------
direction = jnp.asarray(1, dtype=time_dtype)
def _wrap(term):
assert isinstance(term, AbstractTerm)
return WrapTerm(term, direction)
terms = jtu.tree_map(
_wrap,
terms,
is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, diffrax._term.MultiTerm),
)
# ------------------------------------------------------------------
# Propagate PIDController tolerances into implicit solver root finder
# ------------------------------------------------------------------
if isinstance(solver, diffrax.AbstractImplicitSolver):
from diffrax._root_finder import use_stepsize_tol
def _get_tols(x):
outs = []
for attr in ("rtol", "atol", "norm"):
if getattr(solver.root_finder, attr) is use_stepsize_tol:
outs.append(getattr(x, attr))
return tuple(outs)
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
solver = eqx.tree_at(
lambda s: _get_tols(s.root_finder),
solver,
_get_tols(stepsize_controller),
)
# ------------------------------------------------------------------
# Validate save timestamps
# ------------------------------------------------------------------
def _check_ts(ts):
ts = eqxi.error_if(ts, ts[1:] < ts[:-1], "saveat.ts must be increasing.")
ts = eqxi.error_if(ts, (ts > t1) | (ts < t0), "saveat.ts must lie between t0 and t1.")
return ts
saveat = eqx.tree_at(_cast_ts, saveat, replace_fn=_check_ts)
# Wrap custom fn with direction (identity when direction=1, but kept for correctness)
def _wrap_fn(x):
if _is_subsaveat(x) and x.fn is not save_y:
direction_fn = lambda t, y, a: x.fn(direction * t, y, a)
return eqx.tree_at(lambda x: x.fn, x, direction_fn)
return x
saveat = jtu.tree_map(_wrap_fn, saveat, is_leaf=_is_subsaveat)
# ------------------------------------------------------------------
# Initialise solver & controller states
# ------------------------------------------------------------------
tprev = t0
error_order = solver.error_order(terms)
tnext, controller_state = stepsize_controller.init(terms, t0, t1, y0, dt0, args, solver.func, error_order)
tnext = jnp.minimum(tnext, t1)
solver_state = solver.init(terms, t0, tnext, y0, args)
# ------------------------------------------------------------------
# Allocate output buffers
# ------------------------------------------------------------------
def _allocate(subsaveat: SubSaveAt) -> SaveState:
out_size = 0
if subsaveat.t0:
out_size += 1
if subsaveat.ts is not None:
out_size += len(subsaveat.ts)
if subsaveat.t1:
out_size += 1
struct = eqx.filter_eval_shape(subsaveat.fn, t0, y0, args)
ts = jnp.full(out_size, jnp.inf, dtype=time_dtype)
ys = jtu.tree_map(lambda y: jnp.full((out_size,) + y.shape, jnp.inf, dtype=y.dtype), struct)
return SaveState(ts=ts, ys=ys, save_index=0, saveat_ts_index=0)
save_state = jtu.tree_map(_allocate, saveat.subs, is_leaf=_is_subsaveat)
# ------------------------------------------------------------------
# Build initial CircuitState
# ------------------------------------------------------------------
init_state = CircuitState(
y=y0,
tprev=tprev,
tnext=tnext,
solver_state=solver_state,
controller_state=controller_state,
result=RESULTS.successful,
num_steps=0,
num_accepted_steps=0,
num_rejected_steps=0,
save_state=save_state,
)
# ------------------------------------------------------------------
# Choose while-loop variant (checkpointed for AD, lax otherwise)
# ------------------------------------------------------------------
if max_steps is None:
inner_while_loop = ft.partial(_inner_loop, kind="lax")
outer_while_loop = ft.partial(_outer_loop, kind="lax")
else:
inner_while_loop = ft.partial(_inner_loop, kind="checkpointed")
outer_while_loop = ft.partial(_outer_loop, kind="checkpointed", checkpoints=checkpoints)
# ------------------------------------------------------------------
# Run the integration
# ------------------------------------------------------------------
final_state = _circuit_loop(
solver=solver,
stepsize_controller=stepsize_controller,
saveat=saveat,
t0=t0,
t1=t1,
max_steps=max_steps,
terms=terms,
args=args,
init_state=init_state,
inner_while_loop=inner_while_loop,
outer_while_loop=outer_while_loop,
)
# ------------------------------------------------------------------
# Build the Solution (compatible with diffrax.Solution)
# ------------------------------------------------------------------
ts = jtu.tree_map(lambda s: s.ts, final_state.save_state, is_leaf=_is_save_state)
ys = jtu.tree_map(lambda s: s.ys, final_state.save_state, is_leaf=_is_save_state)
stats = {
"num_steps": final_state.num_steps,
"num_accepted_steps": final_state.num_accepted_steps,
"num_rejected_steps": final_state.num_rejected_steps,
"max_steps": max_steps,
}
sol = Solution(
t0=t0,
t1=t1,
ts=ts,
ys=ys,
interpolation=None,
stats=stats,
result=final_state.result,
solver_state=None,
controller_state=None,
made_jump=None,
event_mask=None,
)
if throw:
sol = final_state.result.error_if(sol, jnp.invert(is_okay(final_state.result)))
return sol