Skip to content

solvers ¤

Root finding and transient solvers.

Modules:

Name Description
ac_sweep

AC small-signal frequency sweep returning S-parameters.

assembly

Assembly functions for the transient circuit solver.

circuit_diffeq

Stripped-down ODE integrator for circuit simulation.

harmonic_balance

Harmonic Balance solver for periodic steady-state circuit analysis.

linear

Circuit linear solver strategies.

transient

Transient solvers to be used with Diffrax.

Classes:

Name Description
BDF2FactorizedTransientSolver

BDF2 upgrade of :class:FactorizedTransientSolver (frozen-Jacobian Newton).

BDF2RefactoringTransientSolver

BDF2 upgrade of :class:RefactoringTransientSolver (KLU refactor per iteration).

BDF2VectorizedTransientSolver

BDF2 upgrade of :class:VectorizedTransientSolver.

CircuitLinearSolver

Abstract base for all circuit linear solvers.

DenseSolver

Solves the system using dense matrix factorization (LU).

FactorizedTransientSolver

Transient solver using a Modified Newton (frozen-Jacobian) scheme.

KLUSolver

Solves the system using the KLU sparse solver (via klujax).

KLUSplitLinear

KLU split solver paired with Modified Newton (frozen-Jacobian) for linear convergence.

KLUSplitQuadratic

KLU split solver paired with full Newton for quadratic convergence via klu_refactor.

RefactoringTransientSolver

Transient solver with full Newton (quadratic) convergence using klujax.refactor.

SDIRK3FactorizedTransientSolver

3rd-order A-stable SDIRK3 solver with frozen-Jacobian Newton across all stages.

SDIRK3RefactoringTransientSolver

3rd-order A-stable SDIRK3 solver with KLU refactor at each Newton iteration.

SDIRK3VectorizedTransientSolver

3rd-order A-stable SDIRK3 solver using full Newton-Raphson at each stage.

SparseSolver

Solves the system using JAX's Iterative BiCGStab solver.

VectorizedTransientSolver

Transient solver that works strictly on FLAT (Real) vectors.

Functions:

Name Description
analyze_circuit

Initializes a linear solver strategy for circuit analysis.

assemble_system_complex

Assemble the residual vectors and effective Jacobian values for an unrolled complex system.

assemble_system_real

Assemble the residual vectors and effective Jacobian values for a real system.

circuit_diffeqsolve

Stripped-down :func:diffrax.diffeqsolve for circuit simulation.

setup_ac_sweep

Configure and return a callable for AC small-signal S-parameter sweep.

setup_harmonic_balance

Configure and return a Harmonic Balance callable.

setup_transient

Configures and returns a function for executing transient analysis.

BDF2FactorizedTransientSolver ¤

Bases: FactorizedTransientSolver

BDF2 upgrade of :class:FactorizedTransientSolver (frozen-Jacobian Newton).

Factors J_eff once at the predictor state and reuses it across all Newton iterations, trading quadratic convergence for cheaper per-iteration cost. The BDF2 Jacobian scaling α₀/h is used when factoring.

BDF2RefactoringTransientSolver ¤

Bases: RefactoringTransientSolver

BDF2 upgrade of :class:RefactoringTransientSolver (KLU refactor per iteration).

Full quadratic Newton convergence via klujax.refactor at each iteration, combined with BDF2 time discretisation for 2nd-order accuracy.

BDF2VectorizedTransientSolver ¤

Bases: VectorizedTransientSolver

BDF2 upgrade of :class:VectorizedTransientSolver.

Implements variable-step BDF2 via the companion method. On the first step Backward Euler is used automatically; from step 2 onward BDF2 is activated. The Jacobian scaling changes from 1/h (BE) to α₀/h (BDF2) where α₀ = (1 + 2ω)/(1 + ω) and ω = h_n/h_{n-1}.

solver_state is a 3-tuple (y_nm1, h_nm1, q_nm1). h_nm1 is initialised to +inf so that ω = 0 on the first step, making the BDF2 formula reduce to Backward Euler via IEEE 754 arithmetic (no branching). q_nm1 caches Q(y_{n-1}) to avoid recomputing it each step.

CircuitLinearSolver ¤

Bases: AbstractLinearSolver

Abstract base for all circuit linear solvers.

Attributes:

Name Type Description
ground_indices Array

Indices of nodes connected to ground (forced to 0V).

is_complex bool

If True, the system is 2N×2N (real/imag unrolled); otherwise N×N.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Satisfies the lineax API; call _solve_impl directly for internal use.

init

Initialize the solver state (no-op for stateless solvers).

solve_dc

DC operating point via damped Newton-Raphson.

solve_dc_auto

DC Operating Point with automatic homotopy fallback.

solve_dc_checked

DC Operating Point with convergence status.

solve_dc_gmin

DC Operating Point via GMIN stepping (homotopy rescue).

solve_dc_source

DC Operating Point via source stepping (homotopy rescue).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Satisfies the lineax API; call _solve_impl directly for internal use.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Satisfies the lineax API; call ``_solve_impl`` directly for internal use."""
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (no-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (no-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC operating point via damped Newton-Raphson.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC operating point via damped Newton-Raphson.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).

    Returns:
        Converged solution vector.

    """
    y, _ = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)
    return y

solve_dc_auto ¤

solve_dc_auto(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point with automatic homotopy fallback.

Attempts a direct Newton solve first. If it fails to converge, falls back to GMIN stepping followed by source stepping — all inside a single JIT-compiled kernel via jax.lax.cond.

Strategy: 1. _run_newton — plain damped Newton from y_guess. 2. On failure: solve_dc_gmin (GMIN stepping) starting from y_guess, then solve_dc_source (source stepping) from the GMIN result.

Because jax.lax.cond evaluates both branches at trace time but only one at runtime, this compiles to a single kernel with no Python- level branching.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage for GMIN stepping (rescue branch).

0.01
n_gmin int

Number of GMIN steps in the rescue branch.

10
n_source int

Number of source steps in the rescue branch.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc_auto(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point with automatic homotopy fallback.

    Attempts a direct Newton solve first.  If it fails to converge, falls
    back to GMIN stepping followed by source stepping — all inside a single
    JIT-compiled kernel via ``jax.lax.cond``.

    Strategy:
    1.  ``_run_newton`` — plain damped Newton from ``y_guess``.
    2.  On failure: ``solve_dc_gmin`` (GMIN stepping) starting from
        ``y_guess``, then ``solve_dc_source`` (source stepping) from the
        GMIN result.

    Because ``jax.lax.cond`` evaluates both branches at *trace* time but
    only one at *runtime*, this compiles to a single kernel with no Python-
    level branching.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage for GMIN stepping (rescue branch).
        n_gmin: Number of GMIN steps in the rescue branch.
        n_source: Number of source steps in the rescue branch.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector.

    """
    y_direct, converged = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

    def rescue(_: None) -> jax.Array:
        y_gmin = self.solve_dc_gmin(
            component_groups,
            y_guess,
            g_start=g_start,
            n_steps=n_gmin,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )
        return self.solve_dc_source(
            component_groups,
            y_gmin,
            n_steps=n_source,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )

    return jax.lax.cond(converged, lambda _: y_direct, rescue, None)

solve_dc_checked ¤

solve_dc_checked(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> tuple[Array, Array]

DC Operating Point with convergence status.

Identical to :meth:solve_dc but additionally returns a boolean JAX scalar indicating whether the Newton-Raphson fixed-point iteration reported success. Because the flag is a JAX array (not a Python bool) it can be consumed inside compiled programs:

.. code-block:: python

y, converged = solver.solve_dc_checked(groups, y0)
# Outside JIT — inspect in Python:
if not converged:
    y = solver.solve_dc_gmin(groups, y0)
# Inside JIT — branch without Python-level control flow:
y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required
rtol float

Relative tolerance for optx.fixed_point.

1e-06
atol float

Absolute tolerance for optx.fixed_point.

1e-06
max_steps int

Maximum Newton iterations.

100

Returns:

Type Description
tuple[Array, Array]

(y, converged) — solution vector and boolean success flag.

Source code in circulax/solvers/linear.py
def solve_dc_checked(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> tuple[jax.Array, jax.Array]:
    """DC Operating Point with convergence status.

    Identical to :meth:`solve_dc` but additionally returns a boolean JAX
    scalar indicating whether the Newton-Raphson fixed-point iteration
    reported success.  Because the flag is a JAX array (not a Python bool)
    it can be consumed inside compiled programs:

    .. code-block:: python

        y, converged = solver.solve_dc_checked(groups, y0)
        # Outside JIT — inspect in Python:
        if not converged:
            y = solver.solve_dc_gmin(groups, y0)
        # Inside JIT — branch without Python-level control flow:
        y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).
        rtol: Relative tolerance for ``optx.fixed_point``.
        atol: Absolute tolerance for ``optx.fixed_point``.
        max_steps: Maximum Newton iterations.

    Returns:
        ``(y, converged)`` — solution vector and boolean success flag.

    """
    return self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

solve_dc_gmin ¤

solve_dc_gmin(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via GMIN stepping (homotopy rescue).

Steps the diagonal regularisation conductance g_leak logarithmically from g_start down to self.g_leak, using each converged solution as the warm start for the next step. The large initial g_leak linearises highly nonlinear components (diodes above threshold, lasers) that would otherwise cause Newton to diverge from a flat 0V start.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage conductance (large value, e.g. 1e-2).

0.01
n_steps int

Number of log-uniform steps from g_start to self.g_leak.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector after the full GMIN schedule.

Source code in circulax/solvers/linear.py
def solve_dc_gmin(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via GMIN stepping (homotopy rescue).

    Steps the diagonal regularisation conductance ``g_leak`` logarithmically
    from ``g_start`` down to ``self.g_leak``, using each converged solution
    as the warm start for the next step.  The large initial ``g_leak``
    linearises highly nonlinear components (diodes above threshold, lasers)
    that would otherwise cause Newton to diverge from a flat 0V start.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage conductance (large value, e.g. ``1e-2``).
        n_steps: Number of log-uniform steps from ``g_start`` to
            ``self.g_leak``.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector after the full GMIN schedule.

    """
    g_values = jnp.logspace(jnp.log10(g_start), jnp.log10(self.g_leak), n_steps)

    def step(y: jax.Array, g_leak_val: jax.Array) -> tuple[jax.Array, None]:
        stepped_solver = eqx.tree_at(lambda s: s.g_leak, self, g_leak_val)
        y_new, _ = stepped_solver._run_newton(  # noqa: SLF001
            component_groups, y, rtol=rtol, atol=atol, max_steps=max_steps
        )
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, g_values)
    return y_final

solve_dc_source ¤

solve_dc_source(
    component_groups: dict[str, Any],
    y_guess: Array,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via source stepping (homotopy rescue).

Ramps all source amplitudes (components tagged with amplitude_param) from 10 % to 100 % of their netlist values, using each converged solution as the warm start for the next step. This guides Newton through the nonlinear region without the large initial step from 0V to full excitation.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
n_steps int

Number of uniformly-spaced steps from 0.1 to 1.0.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector at full source amplitude.

Source code in circulax/solvers/linear.py
def solve_dc_source(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via source stepping (homotopy rescue).

    Ramps all source amplitudes (components tagged with ``amplitude_param``)
    from 10 % to 100 % of their netlist values, using each converged
    solution as the warm start for the next step.  This guides Newton
    through the nonlinear region without the large initial step from 0V to
    full excitation.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        n_steps: Number of uniformly-spaced steps from 0.1 to 1.0.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector at full source amplitude.

    """
    scales = jnp.linspace(0.1, 1.0, n_steps)

    def step(y: jax.Array, scale: jax.Array) -> tuple[jax.Array, None]:
        y_new, _ = self._run_newton(component_groups, y, source_scale=scale, rtol=rtol, atol=atol, max_steps=max_steps)
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, scales)
    return y_final

DenseSolver ¤

Bases: CircuitLinearSolver

Solves the system using dense matrix factorization (LU).

Best For
  • Small to Medium circuits (N < 2000).
  • Wavelength sweeps (AC Analysis) on GPU.
  • Systems where VMAP parallelism is critical.

Attributes:

Name Type Description
static_rows Array

Row indices for placing values into dense matrix.

static_cols Array

Column indices.

g_leak float

Leakage conductance added to diagonal to prevent singularity.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Satisfies the lineax API; call _solve_impl directly for internal use.

from_component_groups

Factory method to pre-calculate indices for the dense matrix.

init

Initialize the solver state (no-op for stateless solvers).

solve_dc

DC operating point via damped Newton-Raphson.

solve_dc_auto

DC Operating Point with automatic homotopy fallback.

solve_dc_checked

DC Operating Point with convergence status.

solve_dc_gmin

DC Operating Point via GMIN stepping (homotopy rescue).

solve_dc_source

DC Operating Point via source stepping (homotopy rescue).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Satisfies the lineax API; call _solve_impl directly for internal use.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Satisfies the lineax API; call ``_solve_impl`` directly for internal use."""
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any],
    num_vars: int,
    *,
    is_complex: bool = False,
    g_leak: float = 1e-09
) -> DenseSolver

