Skip to content

linear ¤

Circuit linear solver strategies.

Separates physics assembly (Jacobian values) from the linear solve (matrix inversion). All solvers implement the lineax abstract interface and are JAX-transformable.

Classes:

Name Description
CircuitLinearSolver

Abstract base for all circuit linear solvers.

DenseSolver

Solves the system using dense matrix factorization (LU).

KLUSolver

Solves the system using the KLU sparse solver (via klujax).

KLUSplitLinear

KLU split solver paired with Modified Newton (frozen-Jacobian) for linear convergence.

KLUSplitQuadratic

KLU split solver paired with full Newton for quadratic convergence via klu_refactor.

KLUSplitSolver

Solves the system using the KLU sparse solver (via klujax) with split interface.

SparseSolver

Solves the system using JAX's Iterative BiCGStab solver.

Functions:

Name Description
analyze_circuit

Initializes a linear solver strategy for circuit analysis.

Attributes:

Name Type Description
DAMPING_EPS float

Small additive epsilon that prevents division by zero in the damping formula.

DAMPING_FACTOR float

Newton-step damping coefficient: limits each step to at most DAMPING_FACTOR / |δy|_max.

DC_DT float

Effective timestep used for DC analysis; makes capacitor stamps vanish (C/dt → 0).

GROUND_STIFFNESS float

Penalty added to ground-node diagonal entries to enforce V=0.

DAMPING_EPS module-attribute ¤

DAMPING_EPS: float = 1e-09

Small additive epsilon that prevents division by zero in the damping formula.

DAMPING_FACTOR module-attribute ¤

DAMPING_FACTOR: float = 0.5

Newton-step damping coefficient: limits each step to at most DAMPING_FACTOR / |δy|_max.

DC_DT module-attribute ¤

DC_DT: float = 1e+18

Effective timestep used for DC analysis; makes capacitor stamps vanish (C/dt → 0).

GROUND_STIFFNESS module-attribute ¤

GROUND_STIFFNESS: float = 1000000000.0

Penalty added to ground-node diagonal entries to enforce V=0.

CircuitLinearSolver ¤

Bases: AbstractLinearSolver

Abstract base for all circuit linear solvers.

Attributes:

Name Type Description
ground_indices Array

Indices of nodes connected to ground (forced to 0V).

is_complex bool

If True, the system is 2N×2N (real/imag unrolled); otherwise N×N.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Satisfies the lineax API; call _solve_impl directly for internal use.

init

Initialize the solver state (no-op for stateless solvers).

solve_dc

DC operating point via damped Newton-Raphson.

solve_dc_auto

DC Operating Point with automatic homotopy fallback.

solve_dc_checked

DC Operating Point with convergence status.

solve_dc_gmin

DC Operating Point via GMIN stepping (homotopy rescue).

solve_dc_source

DC Operating Point via source stepping (homotopy rescue).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Satisfies the lineax API; call _solve_impl directly for internal use.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Satisfies the lineax API; call ``_solve_impl`` directly for internal use."""
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (no-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (no-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC operating point via damped Newton-Raphson.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC operating point via damped Newton-Raphson.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).

    Returns:
        Converged solution vector.

    """
    y, _ = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)
    return y

solve_dc_auto ¤

solve_dc_auto(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point with automatic homotopy fallback.

Attempts a direct Newton solve first. If it fails to converge, falls back to GMIN stepping followed by source stepping — all inside a single JIT-compiled kernel via jax.lax.cond.

Strategy: 1. _run_newton — plain damped Newton from y_guess. 2. On failure: solve_dc_gmin (GMIN stepping) starting from y_guess, then solve_dc_source (source stepping) from the GMIN result.

Because jax.lax.cond evaluates both branches at trace time but only one at runtime, this compiles to a single kernel with no Python- level branching.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage for GMIN stepping (rescue branch).

0.01
n_gmin int

Number of GMIN steps in the rescue branch.

10
n_source int

Number of source steps in the rescue branch.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc_auto(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point with automatic homotopy fallback.

    Attempts a direct Newton solve first.  If it fails to converge, falls
    back to GMIN stepping followed by source stepping — all inside a single
    JIT-compiled kernel via ``jax.lax.cond``.

    Strategy:
    1.  ``_run_newton`` — plain damped Newton from ``y_guess``.
    2.  On failure: ``solve_dc_gmin`` (GMIN stepping) starting from
        ``y_guess``, then ``solve_dc_source`` (source stepping) from the
        GMIN result.

    Because ``jax.lax.cond`` evaluates both branches at *trace* time but
    only one at *runtime*, this compiles to a single kernel with no Python-
    level branching.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage for GMIN stepping (rescue branch).
        n_gmin: Number of GMIN steps in the rescue branch.
        n_source: Number of source steps in the rescue branch.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector.

    """
    y_direct, converged = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

    def rescue(_: None) -> jax.Array:
        y_gmin = self.solve_dc_gmin(
            component_groups,
            y_guess,
            g_start=g_start,
            n_steps=n_gmin,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )
        return self.solve_dc_source(
            component_groups,
            y_gmin,
            n_steps=n_source,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )

    return jax.lax.cond(converged, lambda _: y_direct, rescue, None)

solve_dc_checked ¤

solve_dc_checked(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> tuple[Array, Array]

DC Operating Point with convergence status.

Identical to :meth:solve_dc but additionally returns a boolean JAX scalar indicating whether the Newton-Raphson fixed-point iteration reported success. Because the flag is a JAX array (not a Python bool) it can be consumed inside compiled programs:

.. code-block:: python

y, converged = solver.solve_dc_checked(groups, y0)
# Outside JIT — inspect in Python:
if not converged:
    y = solver.solve_dc_gmin(groups, y0)
# Inside JIT — branch without Python-level control flow:
y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required
rtol float

Relative tolerance for optx.fixed_point.

1e-06
atol float

Absolute tolerance for optx.fixed_point.

1e-06
max_steps int

Maximum Newton iterations.

100

Returns:

Type Description
tuple[Array, Array]

(y, converged) — solution vector and boolean success flag.

Source code in circulax/solvers/linear.py
def solve_dc_checked(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> tuple[jax.Array, jax.Array]:
    """DC Operating Point with convergence status.

    Identical to :meth:`solve_dc` but additionally returns a boolean JAX
    scalar indicating whether the Newton-Raphson fixed-point iteration
    reported success.  Because the flag is a JAX array (not a Python bool)
    it can be consumed inside compiled programs:

    .. code-block:: python

        y, converged = solver.solve_dc_checked(groups, y0)
        # Outside JIT — inspect in Python:
        if not converged:
            y = solver.solve_dc_gmin(groups, y0)
        # Inside JIT — branch without Python-level control flow:
        y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).
        rtol: Relative tolerance for ``optx.fixed_point``.
        atol: Absolute tolerance for ``optx.fixed_point``.
        max_steps: Maximum Newton iterations.

    Returns:
        ``(y, converged)`` — solution vector and boolean success flag.

    """
    return self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

solve_dc_gmin ¤

solve_dc_gmin(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via GMIN stepping (homotopy rescue).

Steps the diagonal regularisation conductance g_leak logarithmically from g_start down to self.g_leak, using each converged solution as the warm start for the next step. The large initial g_leak linearises highly nonlinear components (diodes above threshold, lasers) that would otherwise cause Newton to diverge from a flat 0V start.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage conductance (large value, e.g. 1e-2).

0.01
n_steps int

Number of log-uniform steps from g_start to self.g_leak.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector after the full GMIN schedule.

Source code in circulax/solvers/linear.py
def solve_dc_gmin(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via GMIN stepping (homotopy rescue).

    Steps the diagonal regularisation conductance ``g_leak`` logarithmically
    from ``g_start`` down to ``self.g_leak``, using each converged solution
    as the warm start for the next step.  The large initial ``g_leak``
    linearises highly nonlinear components (diodes above threshold, lasers)
    that would otherwise cause Newton to diverge from a flat 0V start.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage conductance (large value, e.g. ``1e-2``).
        n_steps: Number of log-uniform steps from ``g_start`` to
            ``self.g_leak``.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector after the full GMIN schedule.

    """
    g_values = jnp.logspace(jnp.log10(g_start), jnp.log10(self.g_leak), n_steps)

    def step(y: jax.Array, g_leak_val: jax.Array) -> tuple[jax.Array, None]:
        stepped_solver = eqx.tree_at(lambda s: s.g_leak, self, g_leak_val)
        y_new, _ = stepped_solver._run_newton(  # noqa: SLF001
            component_groups, y, rtol=rtol, atol=atol, max_steps=max_steps
        )
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, g_values)
    return y_final

solve_dc_source ¤

