Resonator frequency estimation models#

This example demonstrates estimating resonance frequencies of superconducting microwave resonators using scikit-rf and Jax.

Hide code cell source

import math
import os
from functools import partial

import jax.numpy as jnp
import sax
import skrf

from qpdk.models.resonator import (
    cpw_media_skrf,
    quarter_wave_resonator_coupled_to_probeline,
    resonator_frequency,
)

Probelines weakly coupled to \(\lambda/4\) resonator#

Creates a probelines weakly coupled to a quarter-wave resonator. The resonance frequency is first estimated using the resonator_frequency function and then compared to the frequency in the coupled case.

if __name__ == "__main__":
    import matplotlib.pyplot as plt

    cpw = cpw_media_skrf(width=10, gap=6)(
        frequency=skrf.Frequency(2, 9, 101, unit="GHz")
    )
    print(f"{cpw=!r}")
    print(f"{cpw.z0.mean().real=!r}")  # Characteristic impedance

    res_freq = resonator_frequency(length=4000, media=cpw, is_quarter_wave=True)
    print("Resonance frequency (quarter-wave):", res_freq / 1e9, "GHz")

    circuit, info = sax.circuit(
        netlist={
            "instances": {
                "R1": "quarter_wave_resonator",
            },
            "connections": {},
            "ports": {
                "in": "R1,o1",
                "out": "R1,o2",
            },
        },
        models={
            "quarter_wave_resonator": partial(
                quarter_wave_resonator_coupled_to_probeline,
                media=cpw_media_skrf(width=10, gap=6),
                length=4000,
                coupling_capacitance=15e-15,
            )
        },
    )

    frequencies = jnp.linspace(1e9, 10e9, 5001)
    S = circuit(f=frequencies)
    print(info)
    plt.plot(frequencies / 1e9, abs(S["in", "out"]) ** 2)
    plt.xlabel("f [GHz]")
    plt.ylabel("$S_{21}$")

    def _mark_resonance_frequency(x_value: float, color: str, label: str):
        """Draws a vertical dashed line on the current matplotlib plot to mark a resonance frequency."""
        plt.axvline(
            x_value / 1e9,  # Convert frequency from Hz to GHz for plotting
            color=color,
            linestyle="--",
            label=label,
        )

    _mark_resonance_frequency(res_freq, "red", "Predicted resonance Frequency")
    actual_freq = frequencies[jnp.argmin(abs(S["in", "out"]))]
    print("Coupled resonance frequency:", actual_freq / 1e9, "GHz")
    _mark_resonance_frequency(actual_freq, "green", "Coupled resonance Frequency")

    plt.legend()
    # plt.show()
cpw=Coplanar Waveguide Media.  2.0-9.0 GHz. 101 points
 W= 1.00e-05m, S= 6.00e-06m
cpw.z0.mean().real=np.float64(49.27868570939533)
Resonance frequency (quarter-wave): 7.6081394901138095 GHz
CircuitInfo(dag=<networkx.classes.digraph.DiGraph object at 0x7f463a5237a0>, models={'quarter_wave_resonator': functools.partial(<function quarter_wave_resonator_coupled_to_probeline at 0x7f463a49cea0>, media=functools.partial(<class 'skrf.media.cpw.CPW'>, w=9.999999999999999e-06, s=6e-06, h=0.0005, t=1.9999999999999996e-07, ep_r=11.45, rho=1e-100, tand=0), length=4000, coupling_capacitance=1.5e-14), 'top_level': <function _flat_circuit.<locals>._circuit at 0x7f4639f8ccc0>}, backend='klu')
Coupled resonance frequency: 7.4404 GHz
../_images/61498b9850a8aa863ec89fb6a61912d8f69684534873406039d8c2ccb6bc1738.png

Optimizer for given resonance frequency#

Find the resonator length that gives a desired resonance frequency using an optimizer.


if __name__ == "__main__":
    import ray
    import ray.tune
    import ray.tune.search.optuna

    frequencies = jnp.linspace(0.5e9, 10e9, 1001)
    TARGET_FREQUENCY = 6e9  # Target resonance frequency in Hz

    def loss_fn(config: dict[str, float]) -> float:
        """Loss function to minimize the difference between the actual and target resonance frequencies.

        Args:
            config: Dictionary containing the resonator length in micrometers.
        """
        length = config["length"]
        # Setup model
        S = circuit(f=frequencies, length=length)
        # Get frequency at minimum S21
        coupled_freq = frequencies[jnp.argmin(abs(S["in", "out"]))]
        return {
            "l1_loss_ghz": abs(float(coupled_freq) - TARGET_FREQUENCY) / 1e9,
            "mse": (float(coupled_freq) - TARGET_FREQUENCY) ** 2,
        }

    # Test loss function
    print(f"{loss_fn(dict(length=4000.0))=}")
    print(f"{loss_fn(dict(length=5900.0))=}")

    # Optimize length using Ray Tune
    tuner = ray.tune.Tuner(
        loss_fn,
        param_space={
            "length": ray.tune.uniform(1000.0, 9000.0),
        },
        tune_config=ray.tune.TuneConfig(
            metric="mse",
            mode="min",
            num_samples=10,
            max_concurrent_trials=math.ceil(os.cpu_count() / 4),
            reuse_actors=True,
            search_alg=ray.tune.search.optuna.OptunaSearch(),
        ),
    )
    results = tuner.fit()
    best_trial = results.get_best_result()
    length = best_trial.config["length"]
    print(f"Best trial config: {best_trial.config}")

    # Initialize optimizer
    print(f"Optimized Length: {length:.2f} µm")
    optimal_S = circuit(f=frequencies, length=length)
    optimal_freq = frequencies[jnp.argmin(abs(optimal_S["in", "out"]))]
    print(f"Achieved Resonance Frequency: {optimal_freq / 1e9:.2f} GHz")

    # Plot
    plt.plot(frequencies / 1e9, abs(optimal_S["in", "out"]) ** 2)
    plt.xlabel("f [GHz]")
    plt.ylabel("$S_{21}$")
    _mark_resonance_frequency(optimal_freq, "blue", "Optimized resonance Frequency")
    _mark_resonance_frequency(TARGET_FREQUENCY, "orange", "Target resonance Frequency")
    plt.legend()
    plt.show()