Factory method to pre-calculate indices for the dense matrix.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False, g_leak: float = 1e-9
) -> "DenseSolver":
    """Factory method to pre-calculate indices for the dense matrix."""
    rows, cols, ground_idxs, sys_size = _build_index_arrays(component_groups, num_vars, is_complex)
    return cls(
        static_rows=jnp.array(rows),
        static_cols=jnp.array(cols),
        sys_size=sys_size,
        ground_indices=jnp.array(ground_idxs),
        is_complex=is_complex,
        g_leak=g_leak,
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (no-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (no-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC operating point via damped Newton-Raphson.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC operating point via damped Newton-Raphson.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).

    Returns:
        Converged solution vector.

    """
    y, _ = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)
    return y

solve_dc_auto ¤

solve_dc_auto(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point with automatic homotopy fallback.

Attempts a direct Newton solve first. If it fails to converge, falls back to GMIN stepping followed by source stepping — all inside a single JIT-compiled kernel via jax.lax.cond.

Strategy: 1. _run_newton — plain damped Newton from y_guess. 2. On failure: solve_dc_gmin (GMIN stepping) starting from y_guess, then solve_dc_source (source stepping) from the GMIN result.

Because jax.lax.cond evaluates both branches at trace time but only one at runtime, this compiles to a single kernel with no Python- level branching.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage for GMIN stepping (rescue branch).

0.01
n_gmin int

Number of GMIN steps in the rescue branch.

10
n_source int

Number of source steps in the rescue branch.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc_auto(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point with automatic homotopy fallback.

    Attempts a direct Newton solve first.  If it fails to converge, falls
    back to GMIN stepping followed by source stepping — all inside a single
    JIT-compiled kernel via ``jax.lax.cond``.

    Strategy:
    1.  ``_run_newton`` — plain damped Newton from ``y_guess``.
    2.  On failure: ``solve_dc_gmin`` (GMIN stepping) starting from
        ``y_guess``, then ``solve_dc_source`` (source stepping) from the
        GMIN result.

    Because ``jax.lax.cond`` evaluates both branches at *trace* time but
    only one at *runtime*, this compiles to a single kernel with no Python-
    level branching.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage for GMIN stepping (rescue branch).
        n_gmin: Number of GMIN steps in the rescue branch.
        n_source: Number of source steps in the rescue branch.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector.

    """
    y_direct, converged = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

    def rescue(_: None) -> jax.Array:
        y_gmin = self.solve_dc_gmin(
            component_groups,
            y_guess,
            g_start=g_start,
            n_steps=n_gmin,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )
        return self.solve_dc_source(
            component_groups,
            y_gmin,
            n_steps=n_source,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )

    return jax.lax.cond(converged, lambda _: y_direct, rescue, None)

solve_dc_checked ¤

solve_dc_checked(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> tuple[Array, Array]

DC Operating Point with convergence status.

Identical to :meth:solve_dc but additionally returns a boolean JAX scalar indicating whether the Newton-Raphson fixed-point iteration reported success. Because the flag is a JAX array (not a Python bool) it can be consumed inside compiled programs:

.. code-block:: python

y, converged = solver.solve_dc_checked(groups, y0)
# Outside JIT — inspect in Python:
if not converged:
    y = solver.solve_dc_gmin(groups, y0)
# Inside JIT — branch without Python-level control flow:
y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required
rtol float

Relative tolerance for optx.fixed_point.

1e-06
atol float

Absolute tolerance for optx.fixed_point.

1e-06
max_steps int

Maximum Newton iterations.

100

Returns:

Type Description
tuple[Array, Array]

(y, converged) — solution vector and boolean success flag.

Source code in circulax/solvers/linear.py
def solve_dc_checked(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> tuple[jax.Array, jax.Array]:
    """DC Operating Point with convergence status.

    Identical to :meth:`solve_dc` but additionally returns a boolean JAX
    scalar indicating whether the Newton-Raphson fixed-point iteration
    reported success.  Because the flag is a JAX array (not a Python bool)
    it can be consumed inside compiled programs:

    .. code-block:: python

        y, converged = solver.solve_dc_checked(groups, y0)
        # Outside JIT — inspect in Python:
        if not converged:
            y = solver.solve_dc_gmin(groups, y0)
        # Inside JIT — branch without Python-level control flow:
        y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).
        rtol: Relative tolerance for ``optx.fixed_point``.
        atol: Absolute tolerance for ``optx.fixed_point``.
        max_steps: Maximum Newton iterations.

    Returns:
        ``(y, converged)`` — solution vector and boolean success flag.

    """
    return self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

solve_dc_gmin ¤

solve_dc_gmin(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via GMIN stepping (homotopy rescue).

Steps the diagonal regularisation conductance g_leak logarithmically from g_start down to self.g_leak, using each converged solution as the warm start for the next step. The large initial g_leak linearises highly nonlinear components (diodes above threshold, lasers) that would otherwise cause Newton to diverge from a flat 0V start.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage conductance (large value, e.g. 1e-2).

0.01
n_steps int

Number of log-uniform steps from g_start to self.g_leak.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector after the full GMIN schedule.

Source code in circulax/solvers/linear.py
def solve_dc_gmin(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via GMIN stepping (homotopy rescue).

    Steps the diagonal regularisation conductance ``g_leak`` logarithmically
    from ``g_start`` down to ``self.g_leak``, using each converged solution
    as the warm start for the next step.  The large initial ``g_leak``
    linearises highly nonlinear components (diodes above threshold, lasers)
    that would otherwise cause Newton to diverge from a flat 0V start.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage conductance (large value, e.g. ``1e-2``).
        n_steps: Number of log-uniform steps from ``g_start`` to
            ``self.g_leak``.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector after the full GMIN schedule.

    """
    g_values = jnp.logspace(jnp.log10(g_start), jnp.log10(self.g_leak), n_steps)

    def step(y: jax.Array, g_leak_val: jax.Array) -> tuple[jax.Array, None]:
        stepped_solver = eqx.tree_at(lambda s: s.g_leak, self, g_leak_val)
        y_new, _ = stepped_solver._run_newton(  # noqa: SLF001
            component_groups, y, rtol=rtol, atol=atol, max_steps=max_steps
        )
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, g_values)
    return y_final

solve_dc_source ¤

solve_dc_source(
    component_groups: dict[str, Any],
    y_guess: Array,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via source stepping (homotopy rescue).

Ramps all source amplitudes (components tagged with amplitude_param) from 10 % to 100 % of their netlist values, using each converged solution as the warm start for the next step. This guides Newton through the nonlinear region without the large initial step from 0V to full excitation.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
n_steps int

Number of uniformly-spaced steps from 0.1 to 1.0.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector at full source amplitude.

Source code in circulax/solvers/linear.py
def solve_dc_source(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via source stepping (homotopy rescue).

    Ramps all source amplitudes (components tagged with ``amplitude_param``)
    from 10 % to 100 % of their netlist values, using each converged
    solution as the warm start for the next step.  This guides Newton
    through the nonlinear region without the large initial step from 0V to
    full excitation.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        n_steps: Number of uniformly-spaced steps from 0.1 to 1.0.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector at full source amplitude.

    """
    scales = jnp.linspace(0.1, 1.0, n_steps)

    def step(y: jax.Array, scale: jax.Array) -> tuple[jax.Array, None]:
        y_new, _ = self._run_newton(component_groups, y, source_scale=scale, rtol=rtol, atol=atol, max_steps=max_steps)
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, scales)
    return y_final

FactorizedTransientSolver ¤

Bases: VectorizedTransientSolver

Transient solver using a Modified Newton (frozen-Jacobian) scheme.

At each timestep the system Jacobian is assembled and factored once at a predicted state, then reused across all Newton iterations. Compared to a full Newton-Raphson solver this trades quadratic convergence for a much cheaper per-iteration cost — one triangular solve instead of a full factorisation — making it efficient for circuits where the Jacobian varies slowly between steps.

Convergence is linear rather than quadratic, so newton_max_steps is set higher than a standard Newton solver would require. Adaptive damping min(1, 0.5 / max|δy|) is applied at each iteration to stabilise convergence in stiff or strongly nonlinear regions.

Both real and complex assembly paths are supported; the complex path concatenates real and imaginary parts into a single real-valued vector, allowing purely real linear algebra kernels to be reused for frequency-domain-style analyses.

Requires a :class:~circulax.solvers.linear.KLUSplitFactorSolver as the linear_solver — use analyze_circuit(..., backend="klu_split_factor").

KLUSolver ¤

Bases: CircuitLinearSolver

Solves the system using the KLU sparse solver (via klujax).

Best For
  • Large circuits (N > 5000) running on CPU.
  • DC Operating Points of massive meshes.
  • Cases where DenseSolver runs out of memory (OOM).
Note

Does NOT support vmap (batching) automatically.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Satisfies the lineax API; call _solve_impl directly for internal use.

from_component_groups

Factory method to pre-hash indices for sparse coalescence.

init

Initialize the solver state (no-op for stateless solvers).

solve_dc

DC operating point via damped Newton-Raphson.

solve_dc_auto

DC Operating Point with automatic homotopy fallback.

solve_dc_checked

DC Operating Point with convergence status.

solve_dc_gmin

DC Operating Point via GMIN stepping (homotopy rescue).

solve_dc_source

DC Operating Point via source stepping (homotopy rescue).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Satisfies the lineax API; call _solve_impl directly for internal use.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Satisfies the lineax API; call ``_solve_impl`` directly for internal use."""
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any],
    num_vars: int,
    *,
    is_complex: bool = False,
    g_leak: float = 1e-09
) -> KLUSolver

Factory method to pre-hash indices for sparse coalescence.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False, g_leak: float = 1e-9
) -> "KLUSolver":
    """Factory method to pre-hash indices for sparse coalescence."""
    rows, cols, ground_idxs, sys_size = _build_index_arrays(component_groups, num_vars, is_complex)
    u_rows, u_cols, map_idx, n_unique = _klu_deduplicate(rows, cols, ground_idxs, sys_size)
    return cls(
        u_rows=jnp.array(u_rows),
        u_cols=jnp.array(u_cols),
        map_idx=jnp.array(map_idx),
        n_unique=n_unique,
        ground_indices=jnp.array(ground_idxs),
        sys_size=sys_size,
        is_complex=is_complex,
        g_leak=g_leak,
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (no-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (no-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC operating point via damped Newton-Raphson.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC operating point via damped Newton-Raphson.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).

    Returns:
        Converged solution vector.

    """
    y, _ = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)
    return y

solve_dc_auto ¤

solve_dc_auto(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point with automatic homotopy fallback.

Attempts a direct Newton solve first. If it fails to converge, falls back to GMIN stepping followed by source stepping — all inside a single JIT-compiled kernel via jax.lax.cond.

Strategy: 1. _run_newton — plain damped Newton from y_guess. 2. On failure: solve_dc_gmin (GMIN stepping) starting from y_guess, then solve_dc_source (source stepping) from the GMIN result.

Because jax.lax.cond evaluates both branches at trace time but only one at runtime, this compiles to a single kernel with no Python- level branching.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage for GMIN stepping (rescue branch).

0.01
n_gmin int

Number of GMIN steps in the rescue branch.

10
n_source int

Number of source steps in the rescue branch.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc_auto(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point with automatic homotopy fallback.

    Attempts a direct Newton solve first.  If it fails to converge, falls
    back to GMIN stepping followed by source stepping — all inside a single
    JIT-compiled kernel via ``jax.lax.cond``.

    Strategy:
    1.  ``_run_newton`` — plain damped Newton from ``y_guess``.
    2.  On failure: ``solve_dc_gmin`` (GMIN stepping) starting from
        ``y_guess``, then ``solve_dc_source`` (source stepping) from the
        GMIN result.

    Because ``jax.lax.cond`` evaluates both branches at *trace* time but
    only one at *runtime*, this compiles to a single kernel with no Python-
    level branching.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage for GMIN stepping (rescue branch).
        n_gmin: Number of GMIN steps in the rescue branch.
        n_source: Number of source steps in the rescue branch.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector.

    """
    y_direct, converged = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

    def rescue(_: None) -> jax.Array:
        y_gmin = self.solve_dc_gmin(
            component_groups,
            y_guess,
            g_start=g_start,
            n_steps=n_gmin,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )
        return self.solve_dc_source(
            component_groups,
            y_gmin,
            n_steps=n_source,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )

    return jax.lax.cond(converged, lambda _: y_direct, rescue, None)

solve_dc_checked ¤

solve_dc_checked(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> tuple[Array, Array]

DC Operating Point with convergence status.

Identical to :meth:solve_dc but additionally returns a boolean JAX scalar indicating whether the Newton-Raphson fixed-point iteration reported success. Because the flag is a JAX array (not a Python bool) it can be consumed inside compiled programs:

.. code-block:: python

y, converged = solver.solve_dc_checked(groups, y0)
# Outside JIT — inspect in Python:
if not converged:
    y = solver.solve_dc_gmin(groups, y0)
# Inside JIT — branch without Python-level control flow:
y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required
rtol float

Relative tolerance for optx.fixed_point.

1e-06
atol float

Absolute tolerance for optx.fixed_point.

1e-06
max_steps int

Maximum Newton iterations.

100

Returns:

Type Description
tuple[Array, Array]

(y, converged) — solution vector and boolean success flag.

Source code in circulax/solvers/linear.py
def solve_dc_checked(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> tuple[jax.Array, jax.Array]:
    """DC Operating Point with convergence status.

    Identical to :meth:`solve_dc` but additionally returns a boolean JAX
    scalar indicating whether the Newton-Raphson fixed-point iteration
    reported success.  Because the flag is a JAX array (not a Python bool)
    it can be consumed inside compiled programs:

    .. code-block:: python

        y, converged = solver.solve_dc_checked(groups, y0)
        # Outside JIT — inspect in Python:
        if not converged:
            y = solver.solve_dc_gmin(groups, y0)
        # Inside JIT — branch without Python-level control flow:
        y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).
        rtol: Relative tolerance for ``optx.fixed_point``.
        atol: Absolute tolerance for ``optx.fixed_point``.
        max_steps: Maximum Newton iterations.

    Returns:
        ``(y, converged)`` — solution vector and boolean success flag.

    """
    return self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

solve_dc_gmin ¤

solve_dc_gmin(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via GMIN stepping (homotopy rescue).

Steps the diagonal regularisation conductance g_leak logarithmically from g_start down to self.g_leak, using each converged solution as the warm start for the next step. The large initial g_leak linearises highly nonlinear components (diodes above threshold, lasers) that would otherwise cause Newton to diverge from a flat 0V start.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage conductance (large value, e.g. 1e-2).

0.01
n_steps int

Number of log-uniform steps from g_start to self.g_leak.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector after the full GMIN schedule.

Source code in circulax/solvers/linear.py
def solve_dc_gmin(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via GMIN stepping (homotopy rescue).

    Steps the diagonal regularisation conductance ``g_leak`` logarithmically
    from ``g_start`` down to ``self.g_leak``, using each converged solution
    as the warm start for the next step.  The large initial ``g_leak``
    linearises highly nonlinear components (diodes above threshold, lasers)
    that would otherwise cause Newton to diverge from a flat 0V start.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage conductance (large value, e.g. ``1e-2``).
        n_steps: Number of log-uniform steps from ``g_start`` to
            ``self.g_leak``.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector after the full GMIN schedule.

    """
    g_values = jnp.logspace(jnp.log10(g_start), jnp.log10(self.g_leak), n_steps)

    def step(y: jax.Array, g_leak_val: jax.Array) -> tuple[jax.Array, None]:
        stepped_solver = eqx.tree_at(lambda s: s.g_leak, self, g_leak_val)
        y_new, _ = stepped_solver._run_newton(  # noqa: SLF001
            component_groups, y, rtol=rtol, atol=atol, max_steps=max_steps
        )
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, g_values)
    return y_final

solve_dc_source ¤

solve_dc_source(
    component_groups: dict[str, Any],
    y_guess: Array,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via source stepping (homotopy rescue).

Ramps all source amplitudes (components tagged with amplitude_param) from 10 % to 100 % of their netlist values, using each converged solution as the warm start for the next step. This guides Newton through the nonlinear region without the large initial step from 0V to full excitation.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
n_steps int

Number of uniformly-spaced steps from 0.1 to 1.0.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector at full source amplitude.

Source code in circulax/solvers/linear.py
def solve_dc_source(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via source stepping (homotopy rescue).

    Ramps all source amplitudes (components tagged with ``amplitude_param``)
    from 10 % to 100 % of their netlist values, using each converged
    solution as the warm start for the next step.  This guides Newton
    through the nonlinear region without the large initial step from 0V to
    full excitation.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        n_steps: Number of uniformly-spaced steps from 0.1 to 1.0.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector at full source amplitude.

    """
    scales = jnp.linspace(0.1, 1.0, n_steps)

    def step(y: jax.Array, scale: jax.Array) -> tuple[jax.Array, None]:
        y_new, _ = self._run_newton(component_groups, y, source_scale=scale, rtol=rtol, atol=atol, max_steps=max_steps)
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, scales)
    return y_final

KLUSplitLinear ¤

Bases: KLUSplitSolver

KLU split solver paired with Modified Newton (frozen-Jacobian) for linear convergence.

Extends :class:KLUSplitSolver with an explicit numeric factorization step so the Jacobian can be factored once per time step and reused across all Newton iterations within that step (Modified Newton / frozen-Jacobian scheme). Use together with :class:~circulax.solvers.transient.FactorizedTransientSolver.

Best For
  • Large circuits (N > 5000) running on CPU where the Jacobian changes slowly.
  • Transient simulations with many Newton iterations per step.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Satisfies the lineax API; call _solve_impl directly for internal use.

factor_jacobian

Factor the Jacobian and return a numeric handle for repeated solves.

from_component_groups

Factory — delegates to :meth:KLUSplitSolver.from_component_groups.

init

Initialize the solver state (no-op for stateless solvers).

solve_dc

DC operating point via damped Newton-Raphson.

solve_dc_auto

DC Operating Point with automatic homotopy fallback.

solve_dc_checked

DC Operating Point with convergence status.

solve_dc_gmin

DC Operating Point via GMIN stepping (homotopy rescue).

solve_dc_source

DC Operating Point via source stepping (homotopy rescue).

solve_with_frozen_jacobian

Solve using a pre-computed numeric factorization.

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Satisfies the lineax API; call _solve_impl directly for internal use.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Satisfies the lineax API; call ``_solve_impl`` directly for internal use."""
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

factor_jacobian ¤

factor_jacobian(all_vals: Array) -> Array

Factor the Jacobian and return a numeric handle for repeated solves.

Parameters:

Name Type Description Default
all_vals Array

Flattened non-zero Jacobian values (COO format).

required

Returns:

Type Description
Array

Opaque numeric handle (int32 JAX scalar) to pass to

Array

meth:solve_with_frozen_jacobian. Must be freed with

Array

klujax.free_numeric after use to avoid C++ memory leaks.

Source code in circulax/solvers/linear.py
def factor_jacobian(self, all_vals: jax.Array) -> jax.Array:
    """Factor the Jacobian and return a numeric handle for repeated solves.

    Args:
        all_vals: Flattened non-zero Jacobian values (COO format).

    Returns:
        Opaque numeric handle (``int32`` JAX scalar) to pass to
        :meth:`solve_with_frozen_jacobian`.  Must be freed with
        ``klujax.free_numeric`` after use to avoid C++ memory leaks.

    """
    g_vals = jnp.full(self.ground_indices.shape[0], GROUND_STIFFNESS, dtype=all_vals.dtype)
    l_vals = jnp.full(self.sys_size, self.g_leak, dtype=all_vals.dtype)

    raw_vals = jnp.concatenate([all_vals, g_vals, l_vals])
    coalesced_vals = jax.ops.segment_sum(raw_vals, self.map_idx, num_segments=self.n_unique)

    return klujax.factor(self.u_rows, self.u_cols, coalesced_vals, self._handle_wrapper.handle)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any],
    num_vars: int,
    *,
    is_complex: bool = False,
    g_leak: float = 1e-09
) -> KLUSplitLinear

Factory — delegates to :meth:KLUSplitSolver.from_component_groups.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False, g_leak: float = 1e-9
) -> "KLUSplitLinear":
    """Factory — delegates to :meth:`KLUSplitSolver.from_component_groups`."""
    return super().from_component_groups(  # type: ignore[return-value]
        component_groups, num_vars, is_complex=is_complex, g_leak=g_leak
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (no-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (no-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC operating point via damped Newton-Raphson.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC operating point via damped Newton-Raphson.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).

    Returns:
        Converged solution vector.

    """
    y, _ = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)
    return y

solve_dc_auto ¤

solve_dc_auto(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point with automatic homotopy fallback.

Attempts a direct Newton solve first. If it fails to converge, falls back to GMIN stepping followed by source stepping — all inside a single JIT-compiled kernel via jax.lax.cond.

Strategy: 1. _run_newton — plain damped Newton from y_guess. 2. On failure: solve_dc_gmin (GMIN stepping) starting from y_guess, then solve_dc_source (source stepping) from the GMIN result.

Because jax.lax.cond evaluates both branches at trace time but only one at runtime, this compiles to a single kernel with no Python- level branching.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage for GMIN stepping (rescue branch).

0.01
n_gmin int

Number of GMIN steps in the rescue branch.

10
n_source int

Number of source steps in the rescue branch.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc_auto(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point with automatic homotopy fallback.

    Attempts a direct Newton solve first.  If it fails to converge, falls
    back to GMIN stepping followed by source stepping — all inside a single
    JIT-compiled kernel via ``jax.lax.cond``.

    Strategy:
    1.  ``_run_newton`` — plain damped Newton from ``y_guess``.
    2.  On failure: ``solve_dc_gmin`` (GMIN stepping) starting from
        ``y_guess``, then ``solve_dc_source`` (source stepping) from the
        GMIN result.

    Because ``jax.lax.cond`` evaluates both branches at *trace* time but
    only one at *runtime*, this compiles to a single kernel with no Python-
    level branching.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage for GMIN stepping (rescue branch).
        n_gmin: Number of GMIN steps in the rescue branch.
        n_source: Number of source steps in the rescue branch.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector.

    """
    y_direct, converged = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

    def rescue(_: None) -> jax.Array:
        y_gmin = self.solve_dc_gmin(
            component_groups,
            y_guess,
            g_start=g_start,
            n_steps=n_gmin,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )
        return self.solve_dc_source(
            component_groups,
            y_gmin,
            n_steps=n_source,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )

    return jax.lax.cond(converged, lambda _: y_direct, rescue, None)

solve_dc_checked ¤

solve_dc_checked(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> tuple[Array, Array]

DC Operating Point with convergence status.

Identical to :meth:solve_dc but additionally returns a boolean JAX scalar indicating whether the Newton-Raphson fixed-point iteration reported success. Because the flag is a JAX array (not a Python bool) it can be consumed inside compiled programs:

.. code-block:: python

y, converged = solver.solve_dc_checked(groups, y0)
# Outside JIT — inspect in Python:
if not converged:
    y = solver.solve_dc_gmin(groups, y0)
# Inside JIT — branch without Python-level control flow:
y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required
rtol float

Relative tolerance for optx.fixed_point.

1e-06
atol float

Absolute tolerance for optx.fixed_point.

1e-06
max_steps int

Maximum Newton iterations.

100

Returns:

Type Description
tuple[Array, Array]

(y, converged) — solution vector and boolean success flag.

Source code in circulax/solvers/linear.py
def solve_dc_checked(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> tuple[jax.Array, jax.Array]:
    """DC Operating Point with convergence status.

    Identical to :meth:`solve_dc` but additionally returns a boolean JAX
    scalar indicating whether the Newton-Raphson fixed-point iteration
    reported success.  Because the flag is a JAX array (not a Python bool)
    it can be consumed inside compiled programs:

    .. code-block:: python

        y, converged = solver.solve_dc_checked(groups, y0)
        # Outside JIT — inspect in Python:
        if not converged:
            y = solver.solve_dc_gmin(groups, y0)
        # Inside JIT — branch without Python-level control flow:
        y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).
        rtol: Relative tolerance for ``optx.fixed_point``.
        atol: Absolute tolerance for ``optx.fixed_point``.
        max_steps: Maximum Newton iterations.

    Returns:
        ``(y, converged)`` — solution vector and boolean success flag.

    """
    return self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

solve_dc_gmin ¤

solve_dc_gmin(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via GMIN stepping (homotopy rescue).

Steps the diagonal regularisation conductance g_leak logarithmically from g_start down to self.g_leak, using each converged solution as the warm start for the next step. The large initial g_leak linearises highly nonlinear components (diodes above threshold, lasers) that would otherwise cause Newton to diverge from a flat 0V start.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage conductance (large value, e.g. 1e-2).

0.01
n_steps int

Number of log-uniform steps from g_start to self.g_leak.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector after the full GMIN schedule.

Source code in circulax/solvers/linear.py
def solve_dc_gmin(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via GMIN stepping (homotopy rescue).

    Steps the diagonal regularisation conductance ``g_leak`` logarithmically
    from ``g_start`` down to ``self.g_leak``, using each converged solution
    as the warm start for the next step.  The large initial ``g_leak``
    linearises highly nonlinear components (diodes above threshold, lasers)
    that would otherwise cause Newton to diverge from a flat 0V start.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage conductance (large value, e.g. ``1e-2``).
        n_steps: Number of log-uniform steps from ``g_start`` to
            ``self.g_leak``.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector after the full GMIN schedule.

    """
    g_values = jnp.logspace(jnp.log10(g_start), jnp.log10(self.g_leak), n_steps)

    def step(y: jax.Array, g_leak_val: jax.Array) -> tuple[jax.Array, None]:
        stepped_solver = eqx.tree_at(lambda s: s.g_leak, self, g_leak_val)
        y_new, _ = stepped_solver._run_newton(  # noqa: SLF001
            component_groups, y, rtol=rtol, atol=atol, max_steps=max_steps
        )
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, g_values)
    return y_final

solve_dc_source ¤

solve_dc_source(
    component_groups: dict[str, Any],
    y_guess: Array,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via source stepping (homotopy rescue).

Ramps all source amplitudes (components tagged with amplitude_param) from 10 % to 100 % of their netlist values, using each converged solution as the warm start for the next step. This guides Newton through the nonlinear region without the large initial step from 0V to full excitation.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
n_steps int

Number of uniformly-spaced steps from 0.1 to 1.0.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector at full source amplitude.

Source code in circulax/solvers/linear.py
def solve_dc_source(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via source stepping (homotopy rescue).

    Ramps all source amplitudes (components tagged with ``amplitude_param``)
    from 10 % to 100 % of their netlist values, using each converged
    solution as the warm start for the next step.  This guides Newton
    through the nonlinear region without the large initial step from 0V to
    full excitation.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        n_steps: Number of uniformly-spaced steps from 0.1 to 1.0.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector at full source amplitude.

    """
    scales = jnp.linspace(0.1, 1.0, n_steps)

    def step(y: jax.Array, scale: jax.Array) -> tuple[jax.Array, None]:
        y_new, _ = self._run_newton(component_groups, y, source_scale=scale, rtol=rtol, atol=atol, max_steps=max_steps)
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, scales)
    return y_final

solve_with_frozen_jacobian ¤

solve_with_frozen_jacobian(residual: Array, numeric: Array) -> Solution

Solve using a pre-computed numeric factorization.

Parameters:

Name Type Description Default
residual Array

The right-hand side vector -F(y).

required
numeric Array

Handle returned by :meth:factor_jacobian.

required

Returns:

Type Description
Solution

class:lineax.Solution with the Newton step δy.

Source code in circulax/solvers/linear.py
def solve_with_frozen_jacobian(self, residual: jax.Array, numeric: jax.Array) -> lx.Solution:
    """Solve using a pre-computed numeric factorization.

    Args:
        residual: The right-hand side vector ``-F(y)``.
        numeric: Handle returned by :meth:`factor_jacobian`.

    Returns:
        :class:`lineax.Solution` with the Newton step ``δy``.

    """
    solution = klujax.solve_with_numeric(numeric, residual, self._handle_wrapper.handle)
    return lx.Solution(
        value=solution.reshape(residual.shape),
        result=lx.RESULTS.successful,
        state=None,
        stats={},
    )

KLUSplitQuadratic ¤

Bases: KLUSplitLinear

KLU split solver paired with full Newton for quadratic convergence via klu_refactor.

Extends :class:KLUSplitLinear with :meth:refactor_jacobian, which updates the numeric LU factorization in-place using klujax.refactor. The sparsity pattern is fixed for a given circuit topology, so KLU reuses the existing memory allocation and fill-reducing permutation — only the L/U values are recomputed. This gives full Newton (quadratic) convergence at a fraction of the cost of re-calling klu_factor at every iteration.

Use together with :class:~circulax.solvers.transient.RefactoringTransientSolver.

Best For
  • Large circuits on CPU with nonlinear devices where quadratic convergence is desired.
  • Transient simulations where the Jacobian changes significantly between Newton iterates.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Satisfies the lineax API; call _solve_impl directly for internal use.

factor_jacobian

Factor the Jacobian and return a numeric handle for repeated solves.

from_component_groups

Factory — delegates to :meth:KLUSplitSolver.from_component_groups.

init

Initialize the solver state (no-op for stateless solvers).

refactor_jacobian

Update the numeric factorization in-place with new Jacobian values.

solve_dc

DC operating point via damped Newton-Raphson.

solve_dc_auto

DC Operating Point with automatic homotopy fallback.

solve_dc_checked

DC Operating Point with convergence status.

solve_dc_gmin

DC Operating Point via GMIN stepping (homotopy rescue).

solve_dc_source

DC Operating Point via source stepping (homotopy rescue).

solve_with_frozen_jacobian

Solve using a pre-computed numeric factorization.

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Satisfies the lineax API; call _solve_impl directly for internal use.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Satisfies the lineax API; call ``_solve_impl`` directly for internal use."""
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

factor_jacobian ¤

factor_jacobian(all_vals: Array) -> Array

Factor the Jacobian and return a numeric handle for repeated solves.

Parameters:

Name Type Description Default
all_vals Array

Flattened non-zero Jacobian values (COO format).

required

Returns:

Type Description
Array

Opaque numeric handle (int32 JAX scalar) to pass to

Array

meth:solve_with_frozen_jacobian. Must be freed with

Array

klujax.free_numeric after use to avoid C++ memory leaks.

Source code in circulax/solvers/linear.py
def factor_jacobian(self, all_vals: jax.Array) -> jax.Array:
    """Factor the Jacobian and return a numeric handle for repeated solves.

    Args:
        all_vals: Flattened non-zero Jacobian values (COO format).

    Returns:
        Opaque numeric handle (``int32`` JAX scalar) to pass to
        :meth:`solve_with_frozen_jacobian`.  Must be freed with
        ``klujax.free_numeric`` after use to avoid C++ memory leaks.

    """
    g_vals = jnp.full(self.ground_indices.shape[0], GROUND_STIFFNESS, dtype=all_vals.dtype)
    l_vals = jnp.full(self.sys_size, self.g_leak, dtype=all_vals.dtype)

    raw_vals = jnp.concatenate([all_vals, g_vals, l_vals])
    coalesced_vals = jax.ops.segment_sum(raw_vals, self.map_idx, num_segments=self.n_unique)

    return klujax.factor(self.u_rows, self.u_cols, coalesced_vals, self._handle_wrapper.handle)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any],
    num_vars: int,
    *,
    is_complex: bool = False,
    g_leak: float = 1e-09
) -> KLUSplitQuadratic

Factory — delegates to :meth:KLUSplitSolver.from_component_groups.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False, g_leak: float = 1e-9
) -> "KLUSplitQuadratic":
    """Factory — delegates to :meth:`KLUSplitSolver.from_component_groups`."""
    return super().from_component_groups(  # type: ignore[return-value]
        component_groups, num_vars, is_complex=is_complex, g_leak=g_leak
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (no-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (no-op for stateless solvers)."""
    return None

refactor_jacobian ¤

refactor_jacobian(all_vals: Array, numeric: Array) -> Array

Update the numeric factorization in-place with new Jacobian values.

Reuses the existing memory allocation and fill-reducing permutation from the symbolic analysis; only the L/U values are recomputed. Faster than calling :meth:~KLUSplitLinear.factor_jacobian from scratch each Newton iteration.

Parameters:

Name Type Description Default
all_vals Array

Flattened non-zero Jacobian values (COO format).

required
numeric Array

Existing handle returned by :meth:~KLUSplitLinear.factor_jacobian.

required

Returns:

Type Description
Array

Refreshed numeric handle (same underlying C++ object, now connected in the

Array

XLA computation graph so the refactor cannot be eliminated as dead code).

Source code in circulax/solvers/linear.py
def refactor_jacobian(self, all_vals: jax.Array, numeric: jax.Array) -> jax.Array:
    """Update the numeric factorization in-place with new Jacobian values.

    Reuses the existing memory allocation and fill-reducing permutation from the
    symbolic analysis; only the L/U values are recomputed.  Faster than calling
    :meth:`~KLUSplitLinear.factor_jacobian` from scratch each Newton iteration.

    Args:
        all_vals: Flattened non-zero Jacobian values (COO format).
        numeric: Existing handle returned by :meth:`~KLUSplitLinear.factor_jacobian`.

    Returns:
        Refreshed numeric handle (same underlying C++ object, now connected in the
        XLA computation graph so the refactor cannot be eliminated as dead code).

    """
    g_vals = jnp.full(self.ground_indices.shape[0], GROUND_STIFFNESS, dtype=all_vals.dtype)
    l_vals = jnp.full(self.sys_size, self.g_leak, dtype=all_vals.dtype)
    raw_vals = jnp.concatenate([all_vals, g_vals, l_vals])
    coalesced_vals = jax.ops.segment_sum(raw_vals, self.map_idx, num_segments=self.n_unique)
    return klujax.refactor(self.u_rows, self.u_cols, coalesced_vals, numeric, self._handle_wrapper.handle)

solve_dc ¤

solve_dc(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC operating point via damped Newton-Raphson.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC operating point via damped Newton-Raphson.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).

    Returns:
        Converged solution vector.

    """
    y, _ = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)
    return y

solve_dc_auto ¤

solve_dc_auto(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point with automatic homotopy fallback.

Attempts a direct Newton solve first. If it fails to converge, falls back to GMIN stepping followed by source stepping — all inside a single JIT-compiled kernel via jax.lax.cond.

Strategy: 1. _run_newton — plain damped Newton from y_guess. 2. On failure: solve_dc_gmin (GMIN stepping) starting from y_guess, then solve_dc_source (source stepping) from the GMIN result.

Because jax.lax.cond evaluates both branches at trace time but only one at runtime, this compiles to a single kernel with no Python- level branching.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage for GMIN stepping (rescue branch).

0.01
n_gmin int

Number of GMIN steps in the rescue branch.

10
n_source int

Number of source steps in the rescue branch.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc_auto(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point with automatic homotopy fallback.

    Attempts a direct Newton solve first.  If it fails to converge, falls
    back to GMIN stepping followed by source stepping — all inside a single
    JIT-compiled kernel via ``jax.lax.cond``.

    Strategy:
    1.  ``_run_newton`` — plain damped Newton from ``y_guess``.
    2.  On failure: ``solve_dc_gmin`` (GMIN stepping) starting from
        ``y_guess``, then ``solve_dc_source`` (source stepping) from the
        GMIN result.

    Because ``jax.lax.cond`` evaluates both branches at *trace* time but
    only one at *runtime*, this compiles to a single kernel with no Python-
    level branching.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage for GMIN stepping (rescue branch).
        n_gmin: Number of GMIN steps in the rescue branch.
        n_source: Number of source steps in the rescue branch.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector.

    """
    y_direct, converged = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

    def rescue(_: None) -> jax.Array:
        y_gmin = self.solve_dc_gmin(
            component_groups,
            y_guess,
            g_start=g_start,
            n_steps=n_gmin,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )
        return self.solve_dc_source(
            component_groups,
            y_gmin,
            n_steps=n_source,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )

    return jax.lax.cond(converged, lambda _: y_direct, rescue, None)

solve_dc_checked ¤

solve_dc_checked(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> tuple[Array, Array]

DC Operating Point with convergence status.

Identical to :meth:solve_dc but additionally returns a boolean JAX scalar indicating whether the Newton-Raphson fixed-point iteration reported success. Because the flag is a JAX array (not a Python bool) it can be consumed inside compiled programs:

.. code-block:: python

y, converged = solver.solve_dc_checked(groups, y0)
# Outside JIT — inspect in Python:
if not converged:
    y = solver.solve_dc_gmin(groups, y0)
# Inside JIT — branch without Python-level control flow:
y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required
rtol float

Relative tolerance for optx.fixed_point.

1e-06
atol float

Absolute tolerance for optx.fixed_point.

1e-06
max_steps int

Maximum Newton iterations.

100

Returns:

Type Description
tuple[Array, Array]

(y, converged) — solution vector and boolean success flag.

Source code in circulax/solvers/linear.py
def solve_dc_checked(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> tuple[jax.Array, jax.Array]:
    """DC Operating Point with convergence status.

    Identical to :meth:`solve_dc` but additionally returns a boolean JAX
    scalar indicating whether the Newton-Raphson fixed-point iteration
    reported success.  Because the flag is a JAX array (not a Python bool)
    it can be consumed inside compiled programs:

    .. code-block:: python

        y, converged = solver.solve_dc_checked(groups, y0)
        # Outside JIT — inspect in Python:
        if not converged:
            y = solver.solve_dc_gmin(groups, y0)
        # Inside JIT — branch without Python-level control flow:
        y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).
        rtol: Relative tolerance for ``optx.fixed_point``.
        atol: Absolute tolerance for ``optx.fixed_point``.
        max_steps: Maximum Newton iterations.

    Returns:
        ``(y, converged)`` — solution vector and boolean success flag.

    """
    return self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

solve_dc_gmin ¤

solve_dc_gmin(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via GMIN stepping (homotopy rescue).

Steps the diagonal regularisation conductance g_leak logarithmically from g_start down to self.g_leak, using each converged solution as the warm start for the next step. The large initial g_leak linearises highly nonlinear components (diodes above threshold, lasers) that would otherwise cause Newton to diverge from a flat 0V start.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage conductance (large value, e.g. 1e-2).

0.01
n_steps int

Number of log-uniform steps from g_start to self.g_leak.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector after the full GMIN schedule.

Source code in circulax/solvers/linear.py
def solve_dc_gmin(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via GMIN stepping (homotopy rescue).

    Steps the diagonal regularisation conductance ``g_leak`` logarithmically
    from ``g_start`` down to ``self.g_leak``, using each converged solution
    as the warm start for the next step.  The large initial ``g_leak``
    linearises highly nonlinear components (diodes above threshold, lasers)
    that would otherwise cause Newton to diverge from a flat 0V start.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage conductance (large value, e.g. ``1e-2``).
        n_steps: Number of log-uniform steps from ``g_start`` to
            ``self.g_leak``.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector after the full GMIN schedule.

    """
    g_values = jnp.logspace(jnp.log10(g_start), jnp.log10(self.g_leak), n_steps)

    def step(y: jax.Array, g_leak_val: jax.Array) -> tuple[jax.Array, None]:
        stepped_solver = eqx.tree_at(lambda s: s.g_leak, self, g_leak_val)
        y_new, _ = stepped_solver._run_newton(  # noqa: SLF001
            component_groups, y, rtol=rtol, atol=atol, max_steps=max_steps
        )
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, g_values)
    return y_final

solve_dc_source ¤

solve_dc_source(
    component_groups: dict[str, Any],
    y_guess: Array,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via source stepping (homotopy rescue).

Ramps all source amplitudes (components tagged with amplitude_param) from 10 % to 100 % of their netlist values, using each converged solution as the warm start for the next step. This guides Newton through the nonlinear region without the large initial step from 0V to full excitation.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
n_steps int

Number of uniformly-spaced steps from 0.1 to 1.0.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector at full source amplitude.

Source code in circulax/solvers/linear.py
def solve_dc_source(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via source stepping (homotopy rescue).

    Ramps all source amplitudes (components tagged with ``amplitude_param``)
    from 10 % to 100 % of their netlist values, using each converged
    solution as the warm start for the next step.  This guides Newton
    through the nonlinear region without the large initial step from 0V to
    full excitation.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        n_steps: Number of uniformly-spaced steps from 0.1 to 1.0.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector at full source amplitude.

    """
    scales = jnp.linspace(0.1, 1.0, n_steps)

    def step(y: jax.Array, scale: jax.Array) -> tuple[jax.Array, None]:
        y_new, _ = self._run_newton(component_groups, y, source_scale=scale, rtol=rtol, atol=atol, max_steps=max_steps)
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, scales)
    return y_final

solve_with_frozen_jacobian ¤

solve_with_frozen_jacobian(residual: Array, numeric: Array) -> Solution

Solve using a pre-computed numeric factorization.

Parameters:

Name Type Description Default
residual Array

The right-hand side vector -F(y).

required
numeric Array

Handle returned by :meth:factor_jacobian.

required

Returns:

Type Description
Solution

class:lineax.Solution with the Newton step δy.

Source code in circulax/solvers/linear.py
def solve_with_frozen_jacobian(self, residual: jax.Array, numeric: jax.Array) -> lx.Solution:
    """Solve using a pre-computed numeric factorization.

    Args:
        residual: The right-hand side vector ``-F(y)``.
        numeric: Handle returned by :meth:`factor_jacobian`.

    Returns:
        :class:`lineax.Solution` with the Newton step ``δy``.

    """
    solution = klujax.solve_with_numeric(numeric, residual, self._handle_wrapper.handle)
    return lx.Solution(
        value=solution.reshape(residual.shape),
        result=lx.RESULTS.successful,
        state=None,
        stats={},
    )

RefactoringTransientSolver ¤

Bases: FactorizedTransientSolver

Transient solver with full Newton (quadratic) convergence using klujax.refactor.

At each timestep the Jacobian is factored once at the predicted state to allocate the numeric handle. Each Newton iteration then calls klujax.refactor — which reuses the existing memory and fill-reducing permutation but recomputes L/U values for the current iterate J(y_k) — followed by a triangular solve. This gives full quadratic Newton convergence at a fraction of the cost of re-factoring from scratch each iteration.

Convergence is quadratic so newton_max_steps is set to 20, matching :class:VectorizedTransientSolver. Adaptive damping min(1, 0.5 / max|δy|) is applied at each iteration to stabilise convergence in stiff or strongly nonlinear regions.

Requires :class:~circulax.solvers.linear.KLUSplitQuadratic as the linear_solver — use analyze_circuit(..., backend="klu_split").

SDIRK3FactorizedTransientSolver ¤

Bases: FactorizedTransientSolver

3rd-order A-stable SDIRK3 solver with frozen-Jacobian Newton across all stages.

Factors J_eff = dF/dy + (1/(γh))·dQ/dy once at the predictor state, then reuses it for all Newton iterations in all three SDIRK stages. This is the recommended backend for large sparse circuits — the single factorisation is shared across all stages because SDIRK's constant diagonal γ gives the same effective Jacobian at every stage.

SDIRK3RefactoringTransientSolver ¤

Bases: RefactoringTransientSolver

3rd-order A-stable SDIRK3 solver with KLU refactor at each Newton iteration.

Provides full quadratic Newton convergence via klujax.refactor within each stage, combined with SDIRK3 time discretisation for 3rd-order accuracy.

SDIRK3VectorizedTransientSolver ¤

Bases: VectorizedTransientSolver

3rd-order A-stable SDIRK3 solver using full Newton-Raphson at each stage.

Uses Alexander's L-stable 3-stage SDIRK tableau with the companion method. Each timestep performs 3 sequential Newton solves (one per stage) with the Jacobian reassembled at every iteration. The same solver_state 2-tuple (y_prev, dt_prev) as Backward Euler is used — SDIRK3 is a one-step method.

SparseSolver ¤

Bases: CircuitLinearSolver

Solves the system using JAX's Iterative BiCGStab solver.

Best For
  • Large Transient Simulations on GPU (uses previous step as warm start).
  • Systems where N is too large for Dense, but we need VMAP support.

Attributes:

Name Type Description
diag_mask Array

Mask to extract diagonal elements for preconditioning.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Satisfies the lineax API; call _solve_impl directly for internal use.

from_component_groups

Factory method to prepare indices and diagonal mask.

init

Initialize the solver state (no-op for stateless solvers).

solve_dc

DC operating point via damped Newton-Raphson.

solve_dc_auto

DC Operating Point with automatic homotopy fallback.

solve_dc_checked

DC Operating Point with convergence status.

solve_dc_gmin

DC Operating Point via GMIN stepping (homotopy rescue).

solve_dc_source

DC Operating Point via source stepping (homotopy rescue).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Satisfies the lineax API; call _solve_impl directly for internal use.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Satisfies the lineax API; call ``_solve_impl`` directly for internal use."""
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any],
    num_vars: int,
    *,
    is_complex: bool = False,
    g_leak: float = 1e-09
) -> SparseSolver

Factory method to prepare indices and diagonal mask.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False, g_leak: float = 1e-9
) -> "SparseSolver":
    """Factory method to prepare indices and diagonal mask."""
    rows, cols, ground_idxs, sys_size = _build_index_arrays(component_groups, num_vars, is_complex)
    return cls(
        static_rows=jnp.array(rows),
        static_cols=jnp.array(cols),
        diag_mask=jnp.array(rows == cols),
        sys_size=sys_size,
        ground_indices=jnp.array(ground_idxs),
        is_complex=is_complex,
        g_leak=g_leak,
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (no-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (no-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC operating point via damped Newton-Raphson.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC operating point via damped Newton-Raphson.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).

    Returns:
        Converged solution vector.

    """
    y, _ = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)
    return y

solve_dc_auto ¤

solve_dc_auto(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point with automatic homotopy fallback.

Attempts a direct Newton solve first. If it fails to converge, falls back to GMIN stepping followed by source stepping — all inside a single JIT-compiled kernel via jax.lax.cond.

Strategy: 1. _run_newton — plain damped Newton from y_guess. 2. On failure: solve_dc_gmin (GMIN stepping) starting from y_guess, then solve_dc_source (source stepping) from the GMIN result.

Because jax.lax.cond evaluates both branches at trace time but only one at runtime, this compiles to a single kernel with no Python- level branching.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage for GMIN stepping (rescue branch).

0.01
n_gmin int

Number of GMIN steps in the rescue branch.

10
n_source int

Number of source steps in the rescue branch.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector.

Source code in circulax/solvers/linear.py
def solve_dc_auto(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_gmin: int = 10,
    n_source: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point with automatic homotopy fallback.

    Attempts a direct Newton solve first.  If it fails to converge, falls
    back to GMIN stepping followed by source stepping — all inside a single
    JIT-compiled kernel via ``jax.lax.cond``.

    Strategy:
    1.  ``_run_newton`` — plain damped Newton from ``y_guess``.
    2.  On failure: ``solve_dc_gmin`` (GMIN stepping) starting from
        ``y_guess``, then ``solve_dc_source`` (source stepping) from the
        GMIN result.

    Because ``jax.lax.cond`` evaluates both branches at *trace* time but
    only one at *runtime*, this compiles to a single kernel with no Python-
    level branching.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage for GMIN stepping (rescue branch).
        n_gmin: Number of GMIN steps in the rescue branch.
        n_source: Number of source steps in the rescue branch.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector.

    """
    y_direct, converged = self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

    def rescue(_: None) -> jax.Array:
        y_gmin = self.solve_dc_gmin(
            component_groups,
            y_guess,
            g_start=g_start,
            n_steps=n_gmin,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )
        return self.solve_dc_source(
            component_groups,
            y_gmin,
            n_steps=n_source,
            rtol=rtol,
            atol=atol,
            max_steps=max_steps,
        )

    return jax.lax.cond(converged, lambda _: y_direct, rescue, None)

solve_dc_checked ¤

solve_dc_checked(
    component_groups: dict[str, Any],
    y_guess: Array,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> tuple[Array, Array]

DC Operating Point with convergence status.

Identical to :meth:solve_dc but additionally returns a boolean JAX scalar indicating whether the Newton-Raphson fixed-point iteration reported success. Because the flag is a JAX array (not a Python bool) it can be consumed inside compiled programs:

.. code-block:: python

y, converged = solver.solve_dc_checked(groups, y0)
# Outside JIT — inspect in Python:
if not converged:
    y = solver.solve_dc_gmin(groups, y0)
# Inside JIT — branch without Python-level control flow:
y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess vector (shape [N] or [2N]).

required
rtol float

Relative tolerance for optx.fixed_point.

1e-06
atol float

Absolute tolerance for optx.fixed_point.

1e-06
max_steps int

Maximum Newton iterations.

100

Returns:

Type Description
tuple[Array, Array]

(y, converged) — solution vector and boolean success flag.

Source code in circulax/solvers/linear.py
def solve_dc_checked(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> tuple[jax.Array, jax.Array]:
    """DC Operating Point with convergence status.

    Identical to :meth:`solve_dc` but additionally returns a boolean JAX
    scalar indicating whether the Newton-Raphson fixed-point iteration
    reported success.  Because the flag is a JAX array (not a Python bool)
    it can be consumed inside compiled programs:

    .. code-block:: python

        y, converged = solver.solve_dc_checked(groups, y0)
        # Outside JIT — inspect in Python:
        if not converged:
            y = solver.solve_dc_gmin(groups, y0)
        # Inside JIT — branch without Python-level control flow:
        y = jax.lax.cond(converged, lambda: y, lambda: solver.solve_dc_gmin(groups, y0))

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess vector (shape ``[N]`` or ``[2N]``).
        rtol: Relative tolerance for ``optx.fixed_point``.
        atol: Absolute tolerance for ``optx.fixed_point``.
        max_steps: Maximum Newton iterations.

    Returns:
        ``(y, converged)`` — solution vector and boolean success flag.

    """
    return self._run_newton(component_groups, y_guess, rtol=rtol, atol=atol, max_steps=max_steps)

solve_dc_gmin ¤

solve_dc_gmin(
    component_groups: dict[str, Any],
    y_guess: Array,
    g_start: float = 0.01,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via GMIN stepping (homotopy rescue).

Steps the diagonal regularisation conductance g_leak logarithmically from g_start down to self.g_leak, using each converged solution as the warm start for the next step. The large initial g_leak linearises highly nonlinear components (diodes above threshold, lasers) that would otherwise cause Newton to diverge from a flat 0V start.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
g_start float

Starting leakage conductance (large value, e.g. 1e-2).

0.01
n_steps int

Number of log-uniform steps from g_start to self.g_leak.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector after the full GMIN schedule.

Source code in circulax/solvers/linear.py
def solve_dc_gmin(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    g_start: float = 1e-2,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via GMIN stepping (homotopy rescue).

    Steps the diagonal regularisation conductance ``g_leak`` logarithmically
    from ``g_start`` down to ``self.g_leak``, using each converged solution
    as the warm start for the next step.  The large initial ``g_leak``
    linearises highly nonlinear components (diodes above threshold, lasers)
    that would otherwise cause Newton to diverge from a flat 0V start.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        g_start: Starting leakage conductance (large value, e.g. ``1e-2``).
        n_steps: Number of log-uniform steps from ``g_start`` to
            ``self.g_leak``.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector after the full GMIN schedule.

    """
    g_values = jnp.logspace(jnp.log10(g_start), jnp.log10(self.g_leak), n_steps)

    def step(y: jax.Array, g_leak_val: jax.Array) -> tuple[jax.Array, None]:
        stepped_solver = eqx.tree_at(lambda s: s.g_leak, self, g_leak_val)
        y_new, _ = stepped_solver._run_newton(  # noqa: SLF001
            component_groups, y, rtol=rtol, atol=atol, max_steps=max_steps
        )
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, g_values)
    return y_final

solve_dc_source ¤

solve_dc_source(
    component_groups: dict[str, Any],
    y_guess: Array,
    n_steps: int = 10,
    rtol: float = 1e-06,
    atol: float = 1e-06,
    max_steps: int = 100,
) -> Array

DC Operating Point via source stepping (homotopy rescue).

Ramps all source amplitudes (components tagged with amplitude_param) from 10 % to 100 % of their netlist values, using each converged solution as the warm start for the next step. This guides Newton through the nonlinear region without the large initial step from 0V to full excitation.

Implemented with jax.lax.scan — fully JIT/grad/vmap-compatible.

Parameters:

Name Type Description Default
component_groups dict[str, Any]

Compiled circuit components.

required
y_guess Array

Initial guess (typically jnp.zeros(sys_size)).

required
n_steps int

Number of uniformly-spaced steps from 0.1 to 1.0.

10
rtol float

Relative tolerance for each inner Newton solve.

1e-06
atol float

Absolute tolerance for each inner Newton solve.

1e-06
max_steps int

Max Newton iterations per step.

100

Returns:

Type Description
Array

Converged solution vector at full source amplitude.

Source code in circulax/solvers/linear.py
def solve_dc_source(
    self,
    component_groups: dict[str, Any],
    y_guess: jax.Array,
    n_steps: int = 10,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    max_steps: int = 100,
) -> jax.Array:
    """DC Operating Point via source stepping (homotopy rescue).

    Ramps all source amplitudes (components tagged with ``amplitude_param``)
    from 10 % to 100 % of their netlist values, using each converged
    solution as the warm start for the next step.  This guides Newton
    through the nonlinear region without the large initial step from 0V to
    full excitation.

    Implemented with ``jax.lax.scan`` — fully JIT/grad/vmap-compatible.

    Args:
        component_groups: Compiled circuit components.
        y_guess: Initial guess (typically ``jnp.zeros(sys_size)``).
        n_steps: Number of uniformly-spaced steps from 0.1 to 1.0.
        rtol: Relative tolerance for each inner Newton solve.
        atol: Absolute tolerance for each inner Newton solve.
        max_steps: Max Newton iterations per step.

    Returns:
        Converged solution vector at full source amplitude.

    """
    scales = jnp.linspace(0.1, 1.0, n_steps)

    def step(y: jax.Array, scale: jax.Array) -> tuple[jax.Array, None]:
        y_new, _ = self._run_newton(component_groups, y, source_scale=scale, rtol=rtol, atol=atol, max_steps=max_steps)
        return y_new, None

    y_final, _ = jax.lax.scan(step, y_guess, scales)
    return y_final

VectorizedTransientSolver ¤

Bases: AbstractSolver

Transient solver that works strictly on FLAT (Real) vectors.

Delegates complexity handling to the 'linear_solver' strategy.

analyze_circuit ¤

analyze_circuit(
    groups: list,
    num_vars: int,
    backend: str = "default",
    *,
    is_complex: bool = False,
    g_leak: float = 1e-09
) -> CircuitLinearSolver

Initializes a linear solver strategy for circuit analysis.

Factory that selects and configures the numerical backend for solving the linear system derived from a circuit's topology.

Parameters:

Name Type Description Default
groups list

A list of component groups that define the circuit's structure and properties.

required
num_vars int

The total number of variables in the linear system.

required
backend str

The name of the solver backend to use. Supported backends are 'klu', 'klu_split', 'dense', and 'sparse'. Defaults to 'default', which uses 'klu_split'.

'default'
is_complex bool

A flag indicating whether the circuit analysis involves complex numbers. Defaults to False.

False

Returns:

Name Type Description
CircuitLinearSolver CircuitLinearSolver

An instance of a circuit linear solver strategy

CircuitLinearSolver

configured for the specified backend and circuit parameters.

Raises:

Type Description
ValueError

If the specified backend is not supported.

Source code in circulax/solvers/linear.py
def analyze_circuit(
    groups: list, num_vars: int, backend: str = "default", *, is_complex: bool = False, g_leak: float = 1e-9
) -> CircuitLinearSolver:
    """Initializes a linear solver strategy for circuit analysis.

    Factory that selects and configures the numerical backend for solving the
    linear system derived from a circuit's topology.

    Args:
        groups (list): A list of component groups that define the circuit's
            structure and properties.
        num_vars (int): The total number of variables in the linear system.
        backend (str, optional): The name of the solver backend to use.
            Supported backends are 'klu', 'klu_split', 'dense', and 'sparse'.
            Defaults to 'default', which uses 'klu_split'.
        is_complex (bool, optional): A flag indicating whether the circuit
            analysis involves complex numbers. Defaults to False.

    Returns:
        CircuitLinearSolver: An instance of a circuit linear solver strategy
        configured for the specified backend and circuit parameters.

    Raises:
        ValueError: If the specified backend is not supported.

    """
    solver_class = backends.get(backend)
    if solver_class is None:
        msg = f"Unknown backend: '{backend}'. Available backends are {list(backends.keys())}"
        raise ValueError(msg)

    linear_strategy = solver_class.from_component_groups(groups, num_vars, is_complex=is_complex, g_leak=g_leak)

    return linear_strategy

assemble_system_complex ¤

assemble_system_complex(
    y_guess: Array,
    component_groups: dict,
    t1: float,
    dt: float,
    source_scale: float = 1.0,
    alpha: float = 1.0,
) -> tuple[Array, Array, Array]

Assemble the residual vectors and effective Jacobian values for an unrolled complex system.

The complex state vector is stored in unrolled (block) format: the first half of y_guess holds the real parts of all node voltages/states, the second half holds the imaginary parts. This avoids JAX's limited support for complex-valued sparse linear solvers by keeping all arithmetic real.

The Jacobian is split into four real blocks — RR, RI, IR, II — representing the partial derivatives of the real and imaginary residual components with respect to the real and imaginary state components respectively. The blocks are concatenated in RR→RI→IR→II order to match the sparsity index layout produced during compilation.

Parameters:

Name Type Description Default
y_guess Array

Unrolled state vector of shape (2 * num_vars,), where y_guess[:num_vars] are real parts and y_guess[num_vars:] are imaginary parts.

required
component_groups dict

Compiled component groups returned by :func:compile_netlist, keyed by group name.

required
t1 float

Time at which the system is being evaluated.

required
dt float

Timestep duration, used to scale the reactive Jacobian blocks.

required
source_scale float

Multiplicative scale applied to source amplitudes (components whose amplitude_param is set). Use 1.0 for a standard evaluation and values in (0, 1) during DC homotopy source stepping.

1.0
alpha float

Jacobian scaling factor for the reactive blocks. Use 1.0 for Backward Euler, the variable-step BDF2 α₀ coefficient for BDF2, or 1/γ for SDIRK3 stages.

1.0

Returns:

Type Description
Array

A three-tuple (total_f, total_q, jac_vals) where:

Array
  • total_f — assembled resistive residual in unrolled format, shape (2 * num_vars,).
Array
  • total_q — assembled reactive residual in unrolled format, shape (2 * num_vars,).
tuple[Array, Array, Array]
  • jac_vals — concatenated non-zero values of the four effective Jacobian blocks (RR, RI, IR, II) in group-sorted order.
Source code in circulax/solvers/assembly.py
def assemble_system_complex(
    y_guess: Array,
    component_groups: dict,
    t1: float,
    dt: float,
    source_scale: float = 1.0,
    alpha: float = 1.0,
) -> tuple[Array, Array, Array]:
    """Assemble the residual vectors and effective Jacobian values for an unrolled complex system.

    The complex state vector is stored in unrolled (block) format: the first
    half of ``y_guess`` holds the real parts of all node voltages/states, the
    second half holds the imaginary parts. This avoids JAX's limited support
    for complex-valued sparse linear solvers by keeping all arithmetic real.

    The Jacobian is split into four real blocks — RR, RI, IR, II — representing
    the partial derivatives of the real and imaginary residual components with
    respect to the real and imaginary state components respectively. The blocks
    are concatenated in RR→RI→IR→II order to match the sparsity index layout
    produced during compilation.

    Args:
        y_guess: Unrolled state vector of shape ``(2 * num_vars,)``, where
            ``y_guess[:num_vars]`` are real parts and ``y_guess[num_vars:]``
            are imaginary parts.
        component_groups: Compiled component groups returned by
            :func:`compile_netlist`, keyed by group name.
        t1: Time at which the system is being evaluated.
        dt: Timestep duration, used to scale the reactive Jacobian blocks.
        source_scale: Multiplicative scale applied to source amplitudes
            (components whose ``amplitude_param`` is set).  Use ``1.0``
            for a standard evaluation and values in ``(0, 1)`` during
            DC homotopy source stepping.
        alpha: Jacobian scaling factor for the reactive blocks.  Use ``1.0``
            for Backward Euler, the variable-step BDF2 ``α₀`` coefficient for
            BDF2, or ``1/γ`` for SDIRK3 stages.

    Returns:
        A three-tuple ``(total_f, total_q, jac_vals)`` where:

        - **total_f** — assembled resistive residual in unrolled format,
            shape ``(2 * num_vars,)``.
        - **total_q** — assembled reactive residual in unrolled format,
            shape ``(2 * num_vars,)``.
        - **jac_vals** — concatenated non-zero values of the four effective
            Jacobian blocks (RR, RI, IR, II) in group-sorted order.

    """
    sys_size = y_guess.shape[0]
    half_size = sys_size // 2
    y_real, y_imag = y_guess[:half_size], y_guess[half_size:]

    total_f = jnp.zeros(sys_size, dtype=jnp.float64)
    total_q = jnp.zeros(sys_size, dtype=jnp.float64)

    vals_blocks: list[list[Array]] = [[], [], [], []]

    for k in sorted(component_groups.keys()):
        group = component_groups[k]

        if group.is_fdomain:
            # F-domain component: evaluate admittance at f=0 (DC) — complex circuit path.
            v_r, v_i = y_real[group.var_indices], y_imag[group.var_indices]
            v_c = v_r + 1j * v_i  # (N, n_ports) complex
            Y_mats = jax.vmap(lambda p: group.physics_func(0.0, p))(group.params)
            i_c = jnp.einsum("nij,nj->ni", Y_mats, v_c)  # (N, n_ports) complex
            idx_r, idx_i = group.eq_indices, group.eq_indices + half_size
            total_f = total_f.at[idx_r].add(i_c.real).at[idx_i].add(i_c.imag)
            # Jacobian blocks: dI/dVr = Y.real, dI/dVi = -Y.imag (by Cauchy-Riemann)
            # For general complex Y: dIr/dVr = Yr, dIr/dVi = -Yi, dIi/dVr = Yi, dIi/dVi = Yr
            Yr = Y_mats.real  # (N, n_ports, n_ports)
            Yi = Y_mats.imag
            vals_blocks[0].append(Yr.reshape(-1))  # RR: dIr/dVr
            vals_blocks[1].append((-Yi).reshape(-1))  # RI: dIr/dVi
            vals_blocks[2].append(Yi.reshape(-1))  # IR: dIi/dVr
            vals_blocks[3].append(Yr.reshape(-1))  # II: dIi/dVi
            continue

        v_r, v_i = y_real[group.var_indices], y_imag[group.var_indices]

        ap = group.amplitude_param
        params = (
            eqx.tree_at(lambda p, _ap=ap: getattr(p, _ap), group.params, getattr(group.params, ap) * source_scale)
            if ap
            else group.params
        )

        physics_split = functools.partial(_complex_physics, group=group, t1=t1)

        (fr, fi, qr, qi), (dfr_r, dfi_r, dqr_r, dqi_r), (dfr_i, dfi_i, dqr_i, dqi_i) = jax.vmap(
            functools.partial(_primal_and_jac_complex, physics_split)
        )(v_r, v_i, params)

        idx_r, idx_i = group.eq_indices, group.eq_indices + half_size
        total_f = total_f.at[idx_r].add(fr).at[idx_i].add(fi)
        total_q = total_q.at[idx_r].add(qr).at[idx_i].add(qi)

        vals_blocks[0].append((dfr_r + (alpha / dt) * dqr_r).reshape(-1))  # RR
        vals_blocks[1].append((dfr_i + (alpha / dt) * dqr_i).reshape(-1))  # RI
        vals_blocks[2].append((dfi_r + (alpha / dt) * dqi_r).reshape(-1))  # IR
        vals_blocks[3].append((dfi_i + (alpha / dt) * dqi_i).reshape(-1))  # II

    all_vals = jnp.concatenate([jnp.concatenate(b) for b in vals_blocks])
    return total_f, total_q, all_vals

assemble_system_real ¤

assemble_system_real(
    y_guess: Array,
    component_groups: dict,
    t1: float,
    dt: float,
    source_scale: float = 1.0,
    alpha: float = 1.0,
) -> tuple[Array, Array, Array]

Assemble the residual vectors and effective Jacobian values for a real system.

For each component group, evaluates the physics at t1 and computes the forward-mode Jacobian via jax.jacfwd. The effective Jacobian combines the resistive and reactive contributions as J_eff = df/dy + (alpha/dt) * dq/dy, where alpha=1 recovers Backward Euler and alpha=3/2 (uniform step) gives BDF2.

Components are processed in sorted key order to ensure a deterministic non-zero layout in the sparse Jacobian, which is required for the factorisation step.

Parameters:

Name Type Description Default
y_guess Array

Current state vector of shape (sys_size,).

required
component_groups dict

Compiled component groups returned by :func:compile_netlist, keyed by group name.

required
t1 float

Time at which the system is being evaluated.

required
dt float

Timestep duration, used to scale the reactive Jacobian block.

required
source_scale float

Multiplicative scale applied to source amplitudes (components whose amplitude_param is set). Use 1.0 for a standard evaluation and values in (0, 1) during DC homotopy source stepping.

1.0
alpha float

Jacobian scaling factor for the reactive block. Use 1.0 for Backward Euler, the variable-step BDF2 α₀ coefficient for BDF2, or 1/γ for SDIRK3 stages.

1.0

Returns:

Type Description
Array

A three-tuple (total_f, total_q, jac_vals) where:

Array
  • total_f — assembled resistive residual, shape (sys_size,).
Array
  • total_q — assembled reactive residual, shape (sys_size,).
tuple[Array, Array, Array]
  • jac_vals — concatenated non-zero values of the effective Jacobian in group-sorted order, ready to be passed to the sparse linear solver.
Source code in circulax/solvers/assembly.py
def assemble_system_real(
    y_guess: Array,
    component_groups: dict,
    t1: float,
    dt: float,
    source_scale: float = 1.0,
    alpha: float = 1.0,
) -> tuple[Array, Array, Array]:
    """Assemble the residual vectors and effective Jacobian values for a real system.

    For each component group, evaluates the physics at ``t1`` and computes the
    forward-mode Jacobian via ``jax.jacfwd``. The effective Jacobian combines
    the resistive and reactive contributions as ``J_eff = df/dy + (alpha/dt) * dq/dy``,
    where ``alpha=1`` recovers Backward Euler and ``alpha=3/2`` (uniform step) gives BDF2.

    Components are processed in sorted key order to ensure a deterministic
    non-zero layout in the sparse Jacobian, which is required for the
    factorisation step.

    Args:
        y_guess: Current state vector of shape ``(sys_size,)``.
        component_groups: Compiled component groups returned by
            :func:`compile_netlist`, keyed by group name.
        t1: Time at which the system is being evaluated.
        dt: Timestep duration, used to scale the reactive Jacobian block.
        source_scale: Multiplicative scale applied to source amplitudes
            (components whose ``amplitude_param`` is set).  Use ``1.0``
            for a standard evaluation and values in ``(0, 1)`` during
            DC homotopy source stepping.
        alpha: Jacobian scaling factor for the reactive block.  Use ``1.0``
            for Backward Euler, the variable-step BDF2 ``α₀`` coefficient for
            BDF2, or ``1/γ`` for SDIRK3 stages.

    Returns:
        A three-tuple ``(total_f, total_q, jac_vals)`` where:

        - **total_f** — assembled resistive residual, shape ``(sys_size,)``.
        - **total_q** — assembled reactive residual, shape ``(sys_size,)``.
        - **jac_vals** — concatenated non-zero values of the effective Jacobian
            in group-sorted order, ready to be passed to the sparse linear solver.

    """
    sys_size = y_guess.shape[0]
    total_f = jnp.zeros(sys_size, dtype=y_guess.dtype)
    total_q = jnp.zeros(sys_size, dtype=y_guess.dtype)
    vals_list = []

    for k in sorted(component_groups.keys()):
        group = component_groups[k]

        if group.is_fdomain:
            # F-domain component: evaluate admittance at f=0 (DC).
            v_locs = y_guess[group.var_indices]
            Y_mats = jax.vmap(lambda p: group.physics_func(0.0, p))(group.params)
            Y_real = Y_mats.real  # (N, n_ports, n_ports)
            f_l = jnp.einsum("nij,nj->ni", Y_real, v_locs)  # (N, n_ports)
            total_f = total_f.at[group.eq_indices].add(f_l)
            vals_list.append(Y_real.reshape(-1))  # Jacobian = Y at DC
            continue

        v_locs = y_guess[group.var_indices]

        ap = group.amplitude_param
        params = (
            eqx.tree_at(lambda p, _ap=ap: getattr(p, _ap), group.params, getattr(group.params, ap) * source_scale)
            if ap
            else group.params
        )

        physics_at_t1 = functools.partial(_real_physics, group=group, t1=t1)

        (f_l, q_l), (df_l, dq_l) = jax.vmap(functools.partial(_primal_and_jac_real, physics_at_t1))(v_locs, params)

        total_f = total_f.at[group.eq_indices].add(f_l)
        total_q = total_q.at[group.eq_indices].add(q_l)
        j_eff = df_l + (alpha / dt) * dq_l
        vals_list.append(j_eff.reshape(-1))

    return total_f, total_q, jnp.concatenate(vals_list)

circuit_diffeqsolve ¤

circuit_diffeqsolve(
    terms: PyTree[AbstractTerm],
    solver,
    t0: RealScalarLike,
    t1: RealScalarLike,
    dt0: RealScalarLike,
    y0: PyTree[ArrayLike],
    args: PyTree[Any] = None,
    *,
    saveat: SaveAt = SaveAt(t1=True),
    stepsize_controller: AbstractStepSizeController = ConstantStepSize(),
    max_steps: int | None = 4096,
    throw: bool = True,
    checkpoints: int | None = None
) -> Solution

Stripped-down :func:diffrax.diffeqsolve for circuit simulation.

Identical calling convention to diffrax.diffeqsolve except:

  • adjoint, event, progress_meter, solver_state, controller_state, and made_jump arguments are absent.
  • t0 < t1 is assumed; no backward-time integration.
  • No SDE/CDE terms.
  • saveat.dense and saveat.steps are not supported.

checkpoints controls the number of binomial checkpoints used by RecursiveCheckpointAdjoint (None = auto from max_steps).

Source code in circulax/solvers/circuit_diffeq.py
@eqx.filter_jit
def circuit_diffeqsolve(
    terms: PyTree[AbstractTerm],
    solver,
    t0: RealScalarLike,
    t1: RealScalarLike,
    dt0: RealScalarLike,
    y0: PyTree[ArrayLike],
    args: PyTree[Any] = None,
    *,
    saveat: SaveAt = SaveAt(t1=True),
    stepsize_controller: AbstractStepSizeController = ConstantStepSize(),
    max_steps: int | None = 4096,
    throw: bool = True,
    checkpoints: int | None = None,
) -> Solution:
    """Stripped-down :func:`diffrax.diffeqsolve` for circuit simulation.

    Identical calling convention to ``diffrax.diffeqsolve`` except:

    - ``adjoint``, ``event``, ``progress_meter``, ``solver_state``,
      ``controller_state``, and ``made_jump`` arguments are absent.
    - ``t0 < t1`` is assumed; no backward-time integration.
    - No SDE/CDE terms.
    - ``saveat.dense`` and ``saveat.steps`` are not supported.

    ``checkpoints`` controls the number of binomial checkpoints used by
    ``RecursiveCheckpointAdjoint`` (``None`` = auto from ``max_steps``).
    """
    # ------------------------------------------------------------------
    # dtype promotion for times (same logic as diffrax)
    # ------------------------------------------------------------------
    timelikes = [t0, t1, dt0] + [s.ts for s in jtu.tree_leaves(saveat.subs, is_leaf=_is_subsaveat) if s.ts is not None]
    with jax.numpy_dtype_promotion("standard"):
        time_dtype = jnp.result_type(*timelikes)
    if jnp.issubdtype(time_dtype, jnp.integer):
        time_dtype = lxi.default_floating_dtype()

    t0 = jnp.asarray(t0, dtype=time_dtype)
    t1 = jnp.asarray(t1, dtype=time_dtype)
    dt0 = jnp.asarray(dt0, dtype=time_dtype)

    # Cast save ts to the same dtype
    def _cast_ts(saveat):
        out = [s.ts for s in jtu.tree_leaves(saveat.subs, is_leaf=_is_subsaveat)]
        return [x for x in out if x is not None]

    saveat = eqx.tree_at(_cast_ts, saveat, replace_fn=lambda ts: ts.astype(time_dtype))

    # Promote y0 dtype to be consistent with time (avoids weak-dtype issues)
    def _promote(yi):
        with jax.numpy_dtype_promotion("standard"):
            _dtype = jnp.result_type(yi, time_dtype)
        return jnp.asarray(yi, dtype=_dtype)

    y0 = jtu.tree_map(_promote, y0)

    # ------------------------------------------------------------------
    # Wrap terms with direction=1 (forward-only; WrapTerm still needed
    # so that solver.step receives a properly wrapped term object)
    # ------------------------------------------------------------------
    direction = jnp.asarray(1, dtype=time_dtype)

    def _wrap(term):
        assert isinstance(term, AbstractTerm)
        return WrapTerm(term, direction)

    terms = jtu.tree_map(
        _wrap,
        terms,
        is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, diffrax._term.MultiTerm),
    )

    # ------------------------------------------------------------------
    # Propagate PIDController tolerances into implicit solver root finder
    # ------------------------------------------------------------------
    if isinstance(solver, diffrax.AbstractImplicitSolver):
        from diffrax._root_finder import use_stepsize_tol

        def _get_tols(x):
            outs = []
            for attr in ("rtol", "atol", "norm"):
                if getattr(solver.root_finder, attr) is use_stepsize_tol:
                    outs.append(getattr(x, attr))
            return tuple(outs)

        if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
            solver = eqx.tree_at(
                lambda s: _get_tols(s.root_finder),
                solver,
                _get_tols(stepsize_controller),
            )

    # ------------------------------------------------------------------
    # Validate save timestamps
    # ------------------------------------------------------------------
    def _check_ts(ts):
        ts = eqxi.error_if(ts, ts[1:] < ts[:-1], "saveat.ts must be increasing.")
        ts = eqxi.error_if(ts, (ts > t1) | (ts < t0), "saveat.ts must lie between t0 and t1.")
        return ts

    saveat = eqx.tree_at(_cast_ts, saveat, replace_fn=_check_ts)

    # Wrap custom fn with direction (identity when direction=1, but kept for correctness)
    def _wrap_fn(x):
        if _is_subsaveat(x) and x.fn is not save_y:
            direction_fn = lambda t, y, a: x.fn(direction * t, y, a)
            return eqx.tree_at(lambda x: x.fn, x, direction_fn)
        return x

    saveat = jtu.tree_map(_wrap_fn, saveat, is_leaf=_is_subsaveat)

    # ------------------------------------------------------------------
    # Initialise solver & controller states
    # ------------------------------------------------------------------
    tprev = t0
    error_order = solver.error_order(terms)
    tnext, controller_state = stepsize_controller.init(terms, t0, t1, y0, dt0, args, solver.func, error_order)
    tnext = jnp.minimum(tnext, t1)
    solver_state = solver.init(terms, t0, tnext, y0, args)

    # ------------------------------------------------------------------
    # Allocate output buffers
    # ------------------------------------------------------------------
    def _allocate(subsaveat: SubSaveAt) -> SaveState:
        out_size = 0
        if subsaveat.t0:
            out_size += 1
        if subsaveat.ts is not None:
            out_size += len(subsaveat.ts)
        if subsaveat.t1:
            out_size += 1
        struct = eqx.filter_eval_shape(subsaveat.fn, t0, y0, args)
        ts = jnp.full(out_size, jnp.inf, dtype=time_dtype)
        ys = jtu.tree_map(lambda y: jnp.full((out_size,) + y.shape, jnp.inf, dtype=y.dtype), struct)
        return SaveState(ts=ts, ys=ys, save_index=0, saveat_ts_index=0)

    save_state = jtu.tree_map(_allocate, saveat.subs, is_leaf=_is_subsaveat)

    # ------------------------------------------------------------------
    # Build initial CircuitState
    # ------------------------------------------------------------------
    init_state = CircuitState(
        y=y0,
        tprev=tprev,
        tnext=tnext,
        solver_state=solver_state,
        controller_state=controller_state,
        result=RESULTS.successful,
        num_steps=0,
        num_accepted_steps=0,
        num_rejected_steps=0,
        save_state=save_state,
    )

    # ------------------------------------------------------------------
    # Choose while-loop variant (checkpointed for AD, lax otherwise)
    # ------------------------------------------------------------------
    if max_steps is None:
        inner_while_loop = ft.partial(_inner_loop, kind="lax")
        outer_while_loop = ft.partial(_outer_loop, kind="lax")
    else:
        inner_while_loop = ft.partial(_inner_loop, kind="checkpointed")
        outer_while_loop = ft.partial(_outer_loop, kind="checkpointed", checkpoints=checkpoints)

    # ------------------------------------------------------------------
    # Run the integration
    # ------------------------------------------------------------------
    final_state = _circuit_loop(
        solver=solver,
        stepsize_controller=stepsize_controller,
        saveat=saveat,
        t0=t0,
        t1=t1,
        max_steps=max_steps,
        terms=terms,
        args=args,
        init_state=init_state,
        inner_while_loop=inner_while_loop,
        outer_while_loop=outer_while_loop,
    )

    # ------------------------------------------------------------------
    # Build the Solution (compatible with diffrax.Solution)
    # ------------------------------------------------------------------
    ts = jtu.tree_map(lambda s: s.ts, final_state.save_state, is_leaf=_is_save_state)
    ys = jtu.tree_map(lambda s: s.ys, final_state.save_state, is_leaf=_is_save_state)

    stats = {
        "num_steps": final_state.num_steps,
        "num_accepted_steps": final_state.num_accepted_steps,
        "num_rejected_steps": final_state.num_rejected_steps,
        "max_steps": max_steps,
    }

    sol = Solution(
        t0=t0,
        t1=t1,
        ts=ts,
        ys=ys,
        interpolation=None,
        stats=stats,
        result=final_state.result,
        solver_state=None,
        controller_state=None,
        made_jump=None,
        event_mask=None,
    )

    if throw:
        sol = final_state.result.error_if(sol, jnp.invert(is_okay(final_state.result)))

    return sol

setup_ac_sweep ¤

setup_ac_sweep(
    groups: dict[str, Any], num_vars: int, port_nodes: list[int], *, z0: float = 50.0
) -> Callable[[Array, Array], Array]

Configure and return a callable for AC small-signal S-parameter sweep.

Linearises the circuit DAE at the DC operating point and sweeps over an array of frequencies, returning the complex S-parameter matrix at each frequency. The returned callable is compatible with :func:jax.jit and :func:jax.vmap.

The analysis solves Y(jω) · V = RHS at each frequency, where::

Y(jω) = G + jωC + Y_fdomain(f) + port_terminations + ground_penalty

G = ∂F/∂y and C = ∂Q/∂y are extracted once at the DC operating point. Y_fdomain(f) is the admittance contribution from frequency-domain components, re-evaluated at each frequency.

S-parameter convention — matched-load verification:

  • Matched load (Z_circuit = Z0) → S11 = 0
  • Open circuit (Z_circuit → ∞) → S11 = +1
  • Short circuit (Z_circuit = 0) → S11 = −1

Parameters:

Name Type Description Default
groups dict[str, Any]

Compiled component groups from :func:~circulax.compile_netlist.

required
num_vars int

Total number of scalar unknowns (second return value of :func:~circulax.compile_netlist).

required
port_nodes list[int]

Global node indices for each circuit port, in the desired port ordering. Obtain from the port-to-node map returned by :func:~circulax.compile_netlist::

_, _, pmap = compile_netlist(net_dict, models)
port_nodes = [pmap["R1,p1"], pmap["C1,p1"]]
required
z0 float

Reference impedance in ohms, applied uniformly to all ports. Defaults to 50.0.

50.0

Returns:

Type Description
Callable[[Array, Array], Array]

A callable run_ac(y_dc, freqs) -> S where:

Callable[[Array, Array], Array]
  • y_dc — DC operating point, shape (num_vars,).
Callable[[Array, Array], Array]
  • freqs — frequencies in Hz, shape (N_freqs,).
Callable[[Array, Array], Array]
  • S — S-parameter matrix, shape (N_freqs, N_ports, N_ports) complex128.
Callable[[Array, Array], Array]

Compatible with :func:jax.jit and :func:jax.vmap over y_dc.

Source code in circulax/solvers/ac_sweep.py
def setup_ac_sweep(
    groups: dict[str, Any],
    num_vars: int,
    port_nodes: list[int],
    *,
    z0: float = 50.0,
) -> Callable[[Array, Array], Array]:
    """Configure and return a callable for AC small-signal S-parameter sweep.

    Linearises the circuit DAE at the DC operating point and sweeps over an
    array of frequencies, returning the complex S-parameter matrix at each
    frequency.  The returned callable is compatible with :func:`jax.jit` and
    :func:`jax.vmap`.

    The analysis solves ``Y(jω) · V = RHS`` at each frequency, where::

        Y(jω) = G + jωC + Y_fdomain(f) + port_terminations + ground_penalty

    ``G = ∂F/∂y`` and ``C = ∂Q/∂y`` are extracted once at the DC operating
    point.  ``Y_fdomain(f)`` is the admittance contribution from
    frequency-domain components, re-evaluated at each frequency.

    **S-parameter convention** — matched-load verification:

    - Matched load (Z_circuit = Z0) → S11 = 0
    - Open circuit (Z_circuit → ∞) → S11 = +1
    - Short circuit (Z_circuit = 0) → S11 = −1

    Args:
        groups: Compiled component groups from :func:`~circulax.compile_netlist`.
        num_vars: Total number of scalar unknowns (second return value of
            :func:`~circulax.compile_netlist`).
        port_nodes: Global node indices for each circuit port, in the desired
            port ordering.  Obtain from the port-to-node map returned by
            :func:`~circulax.compile_netlist`::

                _, _, pmap = compile_netlist(net_dict, models)
                port_nodes = [pmap["R1,p1"], pmap["C1,p1"]]

        z0: Reference impedance in ohms, applied uniformly to all ports.
            Defaults to 50.0.

    Returns:
        A callable ``run_ac(y_dc, freqs) -> S`` where:

        - **y_dc** — DC operating point, shape ``(num_vars,)``.
        - **freqs** — frequencies in Hz, shape ``(N_freqs,)``.
        - **S** — S-parameter matrix, shape ``(N_freqs, N_ports, N_ports)``
            complex128.

        Compatible with :func:`jax.jit` and :func:`jax.vmap` over ``y_dc``.

    """
    if 0 in port_nodes:
        msg = "Port node cannot be the ground node (index 0)."
        raise ValueError(msg)

    # --- Pre-compute static COO index arrays (captured in closure) -----------
    static_rows, static_cols, ground_idxs, _ = _build_index_arrays(groups, num_vars, is_complex=False)
    static_rows_jax = jnp.array(static_rows)
    static_cols_jax = jnp.array(static_cols)
    ground_indices = jnp.array(ground_idxs)

    N_ports = len(port_nodes)
    port_nodes_arr = jnp.array(port_nodes, dtype=jnp.int32)

    # Pre-compute fdomain COO scatter index arrays (static integers — avoids
    # re-creating constant arrays on every trace inside vmap).
    fdomain_scatter: dict[str, tuple[Array, Array]] = {
        gk: (
            jnp.array(groups[gk].jac_rows).reshape(-1),
            jnp.array(groups[gk].jac_cols).reshape(-1),
        )
        for gk in sorted(groups)
        if groups[gk].is_fdomain
    }

    # -------------------------------------------------------------------------
    def run_ac(y_dc: Array, freqs: Array) -> Array:
        """Sweep AC frequencies and return the S-parameter matrix.

        Args:
            y_dc: DC operating point, shape ``(num_vars,)``.
            freqs: Frequencies in Hz, shape ``(N_freqs,)``.

        Returns:
            S-parameter matrix, shape ``(N_freqs, N_ports, N_ports)`` complex128.

        """
        G_vals, C_vals = assemble_gc_real(y_dc, groups)

        G_mat = jnp.zeros((num_vars, num_vars), dtype=jnp.float64)
        G_mat = G_mat.at[static_rows_jax, static_cols_jax].add(G_vals)

        C_mat = jnp.zeros((num_vars, num_vars), dtype=jnp.float64)
        C_mat = C_mat.at[static_rows_jax, static_cols_jax].add(C_vals)

        # RHS column p: 2/z0 at port_nodes[p], zero elsewhere.
        RHS = jnp.zeros((num_vars, N_ports), dtype=jnp.complex128)
        RHS = RHS.at[port_nodes_arr, jnp.arange(N_ports)].set(2.0 / z0)

        def _solve_one_freq(f: Array) -> Array:
            omega = 2.0 * jnp.pi * f
            Y = G_mat.astype(jnp.complex128) + 1j * omega * C_mat.astype(jnp.complex128)

            # Add frequency-domain component admittances.
            # The Python loop is over static strings — safe inside vmap.
            for gk, (rows_fd, cols_fd) in fdomain_scatter.items():
                group_fd = groups[gk]
                Y_mats = jax.vmap(functools.partial(group_fd.physics_func, f))(group_fd.params)
                Y = Y.at[rows_fd, cols_fd].add(Y_mats.reshape(-1))

            # Port terminations: Y[port, port] += 1/z0 for each port.
            Y = Y.at[port_nodes_arr, port_nodes_arr].add(1.0 / z0)

            # Ground stiffness: enforces V[ground] ≈ 0.
            Y = Y.at[ground_indices, ground_indices].add(GROUND_STIFFNESS)

            # Single batched solve: factor once, N_ports back-substitutions.
            V = jnp.linalg.solve(Y, RHS)  # (num_vars, N_ports)

            V_ports = V[port_nodes_arr, :]  # (N_ports, N_ports)
            return V_ports - jnp.eye(N_ports, dtype=jnp.complex128)

        return jax.vmap(_solve_one_freq)(freqs)

    return run_ac

setup_harmonic_balance ¤

setup_harmonic_balance(
    groups: dict[str, Any],
    num_vars: int,
    freq: float,
    num_harmonics: int = 5,
    *,
    is_complex: bool = False,
    g_leak: float = 1e-09,
    osc_node: int | None = None,
    amplitude_tries: Array | None = None
) -> Callable[[Array], tuple[Array, Array]]

Configure and return a Harmonic Balance callable.

Parameters:

Name Type Description Default
groups dict[str, Any]

Compiled component groups from :func:~circulax.compiler.compile_netlist.

required
num_vars int

Total number of state variables.

required
freq float

Fundamental frequency in Hz.

required
num_harmonics int

Number of harmonics; uses K = 2*N+1 time points.

5
is_complex bool

True for photonic circuits (unrolled [re | im] state).

False
g_leak float

Diagonal regularisation to prevent singular Jacobians at floating nodes.

1e-09
osc_node int | None

State-vector index of the oscillator node (from net_map). Enables automatic multi-start across amplitude_tries. No effect if y_flat_init is supplied.

None
amplitude_tries Array | None

Amplitudes (V) to try when osc_node is set. Defaults to [0.3, 0.7, 1.5, 3.0, 7.0, 20.0].

None

Returns:

Type Description
Callable[[Array], tuple[Array, Array]]

run_hb(y_dc, *, y_flat_init=None, max_iter=50, tol=1e-6) -> (y_time, y_freq),

Callable[[Array], tuple[Array, Array]]

compatible with jax.jit.

Source code in circulax/solvers/harmonic_balance.py
def setup_harmonic_balance(
    groups: dict[str, Any],
    num_vars: int,
    freq: float,
    num_harmonics: int = 5,
    *,
    is_complex: bool = False,
    g_leak: float = 1e-9,
    osc_node: int | None = None,
    amplitude_tries: Array | None = None,
) -> Callable[[Array], tuple[Array, Array]]:
    """Configure and return a Harmonic Balance callable.

    Args:
        groups: Compiled component groups from :func:`~circulax.compiler.compile_netlist`.
        num_vars: Total number of state variables.
        freq: Fundamental frequency in Hz.
        num_harmonics: Number of harmonics; uses K = 2*N+1 time points.
        is_complex: ``True`` for photonic circuits (unrolled ``[re | im]`` state).
        g_leak: Diagonal regularisation to prevent singular Jacobians at floating nodes.
        osc_node: State-vector index of the oscillator node (from ``net_map``).
            Enables automatic multi-start across ``amplitude_tries``. No effect
            if ``y_flat_init`` is supplied.
        amplitude_tries: Amplitudes (V) to try when ``osc_node`` is set.
            Defaults to ``[0.3, 0.7, 1.5, 3.0, 7.0, 20.0]``.

    Returns:
        ``run_hb(y_dc, *, y_flat_init=None, max_iter=50, tol=1e-6) -> (y_time, y_freq)``,
        compatible with ``jax.jit``.

    """
    _, _, ground_idxs, sys_size = _build_index_arrays(groups, num_vars, is_complex=is_complex)
    ground_indices = jnp.array(ground_idxs)

    K = 2 * num_harmonics + 1
    omega = 2.0 * jnp.pi * freq
    t_points = jnp.linspace(0.0, 1.0 / freq, K, endpoint=False)

    _amplitude_tries: Array = (
        amplitude_tries
        if amplitude_tries is not None
        else jnp.array([0.3, 0.7, 1.5, 3.0, 7.0, 20.0], dtype=jnp.float64)
    )
    _phase = 2.0 * jnp.pi * jnp.arange(K, dtype=jnp.float64) / K  # (K,)

    def run_hb(
        y_dc: Array,
        *,
        y_flat_init: Array | None = None,
        max_iter: int = 50,
        tol: float = 1e-6,
    ) -> tuple[Array, Array]:
        """Find the periodic steady state via Newton–Raphson on the HB residual.

        Args:
            y_dc: DC operating point, shape ``(sys_size,)``. Used as the
                initial guess (zero AC amplitude). Obtain from
                :meth:`~circulax.solvers.CircuitLinearSolver.solve_dc`.
            y_flat_init: Optional flat initial waveform, shape ``(K * sys_size,)``.
                Overrides the automatic multi-start strategy even when
                ``osc_node`` was set at setup time.
            max_iter: Maximum number of Newton iterations.
            tol: Convergence tolerance on the infinity norm of the residual.

        Returns:
            A two-tuple ``(y_time, y_freq)`` where:

            - **y_time** -- periodic waveform samples, shape ``(K, sys_size)``.
              The k-th row is the state at time ``t_k = k*T/K``.
            - **y_freq** -- normalised Fourier coefficients, shape
              ``(N_harm+1, sys_size)`` complex. ``y_freq[0]`` is the DC
              component, ``y_freq[1]`` is the fundamental, and so on.
              Two-sided amplitude at harmonic k>=1 is ``2 * |y_freq[k]|``.

        """

        def newton_step(y_flat: Array, grps: Any) -> Array:
            def _res(y: Array) -> Array:
                return _hb_residual(
                    y.reshape(K, sys_size), grps, t_points, omega, ground_indices, is_complex=is_complex
                ).flatten()

            r = _res(y_flat)
            J = jax.jacobian(_res)(y_flat)
            # Regularise: prevents singular Jacobian when floating nodes have
            # no DC path to ground (mirrors the DC solver's g_leak).
            J = J + g_leak * jnp.eye(J.shape[0], dtype=J.dtype)
            delta = jnp.linalg.solve(J, -r)
            # Voltage damping mirrors the DC solver: limits the maximum step to
            # avoid crashing exponential nonlinearities (diodes, transistors).
            max_change = jnp.max(jnp.abs(delta))
            damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))
            return y_flat + delta * damping

        def _solve(y_flat: Array) -> tuple[Array, Array]:
            # groups MUST be passed via args= — ImplicitAdjoint differentiates
            # only through explicit args; closure-captured variables give zero gradients.
            hb_solver = optx.FixedPointIteration(rtol=tol, atol=tol)
            sol = optx.fixed_point(
                newton_step, hb_solver, y_flat, args=groups,
                max_steps=max_iter, throw=False,
            )
            y_flat_sol = sol.value
            y_time_sol = y_flat_sol.reshape(K, sys_size)
            # Normalise by K so that y_freq[k] is the true complex amplitude.
            y_freq_sol = jnp.fft.rfft(y_time_sol, axis=0) / K
            return y_time_sol, y_freq_sol

        # Autonomous multi-start: y=0 is always a trivial fixed point for oscillators,
        # so low-amplitude starts converge to zero. Try several amplitudes and keep
        # the one with the largest fundamental — at least one will be above the basin.
        if osc_node is not None and y_flat_init is None:
            def _single_start(A: Array) -> tuple[Array, Array]:
                y0 = (jnp.zeros(K * sys_size, dtype=jnp.float64)
                      .at[jnp.arange(K) * sys_size + osc_node].set(A * jnp.sin(_phase)))
                return _solve(y0)

            y_times, y_freqs = jax.vmap(_single_start)(_amplitude_tries)
            best = jnp.argmax(jnp.abs(y_freqs[:, 1, osc_node]))
            return jnp.take(y_times, best, axis=0), jnp.take(y_freqs, best, axis=0)

        y_flat = y_flat_init if y_flat_init is not None else jnp.tile(y_dc, K)
        return _solve(y_flat)

    return run_hb

setup_transient ¤

setup_transient(
    groups: list,
    linear_strategy: CircuitLinearSolver,
    transient_solver: AbstractSolver = None,
) -> Callable[..., Solution]

Configures and returns a function for executing transient analysis.

This function acts as a factory, preparing a transient solver that is pre-configured with the circuit's linear strategy. It returns a callable that executes the time-domain simulation using diffrax.diffeqsolve.

Parameters:

Name Type Description Default
groups list

A list of component groups that define the circuit.

required
linear_strategy CircuitLinearSolver

The configured linear solver strategy, typically obtained from analyze_circuit.

required
transient_solver optional

The transient solver class to use. If None, BDF2VectorizedTransientSolver will be used.

None

Returns:

Type Description
Callable[..., Solution]

Callable[..., Any]: A function that executes the transient analysis.

Callable[..., Solution]

This returned function accepts the following arguments:

t0 (float): The start time of the simulation. t1 (float): The end time of the simulation. dt0 (float): The initial time step for the solver. y0 (ArrayLike): The initial state vector of the system. saveat (diffrax.SaveAt, optional): Specifies time points at which to save the solution. Defaults to None. max_steps (int, optional): The maximum number of steps the solver can take. Defaults to 100000. throw (bool, optional): If True, the solver will raise an error on failure. Defaults to False. term (diffrax.AbstractTerm, optional): The term defining the ODE. Defaults to a zero-value ODETerm. stepsize_controller (diffrax.AbstractStepSizeController, optional): The step size controller. Defaults to ConstantStepSize(). **kwargs: Additional keyword arguments to pass directly to diffrax.diffeqsolve.

Source code in circulax/solvers/transient.py
def setup_transient(
    groups: list, linear_strategy: CircuitLinearSolver, transient_solver: AbstractSolver = None
) -> Callable[..., diffrax.Solution]:
    """Configures and returns a function for executing transient analysis.

    This function acts as a factory, preparing a transient solver that is
    pre-configured with the circuit's linear strategy. It returns a callable
    that executes the time-domain simulation using `diffrax.diffeqsolve`.

    Args:
        groups (list): A list of component groups that define the circuit.
        linear_strategy (CircuitLinearSolver): The configured linear solver
            strategy, typically obtained from `analyze_circuit`.
        transient_solver (optional): The transient solver class to use.
            If None, `BDF2VectorizedTransientSolver` will be used.

    Returns:
        Callable[..., Any]: A function that executes the transient analysis.
        This returned function accepts the following arguments:

            t0 (float): The start time of the simulation.
            t1 (float): The end time of the simulation.
            dt0 (float): The initial time step for the solver.
            y0 (ArrayLike): The initial state vector of the system.
            saveat (diffrax.SaveAt, optional): Specifies time points at which
                to save the solution. Defaults to None.
            max_steps (int, optional): The maximum number of steps the solver
                can take. Defaults to 100000.
            throw (bool, optional): If True, the solver will raise an error on
                failure. Defaults to False.
            term (diffrax.AbstractTerm, optional): The term defining the ODE.
                Defaults to a zero-value ODETerm.
            stepsize_controller (diffrax.AbstractStepSizeController, optional):
                The step size controller. Defaults to `ConstantStepSize()`.
            **kwargs: Additional keyword arguments to pass directly to
                `diffrax.diffeqsolve`.

    """
    fdomain_names = [g.name for g in groups.values() if getattr(g, "is_fdomain", False)]
    if fdomain_names:
        msg = (
            "Frequency-domain components cannot be used in transient simulation "
            "(time-domain convolution is not supported). "
            f"Offending groups: {fdomain_names}. "
            "Use setup_harmonic_balance() instead."
        )
        raise RuntimeError(msg)

    if transient_solver is None:
        # Pick the best BDF2 variant the linear solver supports.
        if hasattr(linear_strategy, "refactor_jacobian"):
            transient_solver = BDF2RefactoringTransientSolver
        elif hasattr(linear_strategy, "factor_jacobian"):
            transient_solver = BDF2FactorizedTransientSolver
        else:
            transient_solver = BDF2VectorizedTransientSolver

    import inspect
    tsolver = transient_solver(linear_solver=linear_strategy) if inspect.isclass(transient_solver) else transient_solver

    sys_size = linear_strategy.sys_size // 2 if linear_strategy.is_complex else linear_strategy.sys_size

    def _execute_transient(
        *,
        t0: float,
        t1: float,
        dt0: float,
        y0: ArrayLike,
        saveat: diffrax.SaveAt = None,
        max_steps: int = 100000,
        throw: bool = False,
        **kwargs: Any,
    ) -> diffrax.Solution:
        """Executes the transient simulation for the pre-configured circuit."""
        term = kwargs.pop("term", diffrax.ODETerm(lambda t, y, args: jnp.zeros_like(y)))
        solver = kwargs.pop("solver", tsolver)
        args = kwargs.pop("args", (groups, sys_size))
        stepsize_controller = kwargs.pop("stepsize_controller", ConstantStepSize())
        checkpoints = kwargs.pop("checkpoints", None)

        sol = circuit_diffeqsolve(
            terms=term,
            solver=solver,
            t0=t0,
            t1=t1,
            dt0=dt0,
            y0=y0,
            args=args,
            saveat=saveat,
            max_steps=max_steps,
            throw=throw,
            stepsize_controller=stepsize_controller,
            checkpoints=checkpoints,
        )

        return sol

    return _execute_transient