Skip to content

adjoint ¤

Transient parameter sensitivity via discrete adjoint method.

Computes ∂loss/∂params through a transient simulation trajectory using the discrete adjoint of the Backward Euler time-stepping scheme. Avoids autodiff through the OSDI XLA FFI.

Mathematical background¤

Time-stepping residual at step k (Backward Euler):

G[k](y[k], y[k-1], p) = F(y[k], p) + (Q(y[k], p) - Q(y[k-1], p)) / dt = 0

Implicit function: y[k] = Φ(y[k-1], p) defined by G[k] = 0.

Sensitivity through implicit differentiation:

∂y[k]/∂y[k-1]  =  J_eff[k]^{-1} · (J_q[k-1] / dt)
∂y[k]/∂p       = -J_eff[k]^{-1} · (∂F[k]/∂p + ∂Q[k]/∂p / dt)

where J_eff[k] = J_f[k] + J_q[k] / dt.

Discrete adjoint recurrence (backward sweep, k = N down to 0):

ψ[N] = ∂L/∂y[N]|direct
ψ[k] = ∂L/∂y[k]|direct + (J_q[k] / dt)^T · λ[k+1]

J_eff[k]^T · λ[k] = ψ[k]     (adjoint linear solve at each step)

Parameter gradient accumulation:

∂loss/∂p = -Σ_{k=1}^{N} λ[k]^T · (∂F/∂p(y[k]) + ∂Q/∂p(y[k]) / dt)

Note on the J_q coupling term¤

The term (J_q[k]/dt)^T · λ[k+1] propagates sensitivity backward through the capacitive coupling between time steps. For purely resistive circuits (J_q ≈ 0) this term is zero and the adjoint reduces to N independent DC-like adjoint solves. For RC/RLC circuits it is essential for accuracy.

Functions:

Name Description
transient_parameter_sensitivity

Compute ∂loss/∂params via discrete adjoint over a transient trajectory.

transient_parameter_sensitivity_dense

Dense-solver fallback for :func:transient_parameter_sensitivity.

transient_parameter_sensitivity ¤

transient_parameter_sensitivity(
    component_groups: dict,
    solver: CircuitLinearSolver,
    y_trajectory: Array,
    ts: Array,
    loss_fn: Callable,
    *,
    osdi_group_key: str,
    param_names: list[str],
    model_descriptor: Any | None = None,
    param_to_col: dict[str, int] | None = None,
    eps: float = 1e-06,
    shared_params: bool = False
) -> dict[str, Array | float]

Compute ∂loss/∂params via discrete adjoint over a transient trajectory.

Implements the correct discrete adjoint of the Backward Euler time-stepping scheme, with full inter-step coupling through the capacitance matrix (J_q).

Note

This function uses host-side loops and jax.device_get for finite-difference perturbations and cannot be JIT-compiled.

The adjoint recurrence is (k = N, N-1, ..., 1):

ψ[k] = ∂L/∂y[k]|direct + (J_q[k] / dt)^T · λ[k+1]
J_eff[k]^T · λ[k] = ψ[k]                  (linear solve)
∂loss/∂p -= λ[k]^T · (∂F[k]/∂p + ∂Q[k]/∂p / dt)   (gradient)

Parameters:

Name Type Description Default
component_groups dict

Compiled circuit groups (dict returned by :func:~circulax.compiler.compile_netlist).

required
solver CircuitLinearSolver

A :class:~circulax.solvers.linear.CircuitLinearSolver instance built from the same component_groups. Must expose KLU coalescing attributes (u_rows, u_cols, map_idx, n_unique, sys_size, g_leak, ground_indices). For DenseSolver, use :func:transient_parameter_sensitivity_dense.

required
y_trajectory Array

Saved trajectory array of shape (n_checkpoints, sys_size) — the solution.ys from diffrax.diffeqsolve with SaveAt(ts=...).

required
ts Array

Time points of the checkpoints, shape (n_checkpoints,).

required
loss_fn Callable

Callable (y_trajectory, ts) -> scalar or y_final -> scalar. Must be differentiable w.r.t. y_trajectory via jax.grad. If loss depends only on the final state, pass a single-argument callable.

required
osdi_group_key str

