Resonator frequency estimation models#
This example demonstrates estimating resonance frequencies of superconducting microwave resonators using Jax.
Probelines weakly coupled to \(\lambda/4\) resonator#
Creates a probelines weakly coupled to a quarter-wave resonator.
The resonance frequency is first estimated using the resonator_frequency function and then compared to the frequency in the coupled case.
ep_eff, z0 = cpw_parameters(width=10, gap=6)
logger.info(f"{ep_eff=!r}")
logger.info(f"{z0=!r}") # Characteristic impedance
res_freq = resonator_frequency(
length=4000,
epsilon_eff=float(jnp.real(ep_eff)),
is_quarter_wave=True,
)
logger.info(f"Resonance frequency (quarter-wave): {float(res_freq) / 1e9} GHz")
circuit, info = sax.circuit(
netlist={
"instances": {
"R1": "quarter_wave_resonator",
},
"connections": {},
"ports": {
"in": "R1,coupling_o1",
"out": "R1,coupling_o2",
},
},
models={
"quarter_wave_resonator": partial(
quarter_wave_resonator_coupled,
cross_section=coplanar_waveguide(width=10, gap=6),
)
},
)
frequencies = jnp.linspace(1e9, 10e9, 5001)
S = circuit(f=frequencies, length=4000.0)
logger.info(info)
plt.plot(frequencies / 1e9, abs(S["in", "out"]) ** 2)
plt.xlabel("f [GHz]")
plt.ylabel("$S_{21}$")
def _mark_resonance_frequency(x_value: float, color: str, label: str):
"""Draws a vertical dashed line on the current matplotlib plot to mark a resonance frequency."""
plt.axvline(
float(x_value) / 1e9, # Convert frequency from Hz to GHz for plotting
color=color,
linestyle="--",
label=label,
)
_mark_resonance_frequency(res_freq, "red", "Predicted resonance Frequency")
actual_freq = frequencies[jnp.argmin(abs(S["in", "out"]))]
logger.info(f"Coupled resonance frequency: {float(actual_freq) / 1e9} GHz")
_mark_resonance_frequency(actual_freq, "green", "Coupled resonance Frequency")
plt.legend()
2026-04-02 17:57:56.129 | INFO | __main__:<module>:2 - ep_eff=6.065191244680569
2026-04-02 17:57:56.129 | INFO | __main__:<module>:3 - z0=49.312798842593466
2026-04-02 17:57:56.189 | INFO | __main__:<module>:10 - Resonance frequency (quarter-wave): 7.6081395616176914 GHz
2026-04-02 17:57:59.953 | INFO | __main__:<module>:33 - CircuitInfo(dag=<networkx.classes.digraph.DiGraph object at 0x7f7592072990>, models={'quarter_wave_resonator': functools.partial(<function quarter_wave_resonator_coupled at 0x7f75921699e0>, cross_section=CrossSection(sections=(Section(width=10.0, offset=0.0, insets=None, layer=<LayerMapQPDK.M1_DRAW: 1>, port_names=('o1', 'o2'), port_types=('optical', 'optical'), name='_default', hidden=False, simplify=None, width_function=None, offset_function=None), Section(width=6.0, offset=8.0, insets=None, layer=<LayerMapQPDK.M1_ETCH: 47>, port_names=(None, None), port_types=('optical', 'optical'), name='etch_offset_pos', hidden=False, simplify=None, width_function=None, offset_function=None), Section(width=6.0, offset=-8.0, insets=None, layer=<LayerMapQPDK.M1_ETCH: 47>, port_names=(None, None), port_types=('optical', 'optical'), name='etch_offset_neg', hidden=False, simplify=None, width_function=None, offset_function=None), Section(width=10.0, offset=0, insets=None, layer=<LayerMapQPDK.WG: 59>, port_names=(None, None), port_types=('optical', 'optical'), name='waveguide', hidden=False, simplify=None, width_function=None, offset_function=None)), components_along_path=(), radius=100.0, radius_min=11.0, bbox_layers=None, bbox_offsets=None)), 'top_level': <function _flat_circuit.<locals>._circuit at 0x7f75902602c0>}, backend='klu')
2026-04-02 17:58:00.172 | INFO | __main__:<module>:51 - Coupled resonance frequency: 7.5718 GHz
<matplotlib.legend.Legend at 0x7f7544fd74d0>
Optimizer for given resonance frequency#
Find the resonator length that gives a desired resonance frequency using an optimizer. Here we use Optax and JAX’s automatic differentiation. Instead of evaluating the entire frequency band in every iteration, we can simply minimize the transmission \(|S_{21}|^2\) exactly at the target frequency. JAX automatically computes the analytical gradient of the transmission with respect to the resonator length, making the optimization remarkably fast and precise.
TARGET_FREQUENCY = 6e9 # Target resonance frequency in Hz
@jax.jit
def loss_fn(params: dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Loss function to minimize the S21 transmission at the target frequency.
Args:
params: Dictionary containing the resonator length in micrometers.
Returns:
The loss value as a JAX array.
"""
length = params["length"]
# Setup model using the Jittable `circuit` function
S = circuit(f=jnp.array([TARGET_FREQUENCY]), length=length)
# S is evaluated at TARGET_FREQUENCY
# Minimize S21 magnitude at the target frequency
s21 = S["in", "out"][0]
return jnp.real(s21 * jnp.conj(s21))
# Test loss function
logger.info(f"Loss at 4000 um: {float(loss_fn({'length': jnp.array(4000.0)}))}")
logger.info(f"Loss at 5900 um: {float(loss_fn({'length': jnp.array(5900.0)}))}")
# To ensure we start within the narrow resonance dip, we first evaluate a coarse array of lengths
expected_length = float(4000.0 * (res_freq / TARGET_FREQUENCY))
coarse_lengths = jnp.linspace(expected_length - 1000, expected_length + 1000, 2001)
@jax.jit
@jax.vmap
def sweep_loss(length_val: jnp.ndarray) -> jnp.ndarray:
"""Evaluate loss over a single length value during the initial vmap sweep.
Args:
length_val: The length of the resonator in micrometers.
Returns:
The loss value for the given length.
"""
return loss_fn({"length": length_val})
coarse_losses = sweep_loss(coarse_lengths)
best_initial_length = float(coarse_lengths[jnp.argmin(coarse_losses)])
logger.info(f"Best initial guess from sweep: {best_initial_length:.2f} µm")
# Initialize optimizer
params = {"length": jnp.array(best_initial_length)}
optimizer = optax.adam(learning_rate=0.1)
opt_state = optimizer.init(params)
@jax.jit
def step(params, opt_state):
"""Perform a single Optax optimization step to update parameters.
Args:
params: Dictionary containing current parameters (e.g., 'length').
opt_state: Current state of the optimizer.
Returns:
A tuple containing (updated parameters, updated opt_state).
"""
loss_value, grads = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for _ in trange(100, desc="Optimizing"):
params, opt_state, loss_value = step(params, opt_state)
length_val = float(jax.device_get(params["length"]))
logger.info(f"Optimized Length: {length_val:.2f} µm")
# Evaluate over a range of frequencies to verify
verification_frequencies = jnp.linspace(0.5e9, 10e9, 1001)
optimal_S = circuit(f=verification_frequencies, length=params["length"])
optimal_freq = verification_frequencies[jnp.argmin(abs(optimal_S["in", "out"]))]
optimal_freq_val = float(jax.device_get(optimal_freq))
logger.info(f"Achieved Resonance Frequency: {optimal_freq_val / 1e9:.2f} GHz")
# Plot
plt.close()
plt.plot(verification_frequencies / 1e9, abs(optimal_S["in", "out"]) ** 2)
plt.xlabel("f [GHz]")
plt.ylabel("$S_{21}$")
_mark_resonance_frequency(optimal_freq_val, "blue", "Optimized resonance Frequency")
_mark_resonance_frequency(TARGET_FREQUENCY, "orange", "Target resonance Frequency")
plt.legend()
plt.show()
# ruff: enable[E402]
2026-04-02 17:58:02.021 | INFO | __main__:<module>:24 - Loss at 4000 um: 0.9999909577500611
2026-04-02 17:58:02.024 | INFO | __main__:<module>:25 - Loss at 5900 um: 0.9999916460723687
2026-04-02 17:58:03.805 | INFO | __main__:<module>:48 - Best initial guess from sweep: 5053.09 µm
2026-04-02 17:58:08.780 | INFO | __main__:<module>:77 - Optimized Length: 5053.01 µm
2026-04-02 17:58:10.605 | INFO | __main__:<module>:84 - Achieved Resonance Frequency: 6.00 GHz