Source code for gplugins.sax.plot_model

"""Useful plot functions."""

from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
from pydantic import validate_call
from sax.saxtypes import Model


[docs] @validate_call def plot_model( model: Model, port1: str = "o1", ports2: tuple[str, ...] | None = None, logscale: bool = True, fig=None, wavelength_start: float = 1.5, wavelength_stop: float = 1.6, wavelength_points: int = 2000, phase: bool = False, ) -> None: """Plot Model Sparameters Magnitude. Args: model: function that returns SDict as function of wavelength. port1: input port name. ports2: list of ports. logscale: plots in dB logarithmic scale. wavelength_start: wavelength min (um). wavelength_stop: wavelength max (um). wavelength_points: number of wavelength steps. phase: plot phase instead of magnitude. .. plot:: :include-source: import gplugins.sax as gs gs.plot_model(gs.models.straight, phase=True, port1="o1") """ wavelengths = np.linspace(wavelength_start, wavelength_stop, wavelength_points) sdict = model(wl=wavelengths) ports = {ports[0] for ports in sdict.keys()} ports2 = ports2 or ports if port1 not in ports: raise ValueError(f"port1 {port1!r} not in {list(ports)}") for port in ports2: if port not in ports: raise ValueError(f"port2 {port!r} not in {list(ports)}") fig = fig or plt.subplot() ax = fig.axes for port2 in ports2: if (port1, port2) in sdict: if phase: y = np.angle(sdict[(port1, port2)]) ylabel = "angle (rad)" else: y = np.abs(sdict[(port1, port2)]) y = 20 * np.log10(y) if logscale else y ylabel = "|S (dB)|" if logscale else "|S|" ax.plot(wavelengths, y, label=port2) ax.set_title(port1) ax.set_xlabel("wavelength (um)") ax.set_ylabel(ylabel) plt.legend() return ax
if __name__ == "__main__": import gplugins.sax as gs plot_model(gs.models.straight, phase=True, port1="o1", ports2=("o1", "o2")) plt.show()