Source code for gplugins.sax.plot_model
"""Useful plot functions."""
from __future__ import annotations
import matplotlib.pyplot as plt
import numpy as np
from pydantic import validate_call
from sax.saxtypes import Model
[docs]
@validate_call
def plot_model(
model: Model,
port1: str = "o1",
ports2: tuple[str, ...] | None = None,
logscale: bool = True,
min_db_range: float = 0.5,
fig=None,
wavelength_start: float = 1.5,
wavelength_stop: float = 1.6,
wavelength_points: int = 2000,
phase: bool = False,
title: str | None = None,
) -> None:
"""Plot Model Sparameters Magnitude.
Args:
model: function that returns SDict as function of wavelength.
port1: input port name.
ports2: list of ports.
logscale: plots in dB logarithmic scale.
min_db_range: minimum dB range. Set to 0 to disable.
fig: matplotlib figure.
wavelength_start: wavelength min (µm).
wavelength_stop: wavelength max (µm).
wavelength_points: number of wavelength steps.
phase: plot phase instead of magnitude.
title: plot title.
.. plot::
:include-source:
import gplugins.sax as gs
gs.plot_model(gs.models.straight, phase=True, port1="o1")
"""
wavelengths = np.linspace(wavelength_start, wavelength_stop, wavelength_points)
sdict = model(wl=wavelengths)
ports = {ports[0] for ports in sdict.keys()}
ports2 = ports2 or ports
if port1 not in ports:
raise ValueError(f"port1 {port1!r} not in {list(ports)}")
for port in ports2:
if port not in ports:
raise ValueError(f"port2 {port!r} not in {list(ports)}")
fig = fig or plt.subplot()
ax = fig.axes
for port2 in ports2:
if (port1, port2) in sdict:
if phase:
y = np.angle(sdict[(port1, port2)])
ylabel = "angle (rad)"
else:
y = np.abs(sdict[(port1, port2)])
y = 20 * np.log10(y) if logscale else y
ylabel = "|S (dB)|" if logscale else "|S|"
ax.plot(wavelengths, y, label=f"{port1}→{port2}")
if logscale:
current_ylim = ax.get_ylim()
if current_ylim[1] - current_ylim[0] < min_db_range:
ax.set_ylim(y.mean() - min_db_range / 2, y.mean() + min_db_range / 2)
if title:
ax.set_title(title)
else:
# Handle functools.partial objects
if hasattr(model, "func"):
# It's a partial object
model_name = getattr(model.func, "__name__", "model")
else:
# Regular function
model_name = getattr(model, "__name__", "model")
ax.set_title(f"{model_name} S-Parameters")
ax.set_xlabel("wavelength (µm)")
ax.set_ylabel(ylabel)
plt.legend()
return ax
if __name__ == "__main__":
import gplugins.sax as gs
plot_model(gs.models.straight, phase=True, port1="o1", ports2=("o1", "o2"))
plt.show()