solve_dc_source(
    component_groups: dict[str, Any],
    y_guess: Array,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via source stepping (homotopy rescue).

Ramps all source amplitudes (components tagged with amplitude_param) from 10 % to 100 % of their netlist values, using each converged solution as the warm start for the next step. This guides Newton through the nonlinear region without the large initial step from 0V to full excitation.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
n_steps int

Number of uniformly-spaced steps from 0.1 to 1.0.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector at full source amplitude.

Source code in circulax/solvers/linear.py
def solve_dc_source(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via source stepping (homotopy rescue).

    Ramps all source amplitudes (components tagged with ``amplitude_param``)
    from 10 % to 100 % of their netlist values, using each converged
    solution as the warm start for the next step.  This guides Newton
    through the nonlinear region without the large initial step from 0V to
    full excitation.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        n_steps: Number of uniformly-spaced steps from 0.1 to 1.0.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector at full source amplitude.

    """
    scales = jnp.linspace(0.1, 1.0, n_steps)

    def step(y: jax.Array, scale: jax.Array) -> tuple[jax.Array, None]:
        y_new, _ = self._run_newton(component_groups, y, source_scale=scale, rtol=rtol, atol=atol, max_steps=max_steps)
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, scales)
    return y_final

DenseSolver ¤

Bases: CircuitLinearSolver

Solves the system using dense matrix factorization (LU).

Best For
  • Small to Medium circuits (N < 2000).
  • Wavelength sweeps (AC Analysis) on GPU.
  • Systems where VMAP parallelism is critical.

Attributes:

Name Type Description
static_rows Array

Row indices for placing values into dense matrix.

static_cols Array

Column indices.

g_leak float

Leakage conductance added to diagonal to prevent singularity.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Satisfies the lineax API; call _solve_impl directly for internal use.

from_component_groups

Factory method to pre-calculate indices for the dense matrix.

init

Initialize the solver state (no-op for stateless solvers).

solve_dc

DC operating point via damped Newton-Raphson.

solve_dc_auto

DC Operating Point with automatic homotopy fallback.

solve_dc_checked

DC Operating Point with convergence status.

solve_dc_gmin

DC Operating Point via GMIN stepping (homotopy rescue).

solve_dc_source

DC Operating Point via source stepping (homotopy rescue).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Satisfies the lineax API; call _solve_impl directly for internal use.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Satisfies the lineax API; call ``_solve_impl`` directly for internal use."""
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any],
    num_vars: int,
    *,
    is_complex: bool = False,
    g_leak: float = 1e-09
) -> DenseSolver

