Source code for qpdk.models.waveguides

"""S-parameter model for a straight waveguide."""

from typing import TypedDict, Unpack

import jax
import jax.numpy as jnp
import sax
from jax.typing import ArrayLike
from skrf import Frequency

from qpdk.models.media import MediaCallable, cpw_media_skrf


class StraightModelKwargs(TypedDict, total=False):
    """Type definition for straight S-parameter model keyword arguments."""

    f: ArrayLike
    length: int | float
    media: MediaCallable


# JIT disabled for now due to scikit-rf internals not being JAX-compatible
# @partial(jax.jit, static_argnames=["media"])
[docs] def straight( f: ArrayLike = jnp.array([5e9]), length: int | float = 1000, media: MediaCallable = cpw_media_skrf(), ) -> sax.SType: """S-parameter model for a straight waveguide. See `scikit-rf <skrf>`_ for details on analytical formulæ. Args: f: Array of frequency points in Hz length: Physical length in µm media: Function returning a scikit-rf :class:`~Media` object after called with ``frequency=f``. If None, uses default CPW media. Returns: sax.SType: S-parameters dictionary .. _skrf: https://scikit-rf.org/ """ # Keep f as tuple for scikit-rf, convert to array only for final JAX operations skrf_media = media(frequency=Frequency.from_f(f, unit="Hz")) transmission_line = skrf_media.line(d=length, unit="um") sdict = { ("o1", "o1"): jnp.array(transmission_line.s[:, 0, 0]), ("o1", "o2"): jnp.array(transmission_line.s[:, 0, 1]), ("o2", "o2"): jnp.array(transmission_line.s[:, 1, 1]), } return sax.reciprocal(sdict)
[docs] def bend_circular( *args: ArrayLike | int | float | MediaCallable, **kwargs: Unpack[StraightModelKwargs], ) -> sax.SType: """S-parameter model for a circular bend, wrapped to to :func:`~straight`.""" return straight(*args, **kwargs) # pyrefly: ignore[bad-keyword-argument]
[docs] def bend_euler( *args: ArrayLike | int | float | MediaCallable, **kwargs: Unpack[StraightModelKwargs], ) -> sax.SType: """S-parameter model for an Euler bend, wrapped to to :func:`~straight`.""" return straight(*args, **kwargs) # pyrefly: ignore[bad-keyword-argument]
[docs] def bend_s( *args: ArrayLike | int | float | MediaCallable, **kwargs: Unpack[StraightModelKwargs], ) -> sax.SType: """S-parameter model for an S-bend, wrapped to to :func:`~straight`.""" return straight(*args, **kwargs) # pyrefly: ignore[bad-keyword-argument]
if __name__ == "__main__": import time from tqdm import tqdm cpw = cpw_media_skrf(width=10, gap=6) def straight_no_jit( f: ArrayLike = jnp.array([5e9]), length: int | float = 1000, media: MediaCallable = cpw_media_skrf(), ) -> sax.SType: """Version of straight without just-in-time compilation.""" skrf_media = media(frequency=Frequency.from_f(f, unit="Hz")) transmission_line = skrf_media.line(d=length, unit="um") sdict = { ("o1", "o1"): jnp.array(transmission_line.s[:, 0, 0]), ("o1", "o2"): jnp.array(transmission_line.s[:, 0, 1]), ("o2", "o2"): jnp.array(transmission_line.s[:, 1, 1]), } return sax.reciprocal(sdict) test_freq = jnp.linspace(0.5e9, 9e9, 200001) test_length = 1000 print("Benchmarking jitted vs non-jitted performance…") n_runs = 10 jit_times = [] for _ in tqdm(range(n_runs), desc="With jax.jit", ncols=80, unit="run"): start_time = time.perf_counter() S_jit = straight(f=test_freq, length=test_length, media=cpw) _ = S_jit["o2", "o1"].block_until_ready() end_time = time.perf_counter() jit_times.append(end_time - start_time) no_jit_times = [] for _ in tqdm(range(n_runs), desc="Without jax.jit", ncols=80, unit="run"): start_time = time.perf_counter() S_no_jit = straight_no_jit(f=test_freq, length=test_length, media=cpw) _ = S_no_jit["o2", "o1"].block_until_ready() end_time = time.perf_counter() no_jit_times.append(end_time - start_time) jit_times_steady = jit_times[1:] avg_jit = sum(jit_times_steady) / len(jit_times_steady) avg_no_jit = sum(no_jit_times) / len(no_jit_times) speedup = avg_no_jit / avg_jit print(f"Jitted: {avg_jit:.4f}s avg (excl. first), {jit_times[0]:.3f}s first run") print(f"Non-jitted: {avg_no_jit:.4f}s avg") print(f"Speedup: {speedup:.1f}x") S_jit = straight(f=test_freq, length=test_length, media=cpw) S_no_jit = straight_no_jit(f=test_freq, length=test_length, media=cpw) max_diff = jnp.max(jnp.abs(S_jit["o2", "o1"] - S_no_jit["o2", "o1"])) print(f"Max absolute difference in results: {max_diff:.2e}") try: s21_array = S_jit["o2", "o1"] s21_gpu = jax.device_put(s21_array, jax.devices("gpu")[0]) print(f"GPU available: {s21_gpu.device}") except Exception: print("GPU not available, using CPU")