Key in component_groups identifying the :class:~bosdi.circulax.OsdiComponentGroup whose parameters are differentiated.

required
param_names list[str]

List of canonical OSDI parameter names to differentiate.

required
model_descriptor Any | None

The :class:~bosdi.circulax.OsdiModelDescriptor returned by :func:~circulax.osdi_component. Either this or param_to_col must be provided.

None
param_to_col dict[str, int] | None

Explicit {param_name: column_index} mapping.

None
eps float

Relative finite difference step size.

1e-06
shared_params bool

If True, all devices share the same parameter values (process params). Perturbs all devices at once per parameter, reducing OSDI evals from n_params × n_devices to n_params per checkpoint. Returns scalar gradients instead of per-device.

False

Returns:

Type Description
dict[str, Array | float]

Dict mapping each name in param_names to a gradient.

dict[str, Array | float]

When shared_params=False: array of shape (n_devices,).

dict[str, Array | float]

When shared_params=True: scalar float.

Raises:

Type Description
ValueError

If osdi_group_key is missing or a parameter name is not found.

TypeError

If solver does not expose KLU coalescing attributes.

ImportError

If bosdi / osdi_jax is not available.

Source code in circulax/solvers/adjoint.py
def transient_parameter_sensitivity(
    component_groups: dict,
    solver: CircuitLinearSolver,
    y_trajectory: jax.Array,
    ts: jax.Array,
    loss_fn: Callable,
    *,
    osdi_group_key: str,
    param_names: list[str],
    model_descriptor: Any | None = None,
    param_to_col: dict[str, int] | None = None,
    eps: float = 1e-6,
    shared_params: bool = False,
) -> dict[str, jax.Array | float]:
    """Compute ∂loss/∂params via discrete adjoint over a transient trajectory.

    Implements the correct discrete adjoint of the Backward Euler time-stepping
    scheme, with full inter-step coupling through the capacitance matrix (J_q).

    Note:
        This function uses host-side loops and ``jax.device_get`` for
        finite-difference perturbations and cannot be JIT-compiled.

    The adjoint recurrence is (k = N, N-1, ..., 1):

        ψ[k] = ∂L/∂y[k]|direct + (J_q[k] / dt)^T · λ[k+1]
        J_eff[k]^T · λ[k] = ψ[k]                  (linear solve)
        ∂loss/∂p -= λ[k]^T · (∂F[k]/∂p + ∂Q[k]/∂p / dt)   (gradient)

    Args:
        component_groups: Compiled circuit groups (dict returned by
            :func:`~circulax.compiler.compile_netlist`).
        solver: A :class:`~circulax.solvers.linear.CircuitLinearSolver`
            instance built from the same ``component_groups``.  Must expose
            KLU coalescing attributes (``u_rows``, ``u_cols``, ``map_idx``,
            ``n_unique``, ``sys_size``, ``g_leak``, ``ground_indices``).
            For DenseSolver, use :func:`transient_parameter_sensitivity_dense`.
        y_trajectory: Saved trajectory array of shape
            ``(n_checkpoints, sys_size)`` — the ``solution.ys`` from
            ``diffrax.diffeqsolve`` with ``SaveAt(ts=...)``.
        ts: Time points of the checkpoints, shape ``(n_checkpoints,)``.
        loss_fn: Callable ``(y_trajectory, ts) -> scalar`` or
            ``y_final -> scalar``.  Must be differentiable w.r.t.
            ``y_trajectory`` via ``jax.grad``.  If loss depends only on the
            final state, pass a single-argument callable.
        osdi_group_key: Key in ``component_groups`` identifying the
            :class:`~bosdi.circulax.OsdiComponentGroup` whose parameters are
            differentiated.
        param_names: List of canonical OSDI parameter names to differentiate.
        model_descriptor: The :class:`~bosdi.circulax.OsdiModelDescriptor`
            returned by :func:`~circulax.osdi_component`.  Either this or
            ``param_to_col`` must be provided.
        param_to_col: Explicit ``{param_name: column_index}`` mapping.
        eps: Relative finite difference step size.
        shared_params: If True, all devices share the same parameter values
            (process params). Perturbs all devices at once per parameter,
            reducing OSDI evals from ``n_params × n_devices`` to ``n_params``
            per checkpoint. Returns scalar gradients instead of per-device.

    Returns:
        Dict mapping each name in ``param_names`` to a gradient.
        When ``shared_params=False``: array of shape ``(n_devices,)``.
        When ``shared_params=True``: scalar ``float``.

    Raises:
        ValueError: If ``osdi_group_key`` is missing or a parameter name is
            not found.
        TypeError: If ``solver`` does not expose KLU coalescing attributes.
        ImportError: If bosdi / osdi_jax is not available.

    """
    try:
        from bosdi.circulax import OsdiComponentGroup
    except ImportError as err:
        msg = "transient_parameter_sensitivity requires bosdi."
        raise ImportError(msg) from err

    # --- Validate ---
    if osdi_group_key not in component_groups:
        available = list(component_groups.keys())
        msg = f"OSDI group key {osdi_group_key!r} not found. Available: {available}"
        raise ValueError(msg)

    group = component_groups[osdi_group_key]
    if not isinstance(group, OsdiComponentGroup):
        msg = f"Group {osdi_group_key!r} is not an OsdiComponentGroup (got {type(group).__name__})."
        raise TypeError(msg)

    if not hasattr(solver, "u_rows") or not hasattr(solver, "map_idx"):
        msg = (
            "transient_parameter_sensitivity requires a KLU-based solver. "
            f"Got {type(solver).__name__}. "
            "For DenseSolver, use transient_parameter_sensitivity_dense instead."
        )
        raise TypeError(msg)

    param_cols = _resolve_param_cols(
        group, param_names, model_descriptor=model_descriptor, param_to_col=param_to_col
    )

    import klujax

    sys_size = solver.sys_size
    n_checkpoints = y_trajectory.shape[0]
    n_devices = group.params.shape[0]

    import inspect
    sig = inspect.signature(loss_fn)
    _loss_takes_two_args = len(sig.parameters) >= 2

    # Precompute ∂L/∂y[k] for ALL checkpoints in a single jax.grad call.
    # The old per-checkpoint approach called jax.grad N times with different
    # compile-time constants, triggering N separate JIT compilations per step.
    dL_dy_all = (
        jax.grad(loss_fn)(y_trajectory, ts)
        if _loss_takes_two_args
        else jax.grad(lambda yt: loss_fn(yt[-1]))(y_trajectory)
    )

    if shared_params:
        grad_accum = np.zeros((len(param_names),), dtype=np.float64)
    else:
        grad_accum = np.zeros((len(param_names), n_devices), dtype=np.float64)

    lam_next = None

    for k in range(n_checkpoints - 1, 0, -1):
        y_cur = y_trajectory[k].astype(jnp.float64)
        y_prev = y_trajectory[k - 1].astype(jnp.float64)
        dt = float(ts[k]) - float(ts[k - 1])

        psi_k = dL_dy_all[k].astype(jnp.float64)

        if lam_next is not None:
            dt_next = float(ts[k + 1]) - float(ts[k]) if k + 1 < n_checkpoints else dt
            coupling = _jq_matvec_klu(component_groups, y_cur, dt_next, lam_next)
            psi_k = psi_k + coupling

        coalesced_vals = _build_jeff_klu(component_groups, y_cur, dt, solver)
        lam_k = klujax.tsolve_with_symbol(
            solver.u_rows,
            solver.u_cols,
            coalesced_vals,
            psi_k,
            solver._handle_wrapper,  # noqa: SLF001
        )

        _compute_transient_fd_gradients(
            group, y_cur, y_prev, lam_k, dt, sys_size,
            param_names, param_cols, eps, grad_accum,
            shared_params=shared_params,
        )

        lam_next = lam_k

    if shared_params:
        return {pname: float(grad_accum[pi]) for pi, pname in enumerate(param_names)}
    return {pname: jnp.array(grad_accum[pi]) for pi, pname in enumerate(param_names)}

