Skip to content

sensitivity ¤

DC parameter sensitivity via implicit differentiation.

Computes ∂loss/∂params at a DC operating point y using the adjoint method: F(y, p) = 0 (DC equilibrium) ∂loss/∂p = -λᵀ · ∂F/∂p where J(y*)ᵀ λ = ∂L/∂y

The adjoint approach costs: 1. One linear solve for the adjoint vector λ (J already factored from DC solve) 2. n_params × n_devices OSDI residual evaluations via finite differences 3. One dot product per parameter

No autodiff through OSDI FFI calls is required.

Functions:

Name Description
dc_parameter_sensitivity

Compute ∂loss/∂params for named parameters of an OSDI group.

dc_parameter_sensitivity_dense

Dense-solver fallback for :func:dc_parameter_sensitivity.

dc_parameter_sensitivity ¤

dc_parameter_sensitivity(
    component_groups: dict,
    solver: CircuitLinearSolver,
    y_star: Array,
    loss_fn,
    *,
    osdi_group_key: str,
    param_names: list[str],
    model_descriptor=None,
    param_to_col: dict[str, int] | None = None,
    eps: float = 1e-06
) -> dict[str, Array]

Compute ∂loss/∂params for named parameters of an OSDI group.

Uses the implicit differentiation (adjoint) approach at the DC operating point y*:

F(y*, p) = 0
J(y*)ᵀ λ = ∂L/∂y      [adjoint solve]
∂loss/∂p_k = -λᵀ · ∂F/∂p_k      [parameter gradient]

∂F/∂p_k is computed via forward finite differences through osdi_residual_eval. This avoids autodiff through the XLA FFI.

Parameters:

Name Type Description Default
component_groups dict

Compiled circuit groups (dict returned by :func:~circulax.compiler.compile_netlist).

required
solver CircuitLinearSolver

A :class:~circulax.solvers.linear.CircuitLinearSolver instance built from the same component_groups. Must expose u_rows, u_cols, map_idx, n_unique, sys_size, g_leak, ground_indices (i.e. must be a KLU-based solver).

required
y_star Array

DC operating point, shape (sys_size,).

required
loss_fn

Callable y -> scalar. Must be differentiable w.r.t. y via jax.grad.

required
osdi_group_key str

Key in component_groups identifying the :class:~bosdi.circulax.OsdiComponentGroup whose parameters are to be differentiated.

required
param_names list[str]

List of canonical OSDI parameter names to differentiate. Must be a subset of the model's parameter names.

required
model_descriptor

The :class:~bosdi.circulax.OsdiModelDescriptor returned by :func:~circulax.osdi_component when the model was loaded. Needed to map parameter names to column indices. Either this or param_to_col must be provided.

None
param_to_col dict[str, int] | None

Explicit {param_name: column_index} mapping. Overrides model_descriptor lookup. Column indices match the ordering of group.params columns (= descriptor.param_names order).

None
eps float

Relative finite difference step size. Default 1e-6 works for most parameters.

1e-06

Returns:

Type Description
dict[str, Array]

Dict mapping each name in param_names to a gradient array of

dict[str, Array]

shape (n_devices,) — one scalar per device instance in the group.

Raises:

Type Description
ValueError

If osdi_group_key is missing from component_groups or if a requested parameter name is not found in the OSDI group.

ImportError

If bosdi / osdi_jax is not available.

TypeError

If solver does not expose the KLU coalescing attributes needed to build the Jacobian matrix.