Factory method to pre-calculate indices for the dense matrix.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False, g_leak: float = 1e-9
) -> "DenseSolver":
    """Factory method to pre-calculate indices for the dense matrix."""
    rows, cols, ground_idxs, sys_size = _build_index_arrays(component_groups, num_vars, is_complex)
    return cls(
        static_rows=jnp.array(rows),
        static_cols=jnp.array(cols),
        sys_size=sys_size,
        ground_indices=jnp.array(ground_idxs),
        is_complex=is_complex,
        g_leak=g_leak,
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (no-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (no-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC operating point via damped Newton-Raphson.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC operating point via damped Newton-Raphson.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).

    Returns:
        Converged solution vector.

    """
    y, _ = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)
    return y

solve_dc_auto ¤

solve_dc_auto(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point with automatic homotopy fallback.

Attempts a direct Newton solve first. If it fails to converge, falls back to GMIN stepping followed by source stepping — all inside a single JIT-compiled kernel via jax.lax.cond.

Strategy: 1. _run_newton — plain damped Newton from y_guess. 2. On failure: solve_dc_gmin (GMIN stepping) starting from y_guess, then solve_dc_source (source stepping) from the GMIN result.

Because jax.lax.cond evaluates both branches at trace time but only one at runtime, this compiles to a single kernel with no Python- level branching.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage for GMIN stepping (rescue branch).

0.01
n_gmin int

Number of GMIN steps in the rescue branch.

10
n_source int

Number of source steps in the rescue branch.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc_auto(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point with automatic homotopy fallback.

    Attempts a direct Newton solve first.  If it fails to converge, falls
    back to GMIN stepping followed by source stepping — all inside a single
    JIT-compiled kernel via ``jax.lax.cond``.

    Strategy:
    1.  ``_run_newton`` — plain damped Newton from ``y_guess``.
    2.  On failure: ``solve_dc_gmin`` (GMIN stepping) starting from
        ``y_guess``, then ``solve_dc_source`` (source stepping) from the
        GMIN result.

    Because ``jax.lax.cond`` evaluates both branches at *trace* time but
    only one at *runtime*, this compiles to a single kernel with no Python-
    level branching.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage for GMIN stepping (rescue branch).
        n_gmin: Number of GMIN steps in the rescue branch.
        n_source: Number of source steps in the rescue branch.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector.

    """
    y_direct, converged = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

    def rescue(_: None) -> jax.Array:
        y_gmin = self.solve_dc_gmin(
            component_groups,
            y_guess,
            g_start=g_start,
            n_steps=n_gmin,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )
        return self.solve_dc_source(
            component_groups,
            y_gmin,
            n_steps=n_source,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )

    return jax.lax.cond(converged, lambda _: y_direct, rescue, None)

solve_dc_checked ¤

solve_dc_checked(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> tuple[Array, Array]

DC Operating Point with convergence status.

Identical to :meth:solve_dc but additionally returns a boolean JAX scalar indicating whether the Newton-Raphson fixed-point iteration reported success. Because the flag is a JAX array (not a Python bool) it can be consumed inside compiled programs:

.. code-block:: python

y, converged = solver.solve_dc_checked(groups, y0)
# Outside JIT — inspect in Python:
if not converged:
    y = solver.solve_dc_gmin(groups, y0)
# Inside JIT — branch without Python-level control flow:
y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required
rtol float

Relative tolerance for optx.fixed_point.

1e-06
atol float

Absolute tolerance for optx.fixed_point.

1e-06
max_steps int

Maximum Newton iterations.

100

Returns:

Type Description
tuple[Array, Array]

(y, converged) — solution vector and boolean success flag.

Source code in circulax/solvers/linear.py
def solve_dc_checked(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> tuple[jax.Array, jax.Array]:
    """DC Operating Point with convergence status.

    Identical to :meth:`solve_dc` but additionally returns a boolean JAX
    scalar indicating whether the Newton-Raphson fixed-point iteration
    reported success.  Because the flag is a JAX array (not a Python bool)
    it can be consumed inside compiled programs:

    .. code-block:: python

        y, converged = solver.solve_dc_checked(groups, y0)
        # Outside JIT — inspect in Python:
        if not converged:
            y = solver.solve_dc_gmin(groups, y0)
        # Inside JIT — branch without Python-level control flow:
        y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).
        rtol: Relative tolerance for ``optx.fixed_point``.
        atol: Absolute tolerance for ``optx.fixed_point``.
        max_steps: Maximum Newton iterations.

    Returns:
        ``(y, converged)`` — solution vector and boolean success flag.

    """
    return self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

solve_dc_gmin ¤

solve_dc_gmin(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via GMIN stepping (homotopy rescue).

Steps the diagonal regularisation conductance g_leak logarithmically from g_start down to self.g_leak, using each converged solution as the warm start for the next step. The large initial g_leak linearises highly nonlinear components (diodes above threshold, lasers) that would otherwise cause Newton to diverge from a flat 0V start.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage conductance (large value, e.g. 1e-2).

0.01
n_steps int

Number of log-uniform steps from g_start to self.g_leak.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector after the full GMIN schedule.

Source code in circulax/solvers/linear.py
def solve_dc_gmin(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via GMIN stepping (homotopy rescue).

    Steps the diagonal regularisation conductance ``g_leak`` logarithmically
    from ``g_start`` down to ``self.g_leak``, using each converged solution
    as the warm start for the next step.  The large initial ``g_leak``
    linearises highly nonlinear components (diodes above threshold, lasers)
    that would otherwise cause Newton to diverge from a flat 0V start.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage conductance (large value, e.g. ``1e-2``).
        n_steps: Number of log-uniform steps from ``g_start`` to
            ``self.g_leak``.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector after the full GMIN schedule.

    """
    g_values = jnp.logspace(jnp.log10(g_start), jnp.log10(self.g_leak), n_steps)

    def step(y: jax.Array, g_leak_val: jax.Array) -> tuple[jax.Array, None]:
        stepped_solver = eqx.tree_at(lambda s: s.g_leak, self, g_leak_val)
        y_new, _ = stepped_solver._run_newton(  # noqa: SLF001
            component_groups, y, rtol=rtol, atol=atol, max_steps=max_steps
        )
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, g_values)
    return y_final

solve_dc_source ¤

solve_dc_source(
    component_groups: dict[str, Any],
    y_guess: Array,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via source stepping (homotopy rescue).

Ramps all source amplitudes (components tagged with amplitude_param) from 10 % to 100 % of their netlist values, using each converged solution as the warm start for the next step. This guides Newton through the nonlinear region without the large initial step from 0V to full excitation.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
n_steps int

Number of uniformly-spaced steps from 0.1 to 1.0.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector at full source amplitude.

Source code in circulax/solvers/linear.py
def solve_dc_source(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via source stepping (homotopy rescue).

    Ramps all source amplitudes (components tagged with ``amplitude_param``)
    from 10 % to 100 % of their netlist values, using each converged
    solution as the warm start for the next step.  This guides Newton
    through the nonlinear region without the large initial step from 0V to
    full excitation.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        n_steps: Number of uniformly-spaced steps from 0.1 to 1.0.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector at full source amplitude.

    """
    scales = jnp.linspace(0.1, 1.0, n_steps)

    def step(y: jax.Array, scale: jax.Array) -> tuple[jax.Array, None]:
        y_new, _ = self._run_newton(component_groups, y, source_scale=scale, rtol=rtol, atol=atol, max_steps=max_steps)
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, scales)
    return y_final

KLUSolver ¤

Bases: CircuitLinearSolver

Solves the system using the KLU sparse solver (via klujax).

Best For
  • Large circuits (N > 5000) running on CPU.
  • DC Operating Points of massive meshes.
  • Cases where DenseSolver runs out of memory (OOM).
Note

Does NOT support vmap (batching) automatically.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Satisfies the lineax API; call _solve_impl directly for internal use.

from_component_groups

Factory method to pre-hash indices for sparse coalescence.

init

Initialize the solver state (no-op for stateless solvers).

solve_dc

DC operating point via damped Newton-Raphson.

solve_dc_auto

DC Operating Point with automatic homotopy fallback.

solve_dc_checked

DC Operating Point with convergence status.

solve_dc_gmin

DC Operating Point via GMIN stepping (homotopy rescue).

solve_dc_source

DC Operating Point via source stepping (homotopy rescue).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Satisfies the lineax API; call _solve_impl directly for internal use.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Satisfies the lineax API; call ``_solve_impl`` directly for internal use."""
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any],
    num_vars: int,
    *,
    is_complex: bool = False,
    g_leak: float = 1e-09
) -> KLUSolver

Factory method to pre-hash indices for sparse coalescence.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False, g_leak: float = 1e-9
) -> "KLUSolver":
    """Factory method to pre-hash indices for sparse coalescence."""
    rows, cols, ground_idxs, sys_size = _build_index_arrays(component_groups, num_vars, is_complex)
    u_rows, u_cols, map_idx, n_unique = _klu_deduplicate(rows, cols, ground_idxs, sys_size)
    return cls(
        u_rows=jnp.array(u_rows),
        u_cols=jnp.array(u_cols),
        map_idx=jnp.array(map_idx),
        n_unique=n_unique,
        ground_indices=jnp.array(ground_idxs),
        sys_size=sys_size,
        is_complex=is_complex,
        g_leak=g_leak,
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (no-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (no-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC operating point via damped Newton-Raphson.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC operating point via damped Newton-Raphson.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).

    Returns:
        Converged solution vector.

    """
    y, _ = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)
    return y

solve_dc_auto ¤

solve_dc_auto(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point with automatic homotopy fallback.

Attempts a direct Newton solve first. If it fails to converge, falls back to GMIN stepping followed by source stepping — all inside a single JIT-compiled kernel via jax.lax.cond.

Strategy: 1. _run_newton — plain damped Newton from y_guess. 2. On failure: solve_dc_gmin (GMIN stepping) starting from y_guess, then solve_dc_source (source stepping) from the GMIN result.

Because jax.lax.cond evaluates both branches at trace time but only one at runtime, this compiles to a single kernel with no Python- level branching.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage for GMIN stepping (rescue branch).

0.01
n_gmin int

Number of GMIN steps in the rescue branch.

10
n_source int

Number of source steps in the rescue branch.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc_auto(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point with automatic homotopy fallback.

    Attempts a direct Newton solve first.  If it fails to converge, falls
    back to GMIN stepping followed by source stepping — all inside a single
    JIT-compiled kernel via ``jax.lax.cond``.

    Strategy:
    1.  ``_run_newton`` — plain damped Newton from ``y_guess``.
    2.  On failure: ``solve_dc_gmin`` (GMIN stepping) starting from
        ``y_guess``, then ``solve_dc_source`` (source stepping) from the
        GMIN result.

    Because ``jax.lax.cond`` evaluates both branches at *trace* time but
    only one at *runtime*, this compiles to a single kernel with no Python-
    level branching.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage for GMIN stepping (rescue branch).
        n_gmin: Number of GMIN steps in the rescue branch.
        n_source: Number of source steps in the rescue branch.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector.

    """
    y_direct, converged = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

    def rescue(_: None) -> jax.Array:
        y_gmin = self.solve_dc_gmin(
            component_groups,
            y_guess,
            g_start=g_start,
            n_steps=n_gmin,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )
        return self.solve_dc_source(
            component_groups,
            y_gmin,
            n_steps=n_source,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )

    return jax.lax.cond(converged, lambda _: y_direct, rescue, None)

solve_dc_checked ¤

solve_dc_checked(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> tuple[Array, Array]

DC Operating Point with convergence status.

Identical to :meth:solve_dc but additionally returns a boolean JAX scalar indicating whether the Newton-Raphson fixed-point iteration reported success. Because the flag is a JAX array (not a Python bool) it can be consumed inside compiled programs:

.. code-block:: python

y, converged = solver.solve_dc_checked(groups, y0)
# Outside JIT — inspect in Python:
if not converged:
    y = solver.solve_dc_gmin(groups, y0)
# Inside JIT — branch without Python-level control flow:
y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required
rtol float

Relative tolerance for optx.fixed_point.

1e-06
atol float

Absolute tolerance for optx.fixed_point.

1e-06
max_steps int

Maximum Newton iterations.

100

Returns:

Type Description
tuple[Array, Array]

(y, converged) — solution vector and boolean success flag.

Source code in circulax/solvers/linear.py
def solve_dc_checked(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> tuple[jax.Array, jax.Array]:
    """DC Operating Point with convergence status.

    Identical to :meth:`solve_dc` but additionally returns a boolean JAX
    scalar indicating whether the Newton-Raphson fixed-point iteration
    reported success.  Because the flag is a JAX array (not a Python bool)
    it can be consumed inside compiled programs:

    .. code-block:: python

        y, converged = solver.solve_dc_checked(groups, y0)
        # Outside JIT — inspect in Python:
        if not converged:
            y = solver.solve_dc_gmin(groups, y0)
        # Inside JIT — branch without Python-level control flow:
        y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).
        rtol: Relative tolerance for ``optx.fixed_point``.
        atol: Absolute tolerance for ``optx.fixed_point``.
        max_steps: Maximum Newton iterations.

    Returns:
        ``(y, converged)`` — solution vector and boolean success flag.

    """
    return self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

solve_dc_gmin ¤

solve_dc_gmin(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via GMIN stepping (homotopy rescue).

Steps the diagonal regularisation conductance g_leak logarithmically from g_start down to self.g_leak, using each converged solution as the warm start for the next step. The large initial g_leak linearises highly nonlinear components (diodes above threshold, lasers) that would otherwise cause Newton to diverge from a flat 0V start.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage conductance (large value, e.g. 1e-2).

0.01
n_steps int

Number of log-uniform steps from g_start to self.g_leak.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector after the full GMIN schedule.

Source code in circulax/solvers/linear.py
def solve_dc_gmin(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via GMIN stepping (homotopy rescue).

    Steps the diagonal regularisation conductance ``g_leak`` logarithmically
    from ``g_start`` down to ``self.g_leak``, using each converged solution
    as the warm start for the next step.  The large initial ``g_leak``
    linearises highly nonlinear components (diodes above threshold, lasers)
    that would otherwise cause Newton to diverge from a flat 0V start.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage conductance (large value, e.g. ``1e-2``).
        n_steps: Number of log-uniform steps from ``g_start`` to
            ``self.g_leak``.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector after the full GMIN schedule.

    """
    g_values = jnp.logspace(jnp.log10(g_start), jnp.log10(self.g_leak), n_steps)

    def step(y: jax.Array, g_leak_val: jax.Array) -> tuple[jax.Array, None]:
        stepped_solver = eqx.tree_at(lambda s: s.g_leak, self, g_leak_val)
        y_new, _ = stepped_solver._run_newton(  # noqa: SLF001
            component_groups, y, rtol=rtol, atol=atol, max_steps=max_steps
        )
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, g_values)
    return y_final

solve_dc_source ¤

solve_dc_source(
    component_groups: dict[str, Any],
    y_guess: Array,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via source stepping (homotopy rescue).

Ramps all source amplitudes (components tagged with amplitude_param) from 10 % to 100 % of their netlist values, using each converged solution as the warm start for the next step. This guides Newton through the nonlinear region without the large initial step from 0V to full excitation.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
n_steps int

Number of uniformly-spaced steps from 0.1 to 1.0.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector at full source amplitude.

Source code in circulax/solvers/linear.py
def solve_dc_source(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via source stepping (homotopy rescue).

    Ramps all source amplitudes (components tagged with ``amplitude_param``)
    from 10 % to 100 % of their netlist values, using each converged
    solution as the warm start for the next step.  This guides Newton
    through the nonlinear region without the large initial step from 0V to
    full excitation.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        n_steps: Number of uniformly-spaced steps from 0.1 to 1.0.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector at full source amplitude.

    """
    scales = jnp.linspace(0.1, 1.0, n_steps)

    def step(y: jax.Array, scale: jax.Array) -> tuple[jax.Array, None]:
        y_new, _ = self._run_newton(component_groups, y, source_scale=scale, rtol=rtol, atol=atol, max_steps=max_steps)
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, scales)
    return y_final

KLUSplitLinear ¤

Bases: KLUSplitSolver

KLU split solver paired with Modified Newton (frozen-Jacobian) for linear convergence.

Extends :class:KLUSplitSolver with an explicit numeric factorization step so the Jacobian can be factored once per time step and reused across all Newton iterations within that step (Modified Newton / frozen-Jacobian scheme). Use together with :class:~circulax.solvers.transient.FactorizedTransientSolver.

Best For
  • Large circuits (N > 5000) running on CPU where the Jacobian changes slowly.
  • Transient simulations with many Newton iterations per step.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Satisfies the lineax API; call _solve_impl directly for internal use.

factor_jacobian

Factor the Jacobian and return a numeric handle for repeated solves.

from_component_groups

Factory — delegates to :meth:KLUSplitSolver.from_component_groups.

init

Initialize the solver state (no-op for stateless solvers).

solve_dc

DC operating point via damped Newton-Raphson.

solve_dc_auto

DC Operating Point with automatic homotopy fallback.

solve_dc_checked

DC Operating Point with convergence status.

solve_dc_gmin

DC Operating Point via GMIN stepping (homotopy rescue).

solve_dc_source

DC Operating Point via source stepping (homotopy rescue).

solve_with_frozen_jacobian

Solve using a pre-computed numeric factorization.

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Satisfies the lineax API; call _solve_impl directly for internal use.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Satisfies the lineax API; call ``_solve_impl`` directly for internal use."""
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

factor_jacobian ¤

factor_jacobian(all_vals: Array) -> Array

Factor the Jacobian and return a numeric handle for repeated solves.

Parameters:

Name Type Description Default
all_vals Array

Flattened non-zero Jacobian values (COO format).

required

Returns:

Type Description
Array

Opaque numeric handle (int32 JAX scalar) to pass to

Array

meth:solve_with_frozen_jacobian. Must be freed with

Array

klujax.free_numeric after use to avoid C++ memory leaks.

Source code in circulax/solvers/linear.py
def factor_jacobian(self, all_vals: jax.Array) -> jax.Array:
    """Factor the Jacobian and return a numeric handle for repeated solves.

    Args:
        all_vals: Flattened non-zero Jacobian values (COO format).

    Returns:
        Opaque numeric handle (``int32`` JAX scalar) to pass to
        :meth:`solve_with_frozen_jacobian`.  Must be freed with
        ``klujax.free_numeric`` after use to avoid C++ memory leaks.

    """
    g_vals = jnp.full(self.ground_indices.shape[0], GROUND_STIFFNESS, dtype=all_vals.dtype)
    l_vals = jnp.full(self.sys_size, self.g_leak, dtype=all_vals.dtype)

    raw_vals = jnp.concatenate([all_vals, g_vals, l_vals])
    coalesced_vals = jax.ops.segment_sum(raw_vals, self.map_idx, num_segments=self.n_unique)

    return klujax.factor(self.u_rows, self.u_cols, coalesced_vals, self._handle_wrapper.handle)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any],
    num_vars: int,
    *,
    is_complex: bool = False,
    g_leak: float = 1e-09
) -> KLUSplitLinear

Factory — delegates to :meth:KLUSplitSolver.from_component_groups.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False, g_leak: float = 1e-9
) -> "KLUSplitLinear":
    """Factory — delegates to :meth:`KLUSplitSolver.from_component_groups`."""
    return super().from_component_groups(  # type: ignore[return-value]
        component_groups, num_vars, is_complex=is_complex, g_leak=g_leak
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (no-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (no-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC operating point via damped Newton-Raphson.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC operating point via damped Newton-Raphson.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).

    Returns:
        Converged solution vector.

    """
    y, _ = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)
    return y

solve_dc_auto ¤

solve_dc_auto(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point with automatic homotopy fallback.

Attempts a direct Newton solve first. If it fails to converge, falls back to GMIN stepping followed by source stepping — all inside a single JIT-compiled kernel via jax.lax.cond.

Strategy: 1. _run_newton — plain damped Newton from y_guess. 2. On failure: solve_dc_gmin (GMIN stepping) starting from y_guess, then solve_dc_source (source stepping) from the GMIN result.

Because jax.lax.cond evaluates both branches at trace time but only one at runtime, this compiles to a single kernel with no Python- level branching.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage for GMIN stepping (rescue branch).

0.01
n_gmin int

Number of GMIN steps in the rescue branch.

10
n_source int

Number of source steps in the rescue branch.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc_auto(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point with automatic homotopy fallback.

    Attempts a direct Newton solve first.  If it fails to converge, falls
    back to GMIN stepping followed by source stepping — all inside a single
    JIT-compiled kernel via ``jax.lax.cond``.

    Strategy:
    1.  ``_run_newton`` — plain damped Newton from ``y_guess``.
    2.  On failure: ``solve_dc_gmin`` (GMIN stepping) starting from
        ``y_guess``, then ``solve_dc_source`` (source stepping) from the
        GMIN result.

    Because ``jax.lax.cond`` evaluates both branches at *trace* time but
    only one at *runtime*, this compiles to a single kernel with no Python-
    level branching.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage for GMIN stepping (rescue branch).
        n_gmin: Number of GMIN steps in the rescue branch.
        n_source: Number of source steps in the rescue branch.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector.

    """
    y_direct, converged = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

    def rescue(_: None) -> jax.Array:
        y_gmin = self.solve_dc_gmin(
            component_groups,
            y_guess,
            g_start=g_start,
            n_steps=n_gmin,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )
        return self.solve_dc_source(
            component_groups,
            y_gmin,
            n_steps=n_source,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )

    return jax.lax.cond(converged, lambda _: y_direct, rescue, None)

solve_dc_checked ¤

solve_dc_checked(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> tuple[Array, Array]

DC Operating Point with convergence status.

Identical to :meth:solve_dc but additionally returns a boolean JAX scalar indicating whether the Newton-Raphson fixed-point iteration reported success. Because the flag is a JAX array (not a Python bool) it can be consumed inside compiled programs:

.. code-block:: python

y, converged = solver.solve_dc_checked(groups, y0)
# Outside JIT — inspect in Python:
if not converged:
    y = solver.solve_dc_gmin(groups, y0)
# Inside JIT — branch without Python-level control flow:
y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required
rtol float

Relative tolerance for optx.fixed_point.

1e-06
atol float

Absolute tolerance for optx.fixed_point.

1e-06
max_steps int

Maximum Newton iterations.

100

Returns:

Type Description
tuple[Array, Array]

(y, converged) — solution vector and boolean success flag.

Source code in circulax/solvers/linear.py
def solve_dc_checked(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> tuple[jax.Array, jax.Array]:
    """DC Operating Point with convergence status.

    Identical to :meth:`solve_dc` but additionally returns a boolean JAX
    scalar indicating whether the Newton-Raphson fixed-point iteration
    reported success.  Because the flag is a JAX array (not a Python bool)
    it can be consumed inside compiled programs:

    .. code-block:: python

        y, converged = solver.solve_dc_checked(groups, y0)
        # Outside JIT — inspect in Python:
        if not converged:
            y = solver.solve_dc_gmin(groups, y0)
        # Inside JIT — branch without Python-level control flow:
        y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).
        rtol: Relative tolerance for ``optx.fixed_point``.
        atol: Absolute tolerance for ``optx.fixed_point``.
        max_steps: Maximum Newton iterations.

    Returns:
        ``(y, converged)`` — solution vector and boolean success flag.

    """
    return self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

solve_dc_gmin ¤

solve_dc_gmin(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via GMIN stepping (homotopy rescue).

Steps the diagonal regularisation conductance g_leak logarithmically from g_start down to self.g_leak, using each converged solution as the warm start for the next step. The large initial g_leak linearises highly nonlinear components (diodes above threshold, lasers) that would otherwise cause Newton to diverge from a flat 0V start.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage conductance (large value, e.g. 1e-2).

0.01
n_steps int

Number of log-uniform steps from g_start to self.g_leak.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector after the full GMIN schedule.

Source code in circulax/solvers/linear.py
def solve_dc_gmin(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via GMIN stepping (homotopy rescue).

    Steps the diagonal regularisation conductance ``g_leak`` logarithmically
    from ``g_start`` down to ``self.g_leak``, using each converged solution
    as the warm start for the next step.  The large initial ``g_leak``
    linearises highly nonlinear components (diodes above threshold, lasers)
    that would otherwise cause Newton to diverge from a flat 0V start.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage conductance (large value, e.g. ``1e-2``).
        n_steps: Number of log-uniform steps from ``g_start`` to
            ``self.g_leak``.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector after the full GMIN schedule.

    """
    g_values = jnp.logspace(jnp.log10(g_start), jnp.log10(self.g_leak), n_steps)

    def step(y: jax.Array, g_leak_val: jax.Array) -> tuple[jax.Array, None]:
        stepped_solver = eqx.tree_at(lambda s: s.g_leak, self, g_leak_val)
        y_new, _ = stepped_solver._run_newton(  # noqa: SLF001
            component_groups, y, rtol=rtol, atol=atol, max_steps=max_steps
        )
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, g_values)
    return y_final

solve_dc_source ¤

solve_dc_source(
    component_groups: dict[str, Any],
    y_guess: Array,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via source stepping (homotopy rescue).

Ramps all source amplitudes (components tagged with amplitude_param) from 10 % to 100 % of their netlist values, using each converged solution as the warm start for the next step. This guides Newton through the nonlinear region without the large initial step from 0V to full excitation.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
n_steps int

Number of uniformly-spaced steps from 0.1 to 1.0.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector at full source amplitude.

Source code in circulax/solvers/linear.py
def solve_dc_source(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via source stepping (homotopy rescue).

    Ramps all source amplitudes (components tagged with ``amplitude_param``)
    from 10 % to 100 % of their netlist values, using each converged
    solution as the warm start for the next step.  This guides Newton
    through the nonlinear region without the large initial step from 0V to
    full excitation.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        n_steps: Number of uniformly-spaced steps from 0.1 to 1.0.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector at full source amplitude.

    """
    scales = jnp.linspace(0.1, 1.0, n_steps)

    def step(y: jax.Array, scale: jax.Array) -> tuple[jax.Array, None]:
        y_new, _ = self._run_newton(component_groups, y, source_scale=scale, rtol=rtol, atol=atol, max_steps=max_steps)
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, scales)
    return y_final

solve_with_frozen_jacobian ¤

solve_with_frozen_jacobian(residual: Array, numeric: Array) -> Solution

Solve using a pre-computed numeric factorization.

Parameters:

Name Type Description Default
residual Array

The right-hand side vector -F(y).

required
numeric Array

Handle returned by :meth:factor_jacobian.

required

Returns:

Type Description
Solution

class:lineax.Solution with the Newton step δy.

Source code in circulax/solvers/linear.py
def solve_with_frozen_jacobian(self, residual: jax.Array, numeric: jax.Array) -> lx.Solution:
    """Solve using a pre-computed numeric factorization.

    Args:
        residual: The right-hand side vector ``-F(y)``.
        numeric: Handle returned by :meth:`factor_jacobian`.

    Returns:
        :class:`lineax.Solution` with the Newton step ``δy``.

    """
    solution = klujax.solve_with_numeric(numeric, residual, self._handle_wrapper.handle)
    return lx.Solution(
        value=solution.reshape(residual.shape),
        result=lx.RESULTS.successful,
        state=None,
        stats={},
    )

KLUSplitQuadratic ¤

Bases: KLUSplitLinear

KLU split solver paired with full Newton for quadratic convergence via klu_refactor.

Extends :class:KLUSplitLinear with :meth:refactor_jacobian, which updates the numeric LU factorization in-place using klujax.refactor. The sparsity pattern is fixed for a given circuit topology, so KLU reuses the existing memory allocation and fill-reducing permutation — only the L/U values are recomputed. This gives full Newton (quadratic) convergence at a fraction of the cost of re-calling klu_factor at every iteration.

Use together with :class:~circulax.solvers.transient.RefactoringTransientSolver.

Best For
  • Large circuits on CPU with nonlinear devices where quadratic convergence is desired.
  • Transient simulations where the Jacobian changes significantly between Newton iterates.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Satisfies the lineax API; call _solve_impl directly for internal use.

factor_jacobian

Factor the Jacobian and return a numeric handle for repeated solves.

from_component_groups

Factory — delegates to :meth:KLUSplitSolver.from_component_groups.

init

Initialize the solver state (no-op for stateless solvers).

refactor_jacobian

Update the numeric factorization in-place with new Jacobian values.

solve_dc

DC operating point via damped Newton-Raphson.

solve_dc_auto

DC Operating Point with automatic homotopy fallback.

solve_dc_checked

DC Operating Point with convergence status.

solve_dc_gmin

DC Operating Point via GMIN stepping (homotopy rescue).

solve_dc_source

DC Operating Point via source stepping (homotopy rescue).

solve_with_frozen_jacobian

Solve using a pre-computed numeric factorization.

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Satisfies the lineax API; call _solve_impl directly for internal use.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Satisfies the lineax API; call ``_solve_impl`` directly for internal use."""
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

factor_jacobian ¤

factor_jacobian(all_vals: Array) -> Array

Factor the Jacobian and return a numeric handle for repeated solves.

Parameters:

Name Type Description Default
all_vals Array

Flattened non-zero Jacobian values (COO format).

required

Returns:

Type Description
Array

Opaque numeric handle (int32 JAX scalar) to pass to

Array

meth:solve_with_frozen_jacobian. Must be freed with

Array

klujax.free_numeric after use to avoid C++ memory leaks.

Source code in circulax/solvers/linear.py
def factor_jacobian(self, all_vals: jax.Array) -> jax.Array:
    """Factor the Jacobian and return a numeric handle for repeated solves.

    Args:
        all_vals: Flattened non-zero Jacobian values (COO format).

    Returns:
        Opaque numeric handle (``int32`` JAX scalar) to pass to
        :meth:`solve_with_frozen_jacobian`.  Must be freed with
        ``klujax.free_numeric`` after use to avoid C++ memory leaks.

    """
    g_vals = jnp.full(self.ground_indices.shape[0], GROUND_STIFFNESS, dtype=all_vals.dtype)
    l_vals = jnp.full(self.sys_size, self.g_leak, dtype=all_vals.dtype)

    raw_vals = jnp.concatenate([all_vals, g_vals, l_vals])
    coalesced_vals = jax.ops.segment_sum(raw_vals, self.map_idx, num_segments=self.n_unique)

    return klujax.factor(self.u_rows, self.u_cols, coalesced_vals, self._handle_wrapper.handle)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any],
    num_vars: int,
    *,
    is_complex: bool = False,
    g_leak: float = 1e-09
) -> KLUSplitQuadratic

Factory — delegates to :meth:KLUSplitSolver.from_component_groups.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False, g_leak: float = 1e-9
) -> "KLUSplitQuadratic":
    """Factory — delegates to :meth:`KLUSplitSolver.from_component_groups`."""
    return super().from_component_groups(  # type: ignore[return-value]
        component_groups, num_vars, is_complex=is_complex, g_leak=g_leak
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (no-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (no-op for stateless solvers)."""
    return None

refactor_jacobian ¤

refactor_jacobian(all_vals: Array, numeric: Array) -> Array

Update the numeric factorization in-place with new Jacobian values.

Reuses the existing memory allocation and fill-reducing permutation from the symbolic analysis; only the L/U values are recomputed. Faster than calling :meth:~KLUSplitLinear.factor_jacobian from scratch each Newton iteration.

Parameters:

Name Type Description Default
all_vals Array

Flattened non-zero Jacobian values (COO format).

required
numeric Array

Existing handle returned by :meth:~KLUSplitLinear.factor_jacobian.

required

Returns:

Type Description
Array

Refreshed numeric handle (same underlying C++ object, now connected in the

Array

XLA computation graph so the refactor cannot be eliminated as dead code).

Source code in circulax/solvers/linear.py
def refactor_jacobian(self, all_vals: jax.Array, numeric: jax.Array) -> jax.Array:
    """Update the numeric factorization in-place with new Jacobian values.

    Reuses the existing memory allocation and fill-reducing permutation from the
    symbolic analysis; only the L/U values are recomputed.  Faster than calling
    :meth:`~KLUSplitLinear.factor_jacobian` from scratch each Newton iteration.

    Args:
        all_vals: Flattened non-zero Jacobian values (COO format).
        numeric: Existing handle returned by :meth:`~KLUSplitLinear.factor_jacobian`.

    Returns:
        Refreshed numeric handle (same underlying C++ object, now connected in the
        XLA computation graph so the refactor cannot be eliminated as dead code).

    """
    g_vals = jnp.full(self.ground_indices.shape[0], GROUND_STIFFNESS, dtype=all_vals.dtype)
    l_vals = jnp.full(self.sys_size, self.g_leak, dtype=all_vals.dtype)
    raw_vals = jnp.concatenate([all_vals, g_vals, l_vals])
    coalesced_vals = jax.ops.segment_sum(raw_vals, self.map_idx, num_segments=self.n_unique)
    return klujax.refactor(self.u_rows, self.u_cols, coalesced_vals, numeric, self._handle_wrapper.handle)

solve_dc ¤

solve_dc(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC operating point via damped Newton-Raphson.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC operating point via damped Newton-Raphson.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).

    Returns:
        Converged solution vector.

    """
    y, _ = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)
    return y

solve_dc_auto ¤

solve_dc_auto(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point with automatic homotopy fallback.

Attempts a direct Newton solve first. If it fails to converge, falls back to GMIN stepping followed by source stepping — all inside a single JIT-compiled kernel via jax.lax.cond.

Strategy: 1. _run_newton — plain damped Newton from y_guess. 2. On failure: solve_dc_gmin (GMIN stepping) starting from y_guess, then solve_dc_source (source stepping) from the GMIN result.

Because jax.lax.cond evaluates both branches at trace time but only one at runtime, this compiles to a single kernel with no Python- level branching.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage for GMIN stepping (rescue branch).

0.01
n_gmin int

Number of GMIN steps in the rescue branch.

10
n_source int

Number of source steps in the rescue branch.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc_auto(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point with automatic homotopy fallback.

    Attempts a direct Newton solve first.  If it fails to converge, falls
    back to GMIN stepping followed by source stepping — all inside a single
    JIT-compiled kernel via ``jax.lax.cond``.

    Strategy:
    1.  ``_run_newton`` — plain damped Newton from ``y_guess``.
    2.  On failure: ``solve_dc_gmin`` (GMIN stepping) starting from
        ``y_guess``, then ``solve_dc_source`` (source stepping) from the
        GMIN result.

    Because ``jax.lax.cond`` evaluates both branches at *trace* time but
    only one at *runtime*, this compiles to a single kernel with no Python-
    level branching.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage for GMIN stepping (rescue branch).
        n_gmin: Number of GMIN steps in the rescue branch.
        n_source: Number of source steps in the rescue branch.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector.

    """
    y_direct, converged = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

    def rescue(_: None) -> jax.Array:
        y_gmin = self.solve_dc_gmin(
            component_groups,
            y_guess,
            g_start=g_start,
            n_steps=n_gmin,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )
        return self.solve_dc_source(
            component_groups,
            y_gmin,
            n_steps=n_source,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )

    return jax.lax.cond(converged, lambda _: y_direct, rescue, None)

solve_dc_checked ¤

solve_dc_checked(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> tuple[Array, Array]

DC Operating Point with convergence status.

Identical to :meth:solve_dc but additionally returns a boolean JAX scalar indicating whether the Newton-Raphson fixed-point iteration reported success. Because the flag is a JAX array (not a Python bool) it can be consumed inside compiled programs:

.. code-block:: python

y, converged = solver.solve_dc_checked(groups, y0)
# Outside JIT — inspect in Python:
if not converged:
    y = solver.solve_dc_gmin(groups, y0)
# Inside JIT — branch without Python-level control flow:
y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required
rtol float

Relative tolerance for optx.fixed_point.

1e-06
atol float

Absolute tolerance for optx.fixed_point.

1e-06
max_steps int

Maximum Newton iterations.

100

Returns:

Type Description
tuple[Array, Array]

(y, converged) — solution vector and boolean success flag.

Source code in circulax/solvers/linear.py
def solve_dc_checked(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> tuple[jax.Array, jax.Array]:
    """DC Operating Point with convergence status.

    Identical to :meth:`solve_dc` but additionally returns a boolean JAX
    scalar indicating whether the Newton-Raphson fixed-point iteration
    reported success.  Because the flag is a JAX array (not a Python bool)
    it can be consumed inside compiled programs:

    .. code-block:: python

        y, converged = solver.solve_dc_checked(groups, y0)
        # Outside JIT — inspect in Python:
        if not converged:
            y = solver.solve_dc_gmin(groups, y0)
        # Inside JIT — branch without Python-level control flow:
        y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).
        rtol: Relative tolerance for ``optx.fixed_point``.
        atol: Absolute tolerance for ``optx.fixed_point``.
        max_steps: Maximum Newton iterations.

    Returns:
        ``(y, converged)`` — solution vector and boolean success flag.

    """
    return self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

solve_dc_gmin ¤

solve_dc_gmin(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via GMIN stepping (homotopy rescue).

Steps the diagonal regularisation conductance g_leak logarithmically from g_start down to self.g_leak, using each converged solution as the warm start for the next step. The large initial g_leak linearises highly nonlinear components (diodes above threshold, lasers) that would otherwise cause Newton to diverge from a flat 0V start.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage conductance (large value, e.g. 1e-2).

0.01
n_steps int

Number of log-uniform steps from g_start to self.g_leak.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector after the full GMIN schedule.

Source code in circulax/solvers/linear.py
def solve_dc_gmin(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via GMIN stepping (homotopy rescue).

    Steps the diagonal regularisation conductance ``g_leak`` logarithmically
    from ``g_start`` down to ``self.g_leak``, using each converged solution
    as the warm start for the next step.  The large initial ``g_leak``
    linearises highly nonlinear components (diodes above threshold, lasers)
    that would otherwise cause Newton to diverge from a flat 0V start.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage conductance (large value, e.g. ``1e-2``).
        n_steps: Number of log-uniform steps from ``g_start`` to
            ``self.g_leak``.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector after the full GMIN schedule.

    """
    g_values = jnp.logspace(jnp.log10(g_start), jnp.log10(self.g_leak), n_steps)

    def step(y: jax.Array, g_leak_val: jax.Array) -> tuple[jax.Array, None]:
        stepped_solver = eqx.tree_at(lambda s: s.g_leak, self, g_leak_val)
        y_new, _ = stepped_solver._run_newton(  # noqa: SLF001
            component_groups, y, rtol=rtol, atol=atol, max_steps=max_steps
        )
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, g_values)
    return y_final

solve_dc_source ¤

solve_dc_source(
    component_groups: dict[str, Any],
    y_guess: Array,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via source stepping (homotopy rescue).

Ramps all source amplitudes (components tagged with amplitude_param) from 10 % to 100 % of their netlist values, using each converged solution as the warm start for the next step. This guides Newton through the nonlinear region without the large initial step from 0V to full excitation.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
n_steps int

Number of uniformly-spaced steps from 0.1 to 1.0.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector at full source amplitude.

Source code in circulax/solvers/linear.py
def solve_dc_source(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via source stepping (homotopy rescue).

    Ramps all source amplitudes (components tagged with ``amplitude_param``)
    from 10 % to 100 % of their netlist values, using each converged
    solution as the warm start for the next step.  This guides Newton
    through the nonlinear region without the large initial step from 0V to
    full excitation.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        n_steps: Number of uniformly-spaced steps from 0.1 to 1.0.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector at full source amplitude.

    """
    scales = jnp.linspace(0.1, 1.0, n_steps)

    def step(y: jax.Array, scale: jax.Array) -> tuple[jax.Array, None]:
        y_new, _ = self._run_newton(component_groups, y, source_scale=scale, rtol=rtol, atol=atol, max_steps=max_steps)
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, scales)
    return y_final

solve_with_frozen_jacobian ¤

solve_with_frozen_jacobian(residual: Array, numeric: Array) -> Solution

Solve using a pre-computed numeric factorization.

Parameters:

Name Type Description Default
residual Array

The right-hand side vector -F(y).

required
numeric Array

Handle returned by :meth:factor_jacobian.

required

Returns:

Type Description
Solution

class:lineax.Solution with the Newton step δy.

Source code in circulax/solvers/linear.py
def solve_with_frozen_jacobian(self, residual: jax.Array, numeric: jax.Array) -> lx.Solution:
    """Solve using a pre-computed numeric factorization.

    Args:
        residual: The right-hand side vector ``-F(y)``.
        numeric: Handle returned by :meth:`factor_jacobian`.

    Returns:
        :class:`lineax.Solution` with the Newton step ``δy``.

    """
    solution = klujax.solve_with_numeric(numeric, residual, self._handle_wrapper.handle)
    return lx.Solution(
        value=solution.reshape(residual.shape),
        result=lx.RESULTS.successful,
        state=None,
        stats={},
    )

KLUSplitSolver ¤

Bases: CircuitLinearSolver

Solves the system using the KLU sparse solver (via klujax) with split interface.

This solver performs symbolic analysis ONCE during initialization and reuses the symbolic handle for subsequent solves, significantly speeding up non-linear simulations (Newton-Raphson iterations).

Best For
  • Large circuits (N > 5000) running on CPU.
  • DC Operating Points of massive meshes.

Attributes:

Name Type Description
Bp, Bi

CSC format indices (fixed structure).

csc_map_idx Bi

Mapping from raw value indices to CSC value vector.

symbolic_handle Bi

Pointer to the pre-computed KLU symbolic analysis.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Satisfies the lineax API; call _solve_impl directly for internal use.

from_component_groups

Factory method to pre-hash indices for sparse coalescence.

init

Initialize the solver state (no-op for stateless solvers).

solve_dc

DC operating point via damped Newton-Raphson.

solve_dc_auto

DC Operating Point with automatic homotopy fallback.

solve_dc_checked

DC Operating Point with convergence status.

solve_dc_gmin

DC Operating Point via GMIN stepping (homotopy rescue).

solve_dc_source

DC Operating Point via source stepping (homotopy rescue).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Satisfies the lineax API; call _solve_impl directly for internal use.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Satisfies the lineax API; call ``_solve_impl`` directly for internal use."""
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any],
    num_vars: int,
    *,
    is_complex: bool = False,
    g_leak: float = 1e-09
) -> KLUSplitSolver

Factory method to pre-hash indices for sparse coalescence.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False, g_leak: float = 1e-9
) -> "KLUSplitSolver":
    """Factory method to pre-hash indices for sparse coalescence."""
    rows, cols, ground_idxs, sys_size = _build_index_arrays(component_groups, num_vars, is_complex)
    u_rows, u_cols, map_idx, n_unique = _klu_deduplicate(rows, cols, ground_idxs, sys_size)
    symbolic = klujax.analyze(u_rows, u_cols, sys_size)
    return cls(
        u_rows=jnp.array(u_rows),
        u_cols=jnp.array(u_cols),
        map_idx=jnp.array(map_idx),
        n_unique=n_unique,
        _handle_wrapper=symbolic,
        ground_indices=jnp.array(ground_idxs),
        sys_size=sys_size,
        is_complex=is_complex,
        g_leak=g_leak,
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (no-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (no-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC operating point via damped Newton-Raphson.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC operating point via damped Newton-Raphson.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).

    Returns:
        Converged solution vector.

    """
    y, _ = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)
    return y

solve_dc_auto ¤

solve_dc_auto(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point with automatic homotopy fallback.

Attempts a direct Newton solve first. If it fails to converge, falls back to GMIN stepping followed by source stepping — all inside a single JIT-compiled kernel via jax.lax.cond.

Strategy: 1. _run_newton — plain damped Newton from y_guess. 2. On failure: solve_dc_gmin (GMIN stepping) starting from y_guess, then solve_dc_source (source stepping) from the GMIN result.

Because jax.lax.cond evaluates both branches at trace time but only one at runtime, this compiles to a single kernel with no Python- level branching.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage for GMIN stepping (rescue branch).

0.01
n_gmin int

Number of GMIN steps in the rescue branch.

10
n_source int

Number of source steps in the rescue branch.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc_auto(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point with automatic homotopy fallback.

    Attempts a direct Newton solve first.  If it fails to converge, falls
    back to GMIN stepping followed by source stepping — all inside a single
    JIT-compiled kernel via ``jax.lax.cond``.

    Strategy:
    1.  ``_run_newton`` — plain damped Newton from ``y_guess``.
    2.  On failure: ``solve_dc_gmin`` (GMIN stepping) starting from
        ``y_guess``, then ``solve_dc_source`` (source stepping) from the
        GMIN result.

    Because ``jax.lax.cond`` evaluates both branches at *trace* time but
    only one at *runtime*, this compiles to a single kernel with no Python-
    level branching.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage for GMIN stepping (rescue branch).
        n_gmin: Number of GMIN steps in the rescue branch.
        n_source: Number of source steps in the rescue branch.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector.

    """
    y_direct, converged = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

    def rescue(_: None) -> jax.Array:
        y_gmin = self.solve_dc_gmin(
            component_groups,
            y_guess,
            g_start=g_start,
            n_steps=n_gmin,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )
        return self.solve_dc_source(
            component_groups,
            y_gmin,
            n_steps=n_source,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )

    return jax.lax.cond(converged, lambda _: y_direct, rescue, None)

solve_dc_checked ¤

solve_dc_checked(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> tuple[Array, Array]

DC Operating Point with convergence status.

Identical to :meth:solve_dc but additionally returns a boolean JAX scalar indicating whether the Newton-Raphson fixed-point iteration reported success. Because the flag is a JAX array (not a Python bool) it can be consumed inside compiled programs:

.. code-block:: python

y, converged = solver.solve_dc_checked(groups, y0)
# Outside JIT — inspect in Python:
if not converged:
    y = solver.solve_dc_gmin(groups, y0)
# Inside JIT — branch without Python-level control flow:
y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required
rtol float

Relative tolerance for optx.fixed_point.

1e-06
atol float

Absolute tolerance for optx.fixed_point.

1e-06
max_steps int

Maximum Newton iterations.

100

Returns:

Type Description
tuple[Array, Array]

(y, converged) — solution vector and boolean success flag.

Source code in circulax/solvers/linear.py
def solve_dc_checked(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> tuple[jax.Array, jax.Array]:
    """DC Operating Point with convergence status.

    Identical to :meth:`solve_dc` but additionally returns a boolean JAX
    scalar indicating whether the Newton-Raphson fixed-point iteration
    reported success.  Because the flag is a JAX array (not a Python bool)
    it can be consumed inside compiled programs:

    .. code-block:: python

        y, converged = solver.solve_dc_checked(groups, y0)
        # Outside JIT — inspect in Python:
        if not converged:
            y = solver.solve_dc_gmin(groups, y0)
        # Inside JIT — branch without Python-level control flow:
        y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).
        rtol: Relative tolerance for ``optx.fixed_point``.
        atol: Absolute tolerance for ``optx.fixed_point``.
        max_steps: Maximum Newton iterations.

    Returns:
        ``(y, converged)`` — solution vector and boolean success flag.

    """
    return self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

solve_dc_gmin ¤

solve_dc_gmin(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via GMIN stepping (homotopy rescue).

Steps the diagonal regularisation conductance g_leak logarithmically from g_start down to self.g_leak, using each converged solution as the warm start for the next step. The large initial g_leak linearises highly nonlinear components (diodes above threshold, lasers) that would otherwise cause Newton to diverge from a flat 0V start.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage conductance (large value, e.g. 1e-2).

0.01
n_steps int

Number of log-uniform steps from g_start to self.g_leak.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector after the full GMIN schedule.

Source code in circulax/solvers/linear.py
def solve_dc_gmin(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via GMIN stepping (homotopy rescue).

    Steps the diagonal regularisation conductance ``g_leak`` logarithmically
    from ``g_start`` down to ``self.g_leak``, using each converged solution
    as the warm start for the next step.  The large initial ``g_leak``
    linearises highly nonlinear components (diodes above threshold, lasers)
    that would otherwise cause Newton to diverge from a flat 0V start.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage conductance (large value, e.g. ``1e-2``).
        n_steps: Number of log-uniform steps from ``g_start`` to
            ``self.g_leak``.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector after the full GMIN schedule.

    """
    g_values = jnp.logspace(jnp.log10(g_start), jnp.log10(self.g_leak), n_steps)

    def step(y: jax.Array, g_leak_val: jax.Array) -> tuple[jax.Array, None]:
        stepped_solver = eqx.tree_at(lambda s: s.g_leak, self, g_leak_val)
        y_new, _ = stepped_solver._run_newton(  # noqa: SLF001
            component_groups, y, rtol=rtol, atol=atol, max_steps=max_steps
        )
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, g_values)
    return y_final

solve_dc_source ¤

solve_dc_source(
    component_groups: dict[str, Any],
    y_guess: Array,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via source stepping (homotopy rescue).

Ramps all source amplitudes (components tagged with amplitude_param) from 10 % to 100 % of their netlist values, using each converged solution as the warm start for the next step. This guides Newton through the nonlinear region without the large initial step from 0V to full excitation.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
n_steps int

Number of uniformly-spaced steps from 0.1 to 1.0.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector at full source amplitude.

Source code in circulax/solvers/linear.py
def solve_dc_source(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via source stepping (homotopy rescue).

    Ramps all source amplitudes (components tagged with ``amplitude_param``)
    from 10 % to 100 % of their netlist values, using each converged
    solution as the warm start for the next step.  This guides Newton
    through the nonlinear region without the large initial step from 0V to
    full excitation.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        n_steps: Number of uniformly-spaced steps from 0.1 to 1.0.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector at full source amplitude.

    """
    scales = jnp.linspace(0.1, 1.0, n_steps)

    def step(y: jax.Array, scale: jax.Array) -> tuple[jax.Array, None]:
        y_new, _ = self._run_newton(component_groups, y, source_scale=scale, rtol=rtol, atol=atol, max_steps=max_steps)
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, scales)
    return y_final

SparseSolver ¤

Bases: CircuitLinearSolver

Solves the system using JAX's Iterative BiCGStab solver.

Best For
  • Large Transient Simulations on GPU (uses previous step as warm start).
  • Systems where N is too large for Dense, but we need VMAP support.

Attributes:

Name Type Description
diag_mask Array

Mask to extract diagonal elements for preconditioning.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Satisfies the lineax API; call _solve_impl directly for internal use.

from_component_groups

Factory method to prepare indices and diagonal mask.

init

Initialize the solver state (no-op for stateless solvers).

solve_dc

DC operating point via damped Newton-Raphson.

solve_dc_auto

DC Operating Point with automatic homotopy fallback.

solve_dc_checked

DC Operating Point with convergence status.

solve_dc_gmin

DC Operating Point via GMIN stepping (homotopy rescue).

solve_dc_source

DC Operating Point via source stepping (homotopy rescue).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Satisfies the lineax API; call _solve_impl directly for internal use.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Satisfies the lineax API; call ``_solve_impl`` directly for internal use."""
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any],
    num_vars: int,
    *,
    is_complex: bool = False,
    g_leak: float = 1e-09
) -> SparseSolver

Factory method to prepare indices and diagonal mask.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False, g_leak: float = 1e-9
) -> "SparseSolver":
    """Factory method to prepare indices and diagonal mask."""
    rows, cols, ground_idxs, sys_size = _build_index_arrays(component_groups, num_vars, is_complex)
    return cls(
        static_rows=jnp.array(rows),
        static_cols=jnp.array(cols),
        diag_mask=jnp.array(rows == cols),
        sys_size=sys_size,
        ground_indices=jnp.array(ground_idxs),
        is_complex=is_complex,
        g_leak=g_leak,
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (no-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (no-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC operating point via damped Newton-Raphson.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC operating point via damped Newton-Raphson.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).

    Returns:
        Converged solution vector.

    """
    y, _ = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)
    return y

solve_dc_auto ¤

solve_dc_auto(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point with automatic homotopy fallback.

Attempts a direct Newton solve first. If it fails to converge, falls back to GMIN stepping followed by source stepping — all inside a single JIT-compiled kernel via jax.lax.cond.

Strategy: 1. _run_newton — plain damped Newton from y_guess. 2. On failure: solve_dc_gmin (GMIN stepping) starting from y_guess, then solve_dc_source (source stepping) from the GMIN result.

Because jax.lax.cond evaluates both branches at trace time but only one at runtime, this compiles to a single kernel with no Python- level branching.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage for GMIN stepping (rescue branch).

0.01
n_gmin int

Number of GMIN steps in the rescue branch.

10
n_source int

Number of source steps in the rescue branch.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc_auto(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point with automatic homotopy fallback.

    Attempts a direct Newton solve first.  If it fails to converge, falls
    back to GMIN stepping followed by source stepping — all inside a single
    JIT-compiled kernel via ``jax.lax.cond``.

    Strategy:
    1.  ``_run_newton`` — plain damped Newton from ``y_guess``.
    2.  On failure: ``solve_dc_gmin`` (GMIN stepping) starting from
        ``y_guess``, then ``solve_dc_source`` (source stepping) from the
        GMIN result.

    Because ``jax.lax.cond`` evaluates both branches at *trace* time but
    only one at *runtime*, this compiles to a single kernel with no Python-
    level branching.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage for GMIN stepping (rescue branch).
        n_gmin: Number of GMIN steps in the rescue branch.
        n_source: Number of source steps in the rescue branch.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector.

    """
    y_direct, converged = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

    def rescue(_: None) -> jax.Array:
        y_gmin = self.solve_dc_gmin(
            component_groups,
            y_guess,
            g_start=g_start,
            n_steps=n_gmin,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )
        return self.solve_dc_source(
            component_groups,
            y_gmin,
            n_steps=n_source,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )

    return jax.lax.cond(converged, lambda _: y_direct, rescue, None)

solve_dc_checked ¤

solve_dc_checked(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> tuple[Array, Array]

DC Operating Point with convergence status.

Identical to :meth:solve_dc but additionally returns a boolean JAX scalar indicating whether the Newton-Raphson fixed-point iteration reported success. Because the flag is a JAX array (not a Python bool) it can be consumed inside compiled programs:

.. code-block:: python

y, converged = solver.solve_dc_checked(groups, y0)
# Outside JIT — inspect in Python:
if not converged:
    y = solver.solve_dc_gmin(groups, y0)
# Inside JIT — branch without Python-level control flow:
y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required
rtol float

Relative tolerance for optx.fixed_point.

1e-06
atol float

Absolute tolerance for optx.fixed_point.

1e-06
max_steps int

Maximum Newton iterations.

100

Returns:

Type Description
tuple[Array, Array]

(y, converged) — solution vector and boolean success flag.

Source code in circulax/solvers/linear.py
def solve_dc_checked(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> tuple[jax.Array, jax.Array]:
    """DC Operating Point with convergence status.

    Identical to :meth:`solve_dc` but additionally returns a boolean JAX
    scalar indicating whether the Newton-Raphson fixed-point iteration
    reported success.  Because the flag is a JAX array (not a Python bool)
    it can be consumed inside compiled programs:

    .. code-block:: python

        y, converged = solver.solve_dc_checked(groups, y0)
        # Outside JIT — inspect in Python:
        if not converged:
            y = solver.solve_dc_gmin(groups, y0)
        # Inside JIT — branch without Python-level control flow:
        y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).
        rtol: Relative tolerance for ``optx.fixed_point``.
        atol: Absolute tolerance for ``optx.fixed_point``.
        max_steps: Maximum Newton iterations.

    Returns:
        ``(y, converged)`` — solution vector and boolean success flag.

    """
    return self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

solve_dc_gmin ¤

solve_dc_gmin(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via GMIN stepping (homotopy rescue).

Steps the diagonal regularisation conductance g_leak logarithmically from g_start down to self.g_leak, using each converged solution as the warm start for the next step. The large initial g_leak linearises highly nonlinear components (diodes above threshold, lasers) that would otherwise cause Newton to diverge from a flat 0V start.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage conductance (large value, e.g. 1e-2).

0.01
n_steps int

Number of log-uniform steps from g_start to self.g_leak.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector after the full GMIN schedule.

Source code in circulax/solvers/linear.py
def solve_dc_gmin(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via GMIN stepping (homotopy rescue).

    Steps the diagonal regularisation conductance ``g_leak`` logarithmically
    from ``g_start`` down to ``self.g_leak``, using each converged solution
    as the warm start for the next step.  The large initial ``g_leak``
    linearises highly nonlinear components (diodes above threshold, lasers)
    that would otherwise cause Newton to diverge from a flat 0V start.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage conductance (large value, e.g. ``1e-2``).
        n_steps: Number of log-uniform steps from ``g_start`` to
            ``self.g_leak``.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector after the full GMIN schedule.

    """
    g_values = jnp.logspace(jnp.log10(g_start), jnp.log10(self.g_leak), n_steps)

    def step(y: jax.Array, g_leak_val: jax.Array) -> tuple[jax.Array, None]:
        stepped_solver = eqx.tree_at(lambda s: s.g_leak, self, g_leak_val)
        y_new, _ = stepped_solver._run_newton(  # noqa: SLF001
            component_groups, y, rtol=rtol, atol=atol, max_steps=max_steps
        )
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, g_values)
    return y_final

solve_dc_source ¤

solve_dc_source(
    component_groups: dict[str, Any],
    y_guess: Array,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via source stepping (homotopy rescue).

Ramps all source amplitudes (components tagged with amplitude_param) from 10 % to 100 % of their netlist values, using each converged solution as the warm start for the next step. This guides Newton through the nonlinear region without the large initial step from 0V to full excitation.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
n_steps int

Number of uniformly-spaced steps from 0.1 to 1.0.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector at full source amplitude.

Source code in circulax/solvers/linear.py
def solve_dc_source(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via source stepping (homotopy rescue).

    Ramps all source amplitudes (components tagged with ``amplitude_param``)
    from 10 % to 100 % of their netlist values, using each converged
    solution as the warm start for the next step.  This guides Newton
    through the nonlinear region without the large initial step from 0V to
    full excitation.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        n_steps: Number of uniformly-spaced steps from 0.1 to 1.0.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector at full source amplitude.

    """
    scales = jnp.linspace(0.1, 1.0, n_steps)

    def step(y: jax.Array, scale: jax.Array) -> tuple[jax.Array, None]:
        y_new, _ = self._run_newton(component_groups, y, source_scale=scale, rtol=rtol, atol=atol, max_steps=max_steps)
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, scales)
    return y_final

analyze_circuit ¤

analyze_circuit(
    groups: list,
    num_vars: int,
    backend: str = "default",
    *,
    is_complex: bool = False,
    g_leak: float = 1e-09
) -> CircuitLinearSolver

Initializes a linear solver strategy for circuit analysis.

Factory that selects and configures the numerical backend for solving the linear system derived from a circuit's topology.

Parameters:

Name Type Description Default
groups list

A list of component groups that define the circuit's structure and properties.

required
num_vars int

The total number of variables in the linear system.

required
backend str

The name of the solver backend to use. Supported backends are 'klu', 'klu_split', 'dense', and 'sparse'. Defaults to 'default', which uses 'klu_split'.

'default'
is_complex bool

A flag indicating whether the circuit analysis involves complex numbers. Defaults to False.

False

Returns:

Name Type Description
CircuitLinearSolver CircuitLinearSolver

An instance of a circuit linear solver strategy

CircuitLinearSolver

configured for the specified backend and circuit parameters.

Raises:

Type Description
ValueError

If the specified backend is not supported.

Source code in circulax/solvers/linear.py
def analyze_circuit(
    groups: list, num_vars: int, backend: str = "default", *, is_complex: bool = False, g_leak: float = 1e-9
) -> CircuitLinearSolver:
    """Initializes a linear solver strategy for circuit analysis.

    Factory that selects and configures the numerical backend for solving the
    linear system derived from a circuit's topology.

    Args:
        groups (list): A list of component groups that define the circuit's
            structure and properties.
        num_vars (int): The total number of variables in the linear system.
        backend (str, optional): The name of the solver backend to use.
            Supported backends are 'klu', 'klu_split', 'dense', and 'sparse'.
            Defaults to 'default', which uses 'klu_split'.
        is_complex (bool, optional): A flag indicating whether the circuit
            analysis involves complex numbers. Defaults to False.

    Returns:
        CircuitLinearSolver: An instance of a circuit linear solver strategy
        configured for the specified backend and circuit parameters.

    Raises:
        ValueError: If the specified backend is not supported.

    """
    solver_class = backends.get(backend)
    if solver_class is None:
        msg = f"Unknown backend: '{backend}'. Available backends are {list(backends.keys())}"
        raise ValueError(msg)

    linear_strategy = solver_class.from_component_groups(groups, num_vars, is_complex=is_complex, g_leak=g_leak)

    return linear_strategy