Skip to content

harmonic_balance ¤

Harmonic Balance solver for periodic steady-state circuit analysis.

Harmonic Balance (HB) finds the periodic steady-state solution of a nonlinear circuit driven by periodic sources, without time-stepping to steady state.

The circuit DAE F(y) + dQ/dt = 0 is solved by representing y(t) as K equally- spaced time samples over one period. The HB residual at harmonic k is:

R_k = FFT{F(y(t))}[k] + jkw0 * FFT{Q(y(t))}[k] = 0

JAX makes this clean: jax.vmap evaluates circuit physics at all K time points in parallel, jnp.fft.rfft/irfft handle the frequency transforms, and jax.jacobian provides the exact Newton Jacobian with no manual derivation.

Example::

groups, num_vars, _ = compile_netlist(netlist, models)
linear_strat = analyze_circuit(groups, num_vars)
y_dc = linear_strat.solve_dc(groups, jnp.zeros(num_vars))

run_hb = setup_harmonic_balance(groups, num_vars, freq=1e6, num_harmonics=5)
y_time, y_freq = run_hb(y_dc)

# JIT-able for repeated calls with different initial conditions:
y_time, y_freq = jax.jit(run_hb)(y_dc)

# y_time: shape (K, num_vars) -- waveform at K equally-spaced time points
# y_freq: shape (N+1, num_vars) complex -- normalised Fourier coefficients

Functions:

Name Description
setup_harmonic_balance

Configure and return a Harmonic Balance callable.

setup_harmonic_balance ¤

setup_harmonic_balance(
    groups: dict[str, Any],
    num_vars: int,
    freq: float,
    num_harmonics: int = 5,
    *,
    is_complex: bool = False,
    g_leak: float = 1e-09,
    osc_node: int | None = None,
    amplitude_tries: Array | None = None
) -> Callable[[Array], tuple[Array, Array]]

Configure and return a Harmonic Balance callable.

Parameters:

Name Type Description Default
groups dict[str, Any]

Compiled component groups from :func:~circulax.compiler.compile_netlist.

required
num_vars int

Total number of state variables.

required
freq float

Fundamental frequency in Hz.

required
num_harmonics int

Number of harmonics; uses K = 2*N+1 time points.

5
is_complex bool

True for photonic circuits (unrolled [re | im] state).

False
g_leak float

Diagonal regularisation to prevent singular Jacobians at floating nodes.

1e-09
osc_node int | None

State-vector index of the oscillator node (from net_map). Enables automatic multi-start across amplitude_tries. No effect if y_flat_init is supplied.

None
amplitude_tries Array | None

Amplitudes (V) to try when osc_node is set. Defaults to [0.3, 0.7, 1.5, 3.0, 7.0, 20.0].

None

Returns:

Type Description
Callable[[Array], tuple[Array, Array]]

run_hb(y_dc, *, y_flat_init=None, max_iter=50, tol=1e-6) -> (y_time, y_freq),

Callable[[Array], tuple[Array, Array]]

compatible with jax.jit.