Source code in circulax/solvers/sensitivity.py
def dc_parameter_sensitivity(
    component_groups: dict,
    solver: CircuitLinearSolver,
    y_star: jax.Array,
    loss_fn,
    *,
    osdi_group_key: str,
    param_names: list[str],
    model_descriptor=None,
    param_to_col: dict[str, int] | None = None,
    eps: float = 1e-6,
) -> dict[str, jax.Array]:
    """Compute ∂loss/∂params for named parameters of an OSDI group.

    Uses the implicit differentiation (adjoint) approach at the DC operating
    point ``y*``:

        F(y*, p) = 0
        J(y*)ᵀ λ = ∂L/∂y      [adjoint solve]
        ∂loss/∂p_k = -λᵀ · ∂F/∂p_k      [parameter gradient]

    ``∂F/∂p_k`` is computed via forward finite differences through
    ``osdi_residual_eval``.  This avoids autodiff through the XLA FFI.

    Args:
        component_groups: Compiled circuit groups (dict returned by
            :func:`~circulax.compiler.compile_netlist`).
        solver: A :class:`~circulax.solvers.linear.CircuitLinearSolver`
            instance built from the same ``component_groups``.  Must expose
            ``u_rows``, ``u_cols``, ``map_idx``, ``n_unique``, ``sys_size``,
            ``g_leak``, ``ground_indices`` (i.e. must be a KLU-based solver).
        y_star: DC operating point, shape ``(sys_size,)``.
        loss_fn: Callable ``y -> scalar``.  Must be differentiable w.r.t. ``y``
            via ``jax.grad``.
        osdi_group_key: Key in ``component_groups`` identifying the
            :class:`~bosdi.circulax.OsdiComponentGroup` whose parameters are
            to be differentiated.
        param_names: List of canonical OSDI parameter names to differentiate.
            Must be a subset of the model's parameter names.
        model_descriptor: The :class:`~bosdi.circulax.OsdiModelDescriptor`
            returned by :func:`~circulax.osdi_component` when the model was
            loaded.  Needed to map parameter names to column indices.  Either
            this or ``param_to_col`` must be provided.
        param_to_col: Explicit ``{param_name: column_index}`` mapping.
            Overrides ``model_descriptor`` lookup.  Column indices match the
            ordering of ``group.params`` columns (= ``descriptor.param_names``
            order).
        eps: Relative finite difference step size.  Default ``1e-6`` works
            for most parameters.

    Returns:
        Dict mapping each name in ``param_names`` to a gradient array of
        shape ``(n_devices,)`` — one scalar per device instance in the group.

    Raises:
        ValueError: If ``osdi_group_key`` is missing from ``component_groups``
            or if a requested parameter name is not found in the OSDI group.
        ImportError: If bosdi / osdi_jax is not available.
        TypeError: If ``solver`` does not expose the KLU coalescing attributes
            needed to build the Jacobian matrix.

    """
    try:
        from bosdi.circulax import OsdiComponentGroup
    except ImportError as err:
        raise ImportError("dc_parameter_sensitivity requires bosdi.") from err

    # -----------------------------------------------------------------------
    # Validate inputs
    # -----------------------------------------------------------------------
    if osdi_group_key not in component_groups:
        available = list(component_groups.keys())
        msg = f"OSDI group key {osdi_group_key!r} not found in component_groups. Available: {available}"
        raise ValueError(msg)

    group = component_groups[osdi_group_key]
    if not isinstance(group, OsdiComponentGroup):
        msg = f"Group {osdi_group_key!r} is not an OsdiComponentGroup (got {type(group).__name__})."
        raise TypeError(msg)

    if not hasattr(solver, "u_rows") or not hasattr(solver, "map_idx"):
        msg = (
            "dc_parameter_sensitivity requires a KLU-based solver "
            "(KLUSolver, KLUSplitLinear, KLUSplitQuadratic). "
            f"Got {type(solver).__name__}. "
            "For DenseSolver, use dc_parameter_sensitivity_dense instead."
        )
        raise TypeError(msg)

    # Resolve param names to column indices
    param_cols = _resolve_param_cols(
        group, param_names, model_descriptor=model_descriptor, param_to_col=param_to_col
    )

    # -----------------------------------------------------------------------
    # Step 1: ∂L/∂y  (via JAX autodiff — loss_fn must be JAX-differentiable)
    # -----------------------------------------------------------------------
    dL_dy = jax.grad(loss_fn)(y_star)  # shape: (sys_size,)

    # -----------------------------------------------------------------------
    # Step 2: Assemble J(y*) and coalesce into KLU format
    # -----------------------------------------------------------------------
    _, _, all_vals = assemble_system_real(
        y_star, component_groups, t1=0.0, dt=DC_DT
    )
    coalesced_vals = _build_klu_matrix_vals(solver, all_vals)

    # -----------------------------------------------------------------------
    # Step 3: Solve adjoint system J(y*)ᵀ λ = ∂L/∂y
    #
    # klujax.tsolve_with_symbol solves Aᵀ x = b using the KLU symbolic
    # analysis pre-computed during solver construction.
    # -----------------------------------------------------------------------
    lam = klujax.tsolve_with_symbol(
        solver.u_rows,
        solver.u_cols,
        coalesced_vals,
        dL_dy.astype(jnp.float64),
        solver._handle_wrapper,
    )  # shape: (sys_size,)

    # -----------------------------------------------------------------------
    # Step 4+5: FD ∂F/∂p and compute -λᵀ · ∂F/∂p for each param
    # -----------------------------------------------------------------------
    return _compute_fd_gradients(
        group, y_star, lam, param_names, param_cols, eps, model_id_override=None
    )

