Skip to content

Van der Pol Oscillator Tuning via Backpropagation through Harmonic Balance¤

This notebook shows how automatic differentiation through the Harmonic Balance solver can solve oscillator design problems that are impractical with traditional tools.

The problem¤

Given a Van der Pol oscillator (a nonlinear LC tank with negative resistance), tune its parameters so the oscillation hits a precise target frequency and amplitude. This is a joint nonlinear design problem with no closed-form solution.

Traditional approach: Sweep μ over a grid of values, run a transient simulation to steady state for each, measure the amplitude, and read off the closest match. For a two-parameter sweep (frequency + amplitude) the cost is quadratic in grid resolution.

circulax approach: jax.grad through the HB Newton solver via Optimistix's implicit differentiation. The solver finds the periodic steady state in ~15 Newton steps; the gradient is computed by differentiating through that fixed-point computation. A single gradient call replaces the entire parameter sweep.

Circuit: Van der Pol oscillator¤

The Van der Pol element has a cubic I–V characteristic:

\[I(V) = -\mu G_0 V + \frac{\mu G_0}{3} V^3\]
  • For small \(|V|\): slope \(= -\mu G_0 < 0\) — negative resistance, oscillation grows
  • Cubic term limits amplitude at \(|V| \approx \sqrt{3/G_0}\) — stable limit cycle

The element is placed in parallel with an LC tank:

top_node ─── VDP   ─── GND    (Van der Pol: negative resistance + cubic saturation)
top_node ─── L1    ─── GND    (inductor)
top_node ─── C1    ─── GND    (capacitor)
top_node ─── Rdamp ─── GND    (small tank loss, models coil resistance)

The tank resonates at \(f_0 = 1/(2\pi\sqrt{LC})\). The VDP element pumps energy in at small amplitudes (negative conductance dominates) and absorbs energy at large amplitudes (cubic term dominates), creating a stable limit cycle.

import jax
import jax.numpy as jnp
import numpy as np
import optax
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots

from circulax import compile_circuit, setup_harmonic_balance
from circulax.components.base_component import PhysicsReturn, Signals, States, component
from circulax.components.electronic import Capacitor, Inductor, Resistor
from circulax.utils import update_group_params, update_params_dict

# 64-bit precision is important: HB Newton requires accurate Jacobians, and
# gradient-based optimisation accumulates floating-point error across steps.
jax.config.update("jax_enable_x64", True)

pio.templates.default = "plotly_white"
pio.renderers.default = "png"

Defining the Van der Pol component¤

Custom components are plain Python functions decorated with @component. The decorator generates an Equinox module whose parameters (mu, G0) are JAX-traceable leaves, making the component compatible with jax.vmap, jax.jacfwd, and jax.grad.

@component(ports=("p1", "p2"))
def VanDerPolElement(signals: Signals, s: States, mu: float = 2.0, G0: float = 0.01) -> PhysicsReturn:
    """Nonlinear two-terminal element with cubic I-V characteristic.

    I(V) = -mu*G0*V + (G0/3)*V^3

    The first term is a negative conductance (-mu*G0 < 0), which supplies
    energy to the tank when |V| is small.  The cubic term (G0/3)*V^3 saturates
    the gain at large amplitudes, producing a stable oscillation.

    Limit-cycle amplitude (fundamental, from harmonic balance energy balance):
        -mu*G0*A + G0/4*A^3 = 0  -->  A = 2*sqrt(mu)
    With G0=0.01 and mu=2.0: A = 2*sqrt(2) ≈ 2.83 V

    Note: mu appears only in the linear term so that the limit-cycle amplitude
    A = 2*sqrt(mu) is strongly tunable — a factor-of-4 change in mu moves A by
    2×.  If mu were also in the cubic term, amplitude would saturate at ~2 V
    regardless of mu.
    """
    v = signals.p1 - signals.p2
    i = -mu * G0 * v + (G0 / 3.0) * v**3
    return {"p1": i, "p2": -i}, {}


# Quick smoke-test: check that the I-V slope at V=0 is -mu*G0 (negative conductance)
vdp_test = VanDerPolElement(mu=2.0, G0=0.01)
f_test, _ = vdp_test(p1=0.1, p2=0.0)
print(f"VDP I at V=0.1 V : {f_test['p1']:.6f} A  (expected {-2.0*0.01*0.1 + (0.01/3)*0.1**3:.6f} A)")