Source code in circulax/solvers/harmonic_balance.py
def setup_harmonic_balance(
    groups: dict[str, Any],
    num_vars: int,
    freq: float,
    num_harmonics: int = 5,
    *,
    is_complex: bool = False,
    g_leak: float = 1e-9,
    osc_node: int | None = None,
    amplitude_tries: Array | None = None,
) -> Callable[[Array], tuple[Array, Array]]:
    """Configure and return a Harmonic Balance callable.

    Args:
        groups: Compiled component groups from :func:`~circulax.compiler.compile_netlist`.
        num_vars: Total number of state variables.
        freq: Fundamental frequency in Hz.
        num_harmonics: Number of harmonics; uses K = 2*N+1 time points.
        is_complex: ``True`` for photonic circuits (unrolled ``[re | im]`` state).
        g_leak: Diagonal regularisation to prevent singular Jacobians at floating nodes.
        osc_node: State-vector index of the oscillator node (from ``net_map``).
            Enables automatic multi-start across ``amplitude_tries``. No effect
            if ``y_flat_init`` is supplied.
        amplitude_tries: Amplitudes (V) to try when ``osc_node`` is set.
            Defaults to ``[0.3, 0.7, 1.5, 3.0, 7.0, 20.0]``.

    Returns:
        ``run_hb(y_dc, *, y_flat_init=None, max_iter=50, tol=1e-6) -> (y_time, y_freq)``,
        compatible with ``jax.jit``.

    """
    _, _, ground_idxs, sys_size = _build_index_arrays(groups, num_vars, is_complex=is_complex)
    ground_indices = jnp.array(ground_idxs)

    K = 2 * num_harmonics + 1
    omega = 2.0 * jnp.pi * freq
    t_points = jnp.linspace(0.0, 1.0 / freq, K, endpoint=False)

    _amplitude_tries: Array = (
        amplitude_tries
        if amplitude_tries is not None
        else jnp.array([0.3, 0.7, 1.5, 3.0, 7.0, 20.0], dtype=jnp.float64)
    )
    _phase = 2.0 * jnp.pi * jnp.arange(K, dtype=jnp.float64) / K  # (K,)

    def run_hb(
        y_dc: Array,
        *,
        y_flat_init: Array | None = None,
        max_iter: int = 50,
        tol: float = 1e-6,
    ) -> tuple[Array, Array]:
        """Find the periodic steady state via Newton–Raphson on the HB residual.

        Args:
            y_dc: DC operating point, shape ``(sys_size,)``. Used as the
                initial guess (zero AC amplitude). Obtain from
                :meth:`~circulax.solvers.CircuitLinearSolver.solve_dc`.
            y_flat_init: Optional flat initial waveform, shape ``(K * sys_size,)``.
                Overrides the automatic multi-start strategy even when
                ``osc_node`` was set at setup time.
            max_iter: Maximum number of Newton iterations.
            tol: Convergence tolerance on the infinity norm of the residual.

        Returns:
            A two-tuple ``(y_time, y_freq)`` where:

            - **y_time** -- periodic waveform samples, shape ``(K, sys_size)``.
              The k-th row is the state at time ``t_k = k*T/K``.
            - **y_freq** -- normalised Fourier coefficients, shape
              ``(N_harm+1, sys_size)`` complex. ``y_freq[0]`` is the DC
              component, ``y_freq[1]`` is the fundamental, and so on.
              Two-sided amplitude at harmonic k>=1 is ``2 * |y_freq[k]|``.

        """

        def newton_step(y_flat: Array, grps: Any) -> Array:
            def _res(y: Array) -> Array:
                return _hb_residual(
                    y.reshape(K, sys_size), grps, t_points, omega, ground_indices, is_complex=is_complex
                ).flatten()

            r = _res(y_flat)
            J = jax.jacobian(_res)(y_flat)
            # Regularise: prevents singular Jacobian when floating nodes have
            # no DC path to ground (mirrors the DC solver's g_leak).
            J = J + g_leak * jnp.eye(J.shape[0], dtype=J.dtype)
            delta = jnp.linalg.solve(J, -r)
            # Voltage damping mirrors the DC solver: limits the maximum step to
            # avoid crashing exponential nonlinearities (diodes, transistors).
            max_change = jnp.max(jnp.abs(delta))
            damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))
            return y_flat + delta * damping

        def _solve(y_flat: Array) -> tuple[Array, Array]:
            # groups MUST be passed via args= — ImplicitAdjoint differentiates
            # only through explicit args; closure-captured variables give zero gradients.
            hb_solver = optx.FixedPointIteration(rtol=tol, atol=tol)
            sol = optx.fixed_point(
                newton_step, hb_solver, y_flat, args=groups,
                max_steps=max_iter, throw=False,
            )
            y_flat_sol = sol.value
            y_time_sol = y_flat_sol.reshape(K, sys_size)
            # Normalise by K so that y_freq[k] is the true complex amplitude.
            y_freq_sol = jnp.fft.rfft(y_time_sol, axis=0) / K
            return y_time_sol, y_freq_sol

        # Autonomous multi-start: y=0 is always a trivial fixed point for oscillators,
        # so low-amplitude starts converge to zero. Try several amplitudes and keep
        # the one with the largest fundamental — at least one will be above the basin.
        if osc_node is not None and y_flat_init is None:
            def _single_start(A: Array) -> tuple[Array, Array]:
                y0 = (jnp.zeros(K * sys_size, dtype=jnp.float64)
                      .at[jnp.arange(K) * sys_size + osc_node].set(A * jnp.sin(_phase)))
                return _solve(y0)

            y_times, y_freqs = jax.vmap(_single_start)(_amplitude_tries)
            best = jnp.argmax(jnp.abs(y_freqs[:, 1, osc_node]))
            return jnp.take(y_times, best, axis=0), jnp.take(y_freqs, best, axis=0)

        y_flat = y_flat_init if y_flat_init is not None else jnp.tile(y_dc, K)
        return _solve(y_flat)

    return run_hb