dc_parameter_sensitivity_dense ¤

dc_parameter_sensitivity_dense(
    component_groups: dict,
    y_star: Array,
    loss_fn,
    *,
    osdi_group_key: str,
    param_names: list[str],
    model_descriptor=None,
    param_to_col: dict[str, int] | None = None,
    eps: float = 1e-06
) -> dict[str, Array]

Dense-solver fallback for :func:dc_parameter_sensitivity.

Uses jnp.linalg.solve instead of KLU for the adjoint solve, so it works with any solver backend — including :class:~circulax.solvers.linear.DenseSolver. Intended for small circuits and unit tests.

Args match :func:dc_parameter_sensitivity except solver is not required; the Jacobian is built densely from component_groups.

Returns:

Type Description
dict[str, Array]

Same as :func:dc_parameter_sensitivity.

Source code in circulax/solvers/sensitivity.py
def dc_parameter_sensitivity_dense(
    component_groups: dict,
    y_star: jax.Array,
    loss_fn,
    *,
    osdi_group_key: str,
    param_names: list[str],
    model_descriptor=None,
    param_to_col: dict[str, int] | None = None,
    eps: float = 1e-6,
) -> dict[str, jax.Array]:
    """Dense-solver fallback for :func:`dc_parameter_sensitivity`.

    Uses ``jnp.linalg.solve`` instead of KLU for the adjoint solve, so it
    works with any solver backend — including
    :class:`~circulax.solvers.linear.DenseSolver`.  Intended for small circuits
    and unit tests.

    Args match :func:`dc_parameter_sensitivity` except ``solver`` is not
    required; the Jacobian is built densely from ``component_groups``.

    Returns:
        Same as :func:`dc_parameter_sensitivity`.

    """
    try:
        from bosdi.circulax import OsdiComponentGroup
    except ImportError as err:
        raise ImportError("dc_parameter_sensitivity_dense requires bosdi.") from err

    if osdi_group_key not in component_groups:
        available = list(component_groups.keys())
        msg = f"OSDI group key {osdi_group_key!r} not found. Available: {available}"
        raise ValueError(msg)

    group = component_groups[osdi_group_key]
    if not isinstance(group, OsdiComponentGroup):
        msg = f"Group {osdi_group_key!r} is not an OsdiComponentGroup."
        raise TypeError(msg)

    # Resolve param names to column indices
    param_cols = _resolve_param_cols(
        group, param_names, model_descriptor=model_descriptor, param_to_col=param_to_col
    )

    sys_size = y_star.shape[0]

    # -----------------------------------------------------------------------
    # Step 1: ∂L/∂y
    # -----------------------------------------------------------------------
    dL_dy = jax.grad(loss_fn)(y_star)

    # -----------------------------------------------------------------------
    # Step 2: Assemble dense J(y*)
    # -----------------------------------------------------------------------
    _, _, all_vals = assemble_system_real(y_star, component_groups, t1=0.0, dt=DC_DT)

    # Collect COO row/col from component_groups in sorted order (same as
    # _build_index_arrays in linear.py — must match to get consistent values)
    all_rows_list, all_cols_list = [], []
    for k in sorted(component_groups.keys()):
        g = component_groups[k]
        all_rows_list.append(jnp.array(g.jac_rows).reshape(-1))
        all_cols_list.append(jnp.array(g.jac_cols).reshape(-1))
    static_rows = jnp.concatenate(all_rows_list)
    static_cols = jnp.concatenate(all_cols_list)

    J = jnp.zeros((sys_size, sys_size), dtype=jnp.float64)
    J = J.at[static_rows, static_cols].add(all_vals)

    # Add leakage conductance (g_leak=1e-9 default matches DenseSolver default)
    diag_idx = jnp.arange(sys_size)
    J = J.at[diag_idx, diag_idx].add(1e-9)

    # Ground constraint: add GROUND_STIFFNESS to node 0 diagonal
    J = J.at[0, 0].add(GROUND_STIFFNESS)

    # -----------------------------------------------------------------------
    # Step 3: Solve adjoint Jᵀ λ = ∂L/∂y
    # -----------------------------------------------------------------------
    lam = jnp.linalg.solve(J.T, dL_dy.astype(jnp.float64))

    # -----------------------------------------------------------------------
    # Step 4+5: FD ∂F/∂p and compute -λᵀ · ∂F/∂p for each param
    # -----------------------------------------------------------------------
    return _compute_fd_gradients(
        group, y_star, lam, param_names, param_cols, eps, model_id_override=None
    )