print(f"Expected limit-cycle amplitude: 2*sqrt(mu) = {2*np.sqrt(2.0):.3f} V  (at mu=2, G0=0.01)")
VDP I at V=0.1 V : -0.001997 A  (expected -0.001997 A)
Expected limit-cycle amplitude: 2*sqrt(mu) = 2.828 V  (at mu=2, G0=0.01)

Building and compiling the circuit¤

All four elements connect between the same top_node and GND — they are in parallel. The LC tank sets the oscillation frequency; Rdamp models the finite Q of a real inductor (10 kΩ is intentionally large so tank losses are much smaller than the VDP negative resistance, ensuring oscillation).

The tuple syntax for connections ("GND,p1": ("VDP,p2", "C1,p2", ...)) joins multiple ports to the same node in a single line — equivalent to writing each pair separately.

# ── Circuit parameters ────────────────────────────────────────────────────────
L_val  = 1e-6    # H  (1 µH)
C_val  = 1e-9    # F  (1 nF)
R_damp = 1e4     # Ω  (10 kΩ tank loss — small, Q ≈ R_damp * sqrt(C/L) ≈ 316)

f0 = 1.0 / (2.0 * np.pi * np.sqrt(L_val * C_val))
print(f"Tank resonant frequency : {f0 / 1e6:.4f} MHz")
print(f"Tank Q-factor (Rdamp)   : {R_damp * np.sqrt(C_val / L_val):.1f}")
print(f"VDP negative conductance: {-2.0 * 0.01:.4f} S  (at mu=2, G0=0.01)")
print(f"Tank loss conductance   : {1.0 / R_damp:.6f} S  (1/Rdamp)")
print(f"Net gain margin         : {abs(-2.0 * 0.01) / (1.0 / R_damp):.0f}×  >> 1, oscillation guaranteed")

# ── Netlist ──────────────────────────────────────────────────────────────────
# All four elements are in parallel between top_node and GND.
# GND is the special reserved name — any port touching "GND" is assigned node 0.
vdp_net = {
    "instances": {
        "VDP":   {"component": "vdp",       "settings": {"mu": 2.0,    "G0": 0.01}},
        "L1":    {"component": "inductor",   "settings": {"L": L_val}},
        "C1":    {"component": "capacitor",  "settings": {"C": C_val}},
        "Rdamp": {"component": "resistor",   "settings": {"R": R_damp}},
    },
    "connections": {
        # Connect all negative terminals to GND (node 0)
        "GND,p1":  ("VDP,p2", "L1,p2", "C1,p2", "Rdamp,p2"),
        # Connect all positive terminals to the same top_node
        "VDP,p1":  ("L1,p1", "C1,p1", "Rdamp,p1"),
    },
}

models = {
    "vdp":      VanDerPolElement,
    "inductor":  Inductor,
    "capacitor": Capacitor,
    "resistor":  Resistor,
}

# compile_circuit runs once: builds ComponentGroup objects with batched JAX arrays
# for parameters, and pre-computes index arrays for residual assembly.
circuit = compile_circuit(vdp_net, models)
groups = circuit.groups
num_vars = circuit.sys_size
net_map = circuit.port_map

print(f"\nSystem size : {num_vars} unknowns")
print(f"Node map    : {net_map}")
print(f"Groups      : {list(groups.keys())}")

# The oscillator node — all parallel elements share this port
osc_node = net_map["VDP,p1"]
print(f"Oscillator node index: {osc_node}")

# DC operating point: V=0 is the only fixed point (the VDP element has I(0)=0)
y_dc = circuit()
print(f"DC operating point: max|y_dc| = {float(jnp.max(jnp.abs(y_dc))):.2e} V  (trivially zero)")
Tank resonant frequency : 5.0329 MHz
Tank Q-factor (Rdamp)   : 316.2
VDP negative conductance: -0.0200 S  (at mu=2, G0=0.01)
Tank loss conductance   : 0.000100 S  (1/Rdamp)
Net gain margin         : 200×  >> 1, oscillation guaranteed