transient_parameter_sensitivity_dense ¤

transient_parameter_sensitivity_dense(
    component_groups: dict,
    y_trajectory: Array,
    ts: Array,
    loss_fn: Callable,
    *,
    osdi_group_key: str,
    param_names: list[str],
    model_descriptor: Any | None = None,
    param_to_col: dict[str, int] | None = None,
    eps: float = 1e-06,
    shared_params: bool = False
) -> dict[str, Array | float]

Dense-solver fallback for :func:transient_parameter_sensitivity.

Uses jnp.linalg.solve instead of KLU for adjoint solves, so it works with any solver backend — including :class:~circulax.solvers.linear.DenseSolver. Intended for small circuits and unit tests.

Args match :func:transient_parameter_sensitivity except solver is not required.

Returns:

Type Description
dict[str, Array | float]

Same as :func:transient_parameter_sensitivity.

Source code in circulax/solvers/adjoint.py
def transient_parameter_sensitivity_dense(
    component_groups: dict,
    y_trajectory: jax.Array,
    ts: jax.Array,
    loss_fn: Callable,
    *,
    osdi_group_key: str,
    param_names: list[str],
    model_descriptor: Any | None = None,
    param_to_col: dict[str, int] | None = None,
    eps: float = 1e-6,
    shared_params: bool = False,
) -> dict[str, jax.Array | float]:
    """Dense-solver fallback for :func:`transient_parameter_sensitivity`.

    Uses ``jnp.linalg.solve`` instead of KLU for adjoint solves, so it works
    with any solver backend — including
    :class:`~circulax.solvers.linear.DenseSolver`.  Intended for small
    circuits and unit tests.

    Args match :func:`transient_parameter_sensitivity` except ``solver`` is
    not required.

    Returns:
        Same as :func:`transient_parameter_sensitivity`.

    """
    try:
        from bosdi.circulax import OsdiComponentGroup
    except ImportError as err:
        msg = "transient_parameter_sensitivity_dense requires bosdi."
        raise ImportError(msg) from err

    if osdi_group_key not in component_groups:
        available = list(component_groups.keys())
        msg = f"OSDI group key {osdi_group_key!r} not found. Available: {available}"
        raise ValueError(msg)

    group = component_groups[osdi_group_key]
    if not isinstance(group, OsdiComponentGroup):
        msg = f"Group {osdi_group_key!r} is not an OsdiComponentGroup (got {type(group).__name__})."
        raise TypeError(msg)

    param_cols = _resolve_param_cols(
        group, param_names, model_descriptor=model_descriptor, param_to_col=param_to_col
    )

    sys_size = y_trajectory.shape[1]
    n_checkpoints = y_trajectory.shape[0]
    n_devices = group.params.shape[0]

    import inspect
    sig = inspect.signature(loss_fn)
    _loss_takes_two_args = len(sig.parameters) >= 2

    dL_dy_all = (
        jax.grad(loss_fn)(y_trajectory, ts)
        if _loss_takes_two_args
        else jax.grad(lambda yt: loss_fn(yt[-1]))(y_trajectory)
    )

    if shared_params:
        grad_accum = np.zeros((len(param_names),), dtype=np.float64)
    else:
        grad_accum = np.zeros((len(param_names), n_devices), dtype=np.float64)

    lam_next = None

    for k in range(n_checkpoints - 1, 0, -1):
        y_cur = y_trajectory[k].astype(jnp.float64)
        y_prev = y_trajectory[k - 1].astype(jnp.float64)
        dt = float(ts[k]) - float(ts[k - 1])

        psi_k = dL_dy_all[k].astype(jnp.float64)

        if lam_next is not None:
            dt_next = float(ts[k + 1]) - float(ts[k]) if k + 1 < n_checkpoints else dt
            C_scaled = _build_jq_matvec(component_groups, y_cur, dt_next, sys_size)
            psi_k = psi_k + C_scaled.T @ lam_next

        J = _build_jeff_dense(component_groups, y_cur, dt, sys_size)
        lam_k = jnp.linalg.solve(J.T, psi_k)

        _compute_transient_fd_gradients(
            group, y_cur, y_prev, lam_k, dt, sys_size,
            param_names, param_cols, eps, grad_accum,
            shared_params=shared_params,
        )

        lam_next = lam_k

    if shared_params:
        return {pname: float(grad_accum[pi]) for pi, pname in enumerate(param_names)}
    return {pname: jnp.array(grad_accum[pi]) for pi, pname in enumerate(param_names)}