System size : 3 unknowns
Node map    : {'VDP,p1': 1, 'C1,p1': 1, 'Rdamp,p1': 1, 'L1,p1': 1, 'VDP,p2': 0, 'C1,p2': 0, 'GND,p1': 0, 'L1,p2': 0, 'Rdamp,p2': 0, 'L1,i_L': 2}
Groups      : ['vdp', 'inductor', 'capacitor', 'resistor']
Oscillator node index: 1


DC operating point: max|y_dc| = 0.00e+00 V  (trivially zero)

Part 1 — Harmonic Balance finds the limit cycle¤

Transient simulation would need to integrate forward in time until the oscillation envelope settles — often hundreds of RF cycles. Harmonic Balance finds the periodic steady state directly by solving for the Fourier coefficients of the waveform.

setup_harmonic_balance builds a residual function over K = 2N+1 equally-spaced time samples per period and solves it with Newton–Raphson (via Optimistix's FixedPointIteration backed by jax.lax.while_loop). The result is JIT-compatible and differentiable end-to-end.

N_harm = 7   # 7 harmonics → K = 15 time points per period
K      = 2 * N_harm + 1

# Pass osc_node so setup_harmonic_balance automatically tries several sinusoidal
# initial amplitudes via jax.vmap and selects the limit-cycle solution.
# Users never need to choose a starting amplitude or think about the trivial y=0
# fixed point — the multi-start strategy handles it transparently.
run_hb = setup_harmonic_balance(groups, num_vars, freq=f0, num_harmonics=N_harm, osc_node=osc_node)
y_time, y_freq = jax.jit(run_hb)(y_dc)

print(f"y_time shape : {y_time.shape}  (K={K} time samples × {num_vars} nodes)")
print(f"y_freq shape : {y_freq.shape}  ({N_harm+1} harmonics × {num_vars} nodes)")

# y_freq[0] is the DC component, y_freq[1] is the fundamental, etc.
# Two-sided amplitude at harmonic k>=1 is 2 * |y_freq[k]|  (rfft folds negative freqs)
A_dc   = float(jnp.abs(y_freq[0, osc_node]))
A_fund = float(2.0 * jnp.abs(y_freq[1, osc_node]))
A_2nd  = float(2.0 * jnp.abs(y_freq[2, osc_node]))
A_3rd  = float(2.0 * jnp.abs(y_freq[3, osc_node]))

print(f"\nOscillator node (index {osc_node}) harmonic amplitudes:")
print(f"  DC (0f0)       : {A_dc:.4f} V")
print(f"  Fundamental f0 : {A_fund:.4f} V")
print(f"  2nd harmonic   : {A_2nd:.4f} V  ({A_2nd/A_fund*100:.1f}% of fundamental)")
print(f"  3rd harmonic   : {A_3rd:.4f} V  ({A_3rd/A_fund*100:.1f}% of fundamental)")
print(f"\nExpected amplitude (VdP limit cycle): ~2*sqrt(mu) = {2*np.sqrt(2.0):.3f} V")
y_time shape : (15, 3)  (K=15 time samples × 3 nodes)
y_freq shape : (8, 3)  (8 harmonics × 3 nodes)

Oscillator node (index 1) harmonic amplitudes:
  DC (0f0)       : 0.0000 V
  Fundamental f0 : 2.9139 V
  2nd harmonic   : 0.0000 V  (0.0% of fundamental)
  3rd harmonic   : 0.2501 V  (8.6% of fundamental)

Expected amplitude (VdP limit cycle): ~2*sqrt(mu) = 2.828 V
# ── Plot: time-domain waveform and harmonic spectrum ─────────────────────────
T     = 1.0 / f0
t_ns  = np.linspace(0.0, T * 1e9, K, endpoint=False)
v_osc = np.array(y_time[:, osc_node])

harmonics   = np.arange(N_harm + 1)
scale       = np.where(harmonics == 0, 1.0, 2.0)
spectrum    = scale * np.abs(np.array(y_freq[:, osc_node]))
tick_labels = ["DC" if k == 0 else f"{k}f\u2080" for k in harmonics]

fig = make_subplots(
    rows=1, cols=2,
    subplot_titles=(
        f"Van der Pol limit cycle — f\u2080 = {f0/1e6:.3f} MHz",
        "Harmonic spectrum",
    ),
)

fig.add_trace(
    go.Scatter(x=t_ns, y=v_osc, mode="lines+markers",
               marker=dict(size=7), line=dict(width=1.5),
               name=f"V_osc (HB, K={K})"),
    row=1, col=1,
)
fig.add_hline(y=0, line_dash="dash", line_color="grey", line_width=0.7, row=1, col=1)

fig.add_trace(
    go.Bar(x=tick_labels, y=spectrum, marker_color="#636EFA", opacity=0.85, name="spectrum"),
    row=1, col=2,
)

fig.update_xaxes(title_text="Time (ns)", row=1, col=1)
fig.update_yaxes(title_text="Voltage (V)", row=1, col=1)
fig.update_xaxes(title_text="Harmonic", row=1, col=2)
fig.update_yaxes(title_text="Amplitude (V)", row=1, col=2)
fig.update_layout(margin=dict(t=80, b=60, l=60, r=60), height=450, showlegend=True)
fig.show()

print("The odd-harmonic dominance (f0, 3f0, 5f0) is a hallmark of the symmetric cubic nonlinearity.")
print("Even harmonics (2f0, 4f0) are near-zero because I(-V) = -I(V) — the VDP element is odd.")

png

The odd-harmonic dominance (f0, 3f0, 5f0) is a hallmark of the symmetric cubic nonlinearity.
Even harmonics (2f0, 4f0) are near-zero because I(-V) = -I(V) — the VDP element is odd.

Part 2 — Gradient-based amplitude tuning¤

Goal: find the value of μ such that the fundamental amplitude equals a target \(A_{\text{target}}\).

The key insight: setup_harmonic_balance returns a function that is end-to-end differentiable with respect to any JAX-traceable quantity captured in groups. The parameter update is done with update_group_params (uses eqx.tree_at under the hood — a pure functional update with no in-place mutation), so the entire loss computation is a valid JAX program that jax.grad can differentiate through.

Optimistix's FixedPointIteration implements implicit differentiation: the gradient of the fixed-point solution with respect to μ is computed via the implicit function theorem rather than unrolling the Newton iterations.

A_target = 1.5  # V  (requires mu* = (A_target/2)^2 = 0.5625)

# Fixed warm start for the loss function: sinusoidal at 3.5 V at osc_node, all
# other variables zero.  3.5 V is above the limit-cycle amplitude for all mu in
# [0.5, 3] (LC range [1.5V, 3.46V]), so Newton always descends to the LC, never
# to the trivial y=0 fixed point.  Using a constant here keeps the gradient path
# free of argmax (Optimistix's implicit diff works cleanly through the fixed point).
_phase = 2.0 * jnp.pi * jnp.arange(K, dtype=jnp.float64) / K
y_flat_warmstart = (
    jnp.zeros(K * num_vars, dtype=jnp.float64)
    .at[jnp.arange(K) * num_vars + osc_node].set(3.5 * jnp.sin(_phase))
)


def loss_fn_mu(mu: jax.Array) -> jax.Array:
    """Squared error between fundamental amplitude and target, as a function of mu."""
    groups_new = update_group_params(groups, "vdp", "mu", mu)
    run_hb_new = setup_harmonic_balance(groups_new, num_vars, freq=f0, num_harmonics=N_harm)
    _, y_freq_new = run_hb_new(y_dc, y_flat_init=y_flat_warmstart)
    A_fund = 2.0 * jnp.abs(y_freq_new[1, osc_node])
    return (A_fund - A_target) ** 2


# Test: evaluate loss and gradient at the default mu=2.0
mu_test = jnp.array(2.0)
loss_val, grad_val = jax.value_and_grad(loss_fn_mu)(mu_test)
print(f"At mu=2.0:  loss = {float(loss_val):.4f} V²,  grad = {float(grad_val):.4f} V²/[mu]")
print(f"  Current amplitude ≈ {float(jnp.sqrt(loss_val)) + A_target:.3f} V,  target = {A_target} V")
print("  Positive gradient → decreasing mu reduces amplitude toward target")

# ── Gradient descent — scalar parameter, no need for Adam ───────────────────
mu = jnp.array(3.0)  # start above the solution (A=2*sqrt(3)≈3.46 V)
lr = 0.05

mu_history    = [float(mu)]
loss_history  = []

val_and_grad_jit = jax.jit(jax.value_and_grad(loss_fn_mu))

print(f"\nGradient descent (lr={lr}, starting from mu={float(mu):.1f}):")
for step in range(80):
    loss, g = val_and_grad_jit(mu)
    loss_history.append(float(loss))
    mu = mu - lr * g
    mu_history.append(float(mu))
    if step % 20 == 0 or step == 79:
        groups_cur = update_group_params(groups, "vdp", "mu", mu)
        _, yf_cur  = jax.jit(setup_harmonic_balance(groups_cur, num_vars, freq=f0, num_harmonics=N_harm, osc_node=osc_node))(y_dc)
        A_cur = float(2.0 * jnp.abs(yf_cur[1, osc_node]))
        print(f"  Step {step:3d}: mu = {float(mu):.4f},  loss = {float(loss):.6f},  A_fund = {A_cur:.4f} V")

mu_opt = float(mu)
print(f"\nOptimised mu = {mu_opt:.4f}  (analytical: {(A_target/2)**2:.4f})")
At mu=2.0:  loss = 1.9991 V²,  grad = 3.4346 V²/[mu]
  Current amplitude ≈ 2.914 V,  target = 1.5 V
  Positive gradient → decreasing mu reduces amplitude toward target

Gradient descent (lr=0.05, starting from mu=3.0):


  Step   0: mu = 2.5133,  loss = 4.027752,  A_fund = 3.2511 V


  Step  20: mu = 0.5763,  loss = 0.013855,  A_fund = 1.5303 V


  Step  40: mu = 0.5078,  loss = 0.000000,  A_fund = 1.5000 V


  Step  60: mu = 0.5076,  loss = 0.000000,  A_fund = 1.4998 V


  Step  79: mu = 0.5076,  loss = 0.000000,  A_fund = 1.4998 V

Optimised mu = 0.5076  (analytical: 0.5625)
# ── Compute waveforms before and after optimisation ───────────────────────────
groups_before = update_group_params(groups, "vdp", "mu", jnp.array(2.0))
groups_after  = update_group_params(groups, "vdp", "mu", jnp.array(mu_opt))

_, yf_before = jax.jit(setup_harmonic_balance(groups_before, num_vars, freq=f0, num_harmonics=N_harm, osc_node=osc_node))(y_dc)
yt_after, yf_after = jax.jit(setup_harmonic_balance(groups_after, num_vars, freq=f0, num_harmonics=N_harm, osc_node=osc_node))(y_dc)

A_before = float(2.0 * jnp.abs(yf_before[1, osc_node]))
A_after  = float(2.0 * jnp.abs(yf_after[1, osc_node]))

v_before = np.array(y_time[:, osc_node])
v_after  = np.array(yt_after[:, osc_node])

fig = make_subplots(
    rows=1, cols=2,
    subplot_titles=(
        "Before and after \u03bc tuning",
        f"Amplitude tuning convergence (target {A_target} V)",
    ),
)

fig.add_trace(
    go.Scatter(x=t_ns, y=v_before, mode="lines", line=dict(width=2),
               name=f"\u03bc=2.00 \u2192 A={A_before:.2f} V (initial)"),
    row=1, col=1,
)
fig.add_trace(
    go.Scatter(x=t_ns, y=v_after, mode="lines", line=dict(width=2.5, dash="dash"),
               name=f"\u03bc={mu_opt:.2f} \u2192 A={A_after:.2f} V (optimised)"),
    row=1, col=1,
)
fig.add_hline(y= A_target, line_dash="dot", line_color="grey", line_width=0.8, row=1, col=1)
fig.add_hline(y=-A_target, line_dash="dot", line_color="grey", line_width=0.8, row=1, col=1)

fig.add_trace(
    go.Scatter(x=list(range(len(loss_history))), y=loss_history, mode="lines",
               line=dict(width=2), name="loss", showlegend=False),
    row=1, col=2,
)

fig.update_xaxes(title_text="Time (ns)", row=1, col=1)
fig.update_yaxes(title_text="Voltage (V)", row=1, col=1)
fig.update_xaxes(title_text="Gradient descent step", row=1, col=2)
fig.update_yaxes(title_text="Loss (V\u00b2)", type="log", row=1, col=2)
fig.update_layout(margin=dict(t=80, b=100, l=60, r=60), height=520, legend=dict(orientation="h", y=-0.25, x=0.5, xanchor="center"))
fig.show()

print(f"Amplitude error after optimisation: {abs(A_after - A_target)*1000:.2f} mV")

png

Amplitude error after optimisation: 0.23 mV

Part 3 — Joint tuning: frequency and amplitude¤

Now we solve the harder problem: simultaneously tune L, C, and μ so the oscillator hits a new target frequency (8 MHz, shifted from 5 MHz) and a target amplitude (1.0 V).

This is a joint nonlinear design problem. The traditional approach would be a 3D sweep over (L, C, μ) — thousands of HB solves just to map the landscape. With jax.grad we get the exact gradient with one HB solve, and Adam navigates directly to the solution.

Log-space parameterisation keeps all three parameters positive and balances their gradients: a 10% change in L feels the same as a 10% change in μ, regardless of the absolute scale. We also pass the resonant frequency (derived from L and C) directly to setup_harmonic_balance — this ensures the HB discretisation always matches the actual oscillation frequency, which is essential for convergence when L and C are changing.

f_target = 8e6    # Hz — 8 MHz target (up from ~5 MHz)
A_target_j = 1.0  # V  — 1 V fundamental amplitude

LC_target = 1.0 / (2.0 * np.pi * f_target) ** 2
print(f"Target frequency    : {f_target/1e6:.1f} MHz")
print(f"Required L*C product: {LC_target:.3e} H·F")
print(f"Starting L={L_val*1e6:.2f} µH, C={C_val*1e9:.2f} nF  → f0={f0/1e6:.3f} MHz")

# Fixed sinusoidal warm start for the joint optimisation loss — amplitude 2 V at osc_node,
# all other nodes zero.  This is above the basin-of-attraction saddle (~1.15 V) for all
# (L, C, mu) values explored during optimisation.  Using a constant here keeps the
# gradient path free of argmax, so Optimistix's implicit diff works cleanly.
_phase_k = 2.0 * jnp.pi * jnp.arange(K, dtype=jnp.float64) / K
y_flat_hb_init = (
    jnp.zeros(K * num_vars, dtype=jnp.float64)
    .at[jnp.arange(K) * num_vars + osc_node].set(3.5 * jnp.sin(_phase_k))
)


def loss_joint(log_params: jax.Array) -> jax.Array:
    """Joint loss over (log_L, log_C, log_mu) for frequency + amplitude targets."""
    log_L, log_C, log_mu = log_params
    L   = jnp.exp(log_L)
    C   = jnp.exp(log_C)
    mu  = jnp.exp(log_mu)

    f_resonant = 1.0 / (2.0 * jnp.pi * jnp.sqrt(L * C))

    grps = update_params_dict(groups, "inductor",  "L1", "L", L)
    grps = update_params_dict(grps,   "capacitor", "C1", "C", C)
    grps = update_group_params(grps,  "vdp", "mu", mu)

    run_hb_j = setup_harmonic_balance(grps, num_vars, freq=f_resonant, num_harmonics=N_harm)
    _, y_freq_j = run_hb_j(jnp.zeros(num_vars), y_flat_init=y_flat_hb_init)

    A_fund = 2.0 * jnp.abs(y_freq_j[1, osc_node])

    loss_freq = ((f_resonant - f_target) / f_target) ** 2
    loss_amp  = ((A_fund - A_target_j) / A_target_j) ** 2

    return loss_freq + loss_amp


# Sanity check
log_params_0 = jnp.log(jnp.array([L_val, C_val, 2.0]))
loss_0 = loss_joint(log_params_0)
print(f"\nInitial joint loss: {float(loss_0):.4f}  (freq error + amplitude error)")

# ── Adam optimisation ─────────────────────────────────────────────────────────
optimizer   = optax.adam(0.05)
log_params  = log_params_0
opt_state   = optimizer.init(log_params)
losses_joint = []
param_log_hist = [np.array(log_params)]

val_grad_joint = jax.jit(jax.value_and_grad(loss_joint))

print("\nAdam optimisation (200 steps, lr=0.05):")
for i in range(200):
    loss, grads = val_grad_joint(log_params)
    losses_joint.append(float(loss))
    updates, opt_state = optimizer.update(grads, opt_state)
    log_params = optax.apply_updates(log_params, updates)
    param_log_hist.append(np.array(log_params))
    if i % 50 == 0 or i == 199:
        L_c, C_c, mu_c = np.exp(np.array(log_params))
        f_c_cur = 1.0 / (2.0 * np.pi * np.sqrt(L_c * C_c))
        print(f"  Step {i:3d}: loss={float(loss):.5f},  f={f_c_cur/1e6:.3f} MHz,  L={L_c*1e6:.3f} µH,  C={C_c*1e9:.3f} nF,  mu={mu_c:.3f}")

L_opt, C_opt, mu_opt_j = np.exp(np.array(log_params))
f_opt = 1.0 / (2.0 * np.pi * np.sqrt(L_opt * C_opt))
print(f"\nFinal: f={f_opt/1e6:.4f} MHz  (target {f_target/1e6:.1f} MHz),  L={L_opt*1e6:.4f} µH,  C={C_opt*1e9:.4f} nF,  mu={mu_opt_j:.4f}")
Target frequency    : 8.0 MHz
Required L*C product: 3.958e-16 H·F
Starting L=1.00 µH, C=1.00 nF  → f0=5.033 MHz



Initial joint loss: 3.8006  (freq error + amplitude error)

Adam optimisation (200 steps, lr=0.05):


  Step   0: loss=3.80059,  f=5.033 MHz,  L=0.951 µH,  C=1.051 nF,  mu=1.902


  Step  50: loss=0.08283,  f=8.239 MHz,  L=0.234 µH,  C=1.594 nF,  mu=0.412


  Step 100: loss=0.03842,  f=8.014 MHz,  L=0.222 µH,  C=1.777 nF,  mu=0.304


  Step 150: loss=0.00080,  f=8.002 MHz,  L=0.220 µH,  C=1.800 nF,  mu=0.270


  Step 199: loss=0.00008,  f=8.000 MHz,  L=0.219 µH,  C=1.805 nF,  mu=0.261

Final: f=7.9998 MHz  (target 8.0 MHz),  L=0.2192 µH,  C=1.8055 nF,  mu=0.2613


Pre-computing 51 HB solutions (stride=4)...


  [  1/51] step   0 — f=5.03 MHz, A=2.914 V, mu=2.000


  [ 11/51] step  40 — f=7.65 MHz, A=1.483 V, mu=0.496


  [ 21/51] step  80 — f=7.92 MHz, A=1.238 V, mu=0.330


  [ 31/51] step 120 — f=8.00 MHz, A=1.164 V, mu=0.286


  [ 41/51] step 160 — f=8.00 MHz, A=1.133 V, mu=0.268


  [ 51/51] step 200 — f=8.00 MHz, A=1.120 V, mu=0.261
Done.



Saved → examples/inverse_design/oscillator_optimisation.gif
Frames: 51   Duration: 4.2s at 12 fps
Tip: drag-and-drop directly into PowerPoint (Insert → Pictures).

oscillator_optimisation.gif

Summary¤

Starting from a Van der Pol oscillator running at ~5 MHz with ~2.8 V amplitude, gradient descent simultaneously tuned L, C, and μ to hit 8 MHz and 1.0 V in 200 Adam steps — with no grid sweeps and no manual iteration.

What made this possible?¤

Step Tool Role
Component definition @component decorator Generates JAX-traceable Equinox module
Netlist compilation compile_circuit Runs once; produces vmappable ComponentGroup objects
Differentiable parameter update update_params_dict / update_group_params eqx.tree_at functional update; no recompilation
Periodic steady state setup_harmonic_balance Finds limit cycle in ~15 Newton steps; JIT-compatible
Exact gradients jax.grad Differentiates through the HB Newton solver via implicit differentiation
Optimisation optax.adam First-order optimiser in log-space; 200 steps to convergence

Going further¤

The same pattern extends directly to:

  • Phase noise minimisation — add a noise source and minimise the HB-estimated phase noise spectral density as a function of tank Q and bias point.
  • Injection locking — add a small-signal source and tune the free-running frequency to lock at the injection frequency with maximum locking range.
  • Coupled oscillator arrays — vmap the HB solve over an array of weakly coupled oscillators and optimise the coupling network for synchronisation.
  • CMOS VCO design — replace the VDP element with a transistor-level cross-coupled pair (using NMOS / PMOS components from circulax.components.electronic) and sweep the varactor capacitance as a differentiable parameter.