SAX circuit simulator#

SAX is a circuit solver written in JAX, writing your component models in SAX enables you not only to get the function values but the gradients, this is useful for circuit optimization.

This tutorial has been adapted from SAX tutorial.

Note that SAX does not work on Windows, so if you use windows you’ll need to run from WSL or using docker.

You can install sax with pip

! pip install sax
[1]:
import gdsfactory as gf
import gdsfactory.simulation.sax as gs
import gdsfactory.simulation.modes as gm
import matplotlib.pyplot as plt
import sax
2022-06-28 17:04:29.360 | INFO     | gdsfactory.config:<module>:52 - Load '/home/runner/work/gdsfactory/gdsfactory/gdsfactory' 5.11.4
2022-06-28 17:04:30.961 | INFO     | gdsfactory.simulation.gmeep:<module>:28 - Meep '1.23.0' installed at ['/usr/share/miniconda/envs/anaconda-client-env/lib/python3.9/site-packages/meep']
Using MPI version 4.0, 1 processes

Scatter dictionaries#

The core datastructure for specifying scatter parameters in SAX is a dictionary… more specifically a dictionary which maps a port combination (2-tuple) to a scatter parameter (or an array of scatter parameters when considering multiple wavelengths for example). Such a specific dictionary mapping is called ann SDict in SAX (SDict Dict[Tuple[str,str], float]).

Dictionaries are in fact much better suited for characterizing S-parameters than, say, (jax-)numpy arrays due to the inherent sparse nature of scatter parameters. Moreover, dictionaries allow for string indexing, which makes them much more pleasant to use in this context.

o2            o3
   \        /
    ========
   /        \
o1            o4
[2]:
coupling = 0.5
kappa = coupling**0.5
tau = (1 - coupling) ** 0.5
coupler_dict = {
    ("o1", "o4"): tau,
    ("o4", "o1"): tau,
    ("o1", "o3"): 1j * kappa,
    ("o3", "o1"): 1j * kappa,
    ("o2", "o4"): 1j * kappa,
    ("o4", "o2"): 1j * kappa,
    ("o2", "o3"): tau,
    ("o3", "o2"): tau,
}
coupler_dict
[2]:
{('o1', 'o4'): 0.7071067811865476,
 ('o4', 'o1'): 0.7071067811865476,
 ('o1', 'o3'): 0.7071067811865476j,
 ('o3', 'o1'): 0.7071067811865476j,
 ('o2', 'o4'): 0.7071067811865476j,
 ('o4', 'o2'): 0.7071067811865476j,
 ('o2', 'o3'): 0.7071067811865476,
 ('o3', 'o2'): 0.7071067811865476}

it can still be tedious to specify every port in the circuit manually. SAX therefore offers the reciprocal function, which auto-fills the reverse connection if the forward connection exist. For example:

[3]:
coupler_dict = sax.reciprocal(
    {
        ("o1", "o4"): tau,
        ("o1", "o3"): 1j * kappa,
        ("o2", "o4"): 1j * kappa,
        ("o2", "o3"): tau,
    }
)

coupler_dict
[3]:
{('o1', 'o4'): 0.7071067811865476,
 ('o1', 'o3'): 0.7071067811865476j,
 ('o2', 'o4'): 0.7071067811865476j,
 ('o2', 'o3'): 0.7071067811865476,
 ('o4', 'o1'): 0.7071067811865476,
 ('o3', 'o1'): 0.7071067811865476j,
 ('o4', 'o2'): 0.7071067811865476j,
 ('o3', 'o2'): 0.7071067811865476}

Parametrized Models#

Constructing such an SDict is easy, however, usually we’re more interested in having parametrized models for our components. To parametrize the coupler SDict, just wrap it in a function to obtain a SAX Model, which is a keyword-only function mapping to an SDict:

[4]:
def coupler(coupling=0.5) -> sax.SDict:
    kappa = coupling**0.5
    tau = (1 - coupling) ** 0.5
    return sax.reciprocal(
        {
            ("o1", "o4"): tau,
            ("o1", "o3"): 1j * kappa,
            ("o2", "o4"): 1j * kappa,
            ("o2", "o3"): tau,
        }
    )


coupler(coupling=0.3)
[4]:
{('o1', 'o4'): 0.8366600265340756,
 ('o1', 'o3'): 0.5477225575051661j,
 ('o2', 'o4'): 0.5477225575051661j,
 ('o2', 'o3'): 0.8366600265340756,
 ('o4', 'o1'): 0.8366600265340756,
 ('o3', 'o1'): 0.5477225575051661j,
 ('o4', 'o2'): 0.5477225575051661j,
 ('o3', 'o2'): 0.8366600265340756}
[5]:
def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0) -> sax.SDict:
    dwl = wl - wl0
    dneff_dwl = (ng - neff) / wl0
    neff = neff - dwl * dneff_dwl
    phase = 2 * jnp.pi * neff * length / wl
    transmission = 10 ** (-loss * length / 20) * jnp.exp(1j * phase)
    return sax.reciprocal(
        {
            ("o1", "o2"): transmission,
        }
    )

Component Models#

Waveguide model#

You can create a dispersive waveguide model in SAX.

Lets compute the effective index neff and group index ng for a 1550nm 500nm straight waveguide

[6]:
m = gm.find_mode_dispersion(wavelength=1.55)
print(m.neff, m.ng)
2.3638584286954814 4.251242165908622
[7]:
straight_sc = gf.partial(gs.models.straight, neff=m.neff, ng=m.ng)
[8]:
gs.plot_model(straight_sc)
plt.ylim(-1, 1)
/usr/share/miniconda/envs/anaconda-client-env/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:1882: UserWarning: Explicitly requested dtype <class 'complex'> requested in asarray is not available, and will be truncated to dtype complex64. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  lax_internal._check_user_dtype_supported(dtype, "asarray")
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[8]:
(-1.0, 1.0)
../../../_images/notebooks_plugins_sax_sax_13_2.png
[9]:
gs.plot_model(straight_sc, phase=True)
[9]:
<AxesSubplot:title={'center':'o1'}, xlabel='wavelength (nm)', ylabel='angle (rad)'>
../../../_images/notebooks_plugins_sax_sax_14_1.png

Coupler model#

[10]:
gm.find_coupling_vs_gap?
[11]:
df = gm.find_coupling_vs_gap()
df
100%|██████████| 12/12 [00:00<00:00, 374.44it/s]
[11]:
gap ne no lc dn
0 0.200000 2.457894 2.437607 38.201198 0.020287
1 0.218182 2.459028 2.441121 43.278893 0.017907
2 0.236364 2.452195 2.437659 53.318206 0.014535
3 0.254545 2.451946 2.439549 62.514018 0.012397
4 0.272727 2.451248 2.441058 76.053858 0.010190
5 0.290909 2.451491 2.442325 84.553535 0.009166
6 0.309091 2.449447 2.441517 97.723696 0.007931
7 0.327273 2.447256 2.440698 118.168160 0.006558
8 0.345455 2.451715 2.446248 141.751936 0.005467
9 0.363636 2.451128 2.446210 157.608710 0.004917
10 0.381818 2.445577 2.441408 185.893679 0.004169
11 0.400000 2.445128 2.441678 224.665375 0.003450

For a 200nm gap the effective index difference dn is 0.02, which means that there is 100% power coupling over 38.2um

[12]:
coupler_sc = gf.partial(gs.models.coupler, dn=0.02, length=0, coupling0=0)
gs.plot_model(coupler_sc)
[12]:
<AxesSubplot:title={'center':'o1'}, xlabel='wavelength (nm)', ylabel='|S (dB)|'>
../../../_images/notebooks_plugins_sax_sax_19_1.png

If we ignore the coupling from the bend coupling0 = 0 we know that for a 3dB coupling we need half of the lc length, which is the length needed to coupler 100% of power.

[13]:
coupler_sc = gf.partial(gs.models.coupler, dn=0.02, length=38.2 / 2, coupling0=0)
gs.plot_model(coupler_sc)
[13]:
<AxesSubplot:title={'center':'o1'}, xlabel='wavelength (nm)', ylabel='|S (dB)|'>
../../../_images/notebooks_plugins_sax_sax_21_1.png

FDTD Sparameters model#

You can also fit a model from Sparameter FDTD simulation data.

[14]:
from gdsfactory.simulation.get_sparameters_path import get_sparameters_path_lumerical

filepath = get_sparameters_path_lumerical(gf.c.mmi1x2)
mmi1x2 = gf.partial(gs.read.sdict_from_csv, filepath=filepath)
gs.plot_model(mmi1x2)
[14]:
<AxesSubplot:title={'center':'o1'}, xlabel='wavelength (nm)', ylabel='|S (dB)|'>
../../../_images/notebooks_plugins_sax_sax_23_1.png

Circuit Models#

You can combine component models into a circuit using sax.circuit, which basically creates a new Model function:

Lets define an MZI interferometer

  • two couplers (rgt, lft) right and left

  • two waveguides (top, bot)

           _________
          |  top    |
          |         |
    lft===|         |===rgt
          |         |
          |_________|
             bot

               o1    top   o2
                 ----------
o2            o3           o2            o3
   \        /                 \        /
    ========                   ========
   /        \                 /        \
o1     lft    04           o1    rgt     04
                 ----------
               o1   bot    o2
[15]:
waveguide = straight_sc
coupler = coupler_sc

mzi = sax.circuit(
    instances={
        "lft": coupler,
        "top": waveguide,
        "bot": waveguide,
        "rgt": coupler,
    },
    connections={
        "lft,o4": "bot,o1",
        "bot,o2": "rgt,o1",
        "lft,o3": "top,o1",
        "top,o2": "rgt,o2",
    },
    ports={
        "o1": "lft,o1",
        "o2": "lft,o2",
        "o4": "rgt,o4",
        "o3": "rgt,o3",
    },
)

The circuit function just creates a similar function as we created for the waveguide and the coupler, but instead of taking parameters directly it takes parameter dictionaries for each of the instances in the circuit. The keys in these parameter dictionaries should correspond to the keyword arguments of each individual subcomponent.

You can simulate the MZI

[16]:
%time mzi()
/usr/share/miniconda/envs/anaconda-client-env/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:1882: UserWarning: Explicitly requested dtype <class 'complex'> requested in asarray is not available, and will be truncated to dtype complex64. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  lax_internal._check_user_dtype_supported(dtype, "asarray")
CPU times: user 2.22 s, sys: 0 ns, total: 2.22 s
Wall time: 2.22 s
[16]:
{('o2', 'o2'): DeviceArray(0.+0.j, dtype=complex64),
 ('o2', 'o1'): DeviceArray(0.+0.j, dtype=complex64),
 ('o1', 'o2'): DeviceArray(0.+0.j, dtype=complex64),
 ('o1', 'o1'): DeviceArray(0.+0.j, dtype=complex64),
 ('o2', 'o4'): DeviceArray(0.9997418+0.00439812j, dtype=complex64),
 ('o2', 'o3'): DeviceArray(-9.807199e-05+0.02229288j, dtype=complex64),
 ('o1', 'o4'): DeviceArray(-9.807199e-05+0.02229288j, dtype=complex64),
 ('o1', 'o3'): DeviceArray(0.9997418+0.00439812j, dtype=complex64),
 ('o4', 'o2'): DeviceArray(0.9997418+0.00439812j, dtype=complex64),
 ('o4', 'o1'): DeviceArray(-9.807199e-05+0.02229288j, dtype=complex64),
 ('o4', 'o4'): DeviceArray(0.+0.j, dtype=complex64),
 ('o4', 'o3'): DeviceArray(0.+0.j, dtype=complex64),
 ('o3', 'o2'): DeviceArray(-9.807199e-05+0.02229288j, dtype=complex64),
 ('o3', 'o1'): DeviceArray(0.9997418+0.00439812j, dtype=complex64),
 ('o3', 'o4'): DeviceArray(0.+0.j, dtype=complex64),
 ('o3', 'o3'): DeviceArray(0.+0.j, dtype=complex64)}
[17]:
import jax
import jax.example_libraries.optimizers as opt
import jax.numpy as jnp
import matplotlib.pyplot as plt  # plotting

mzi2 = jax.jit(mzi)
[18]:
%time mzi2()
CPU times: user 2.81 s, sys: 16.2 ms, total: 2.82 s
Wall time: 2.78 s
[18]:
{('o1', 'o1'): DeviceArray(0.+0.j, dtype=complex64),
 ('o1', 'o2'): DeviceArray(0.+0.j, dtype=complex64),
 ('o1', 'o3'): DeviceArray(0.9997418+0.00439812j, dtype=complex64),
 ('o1', 'o4'): DeviceArray(-9.807199e-05+0.02229288j, dtype=complex64),
 ('o2', 'o1'): DeviceArray(0.+0.j, dtype=complex64),
 ('o2', 'o2'): DeviceArray(0.+0.j, dtype=complex64),
 ('o2', 'o3'): DeviceArray(-9.807199e-05+0.02229288j, dtype=complex64),
 ('o2', 'o4'): DeviceArray(0.9997418+0.00439812j, dtype=complex64),
 ('o3', 'o1'): DeviceArray(0.9997418+0.00439812j, dtype=complex64),
 ('o3', 'o2'): DeviceArray(-9.807199e-05+0.02229288j, dtype=complex64),
 ('o3', 'o3'): DeviceArray(0.+0.j, dtype=complex64),
 ('o3', 'o4'): DeviceArray(0.+0.j, dtype=complex64),
 ('o4', 'o1'): DeviceArray(-9.807199e-05+0.02229288j, dtype=complex64),
 ('o4', 'o2'): DeviceArray(0.9997418+0.00439812j, dtype=complex64),
 ('o4', 'o3'): DeviceArray(0.+0.j, dtype=complex64),
 ('o4', 'o4'): DeviceArray(0.+0.j, dtype=complex64)}
[19]:
mzi(top={"length": 25.0}, btm={"length": 15.0})
[19]:
{('o2', 'o2'): DeviceArray(0.+0.j, dtype=complex64),
 ('o2', 'o1'): DeviceArray(0.+0.j, dtype=complex64),
 ('o1', 'o2'): DeviceArray(0.+0.j, dtype=complex64),
 ('o1', 'o1'): DeviceArray(0.+0.j, dtype=complex64),
 ('o2', 'o4'): DeviceArray(0.85720646-0.34735364j, dtype=complex64),
 ('o2', 'o3'): DeviceArray(0.35958472-0.12345627j, dtype=complex64),
 ('o1', 'o4'): DeviceArray(-0.34409368+0.1616854j, dtype=complex64),
 ('o1', 'o3'): DeviceArray(0.85720634-0.34735364j, dtype=complex64),
 ('o4', 'o2'): DeviceArray(0.85720646-0.34735364j, dtype=complex64),
 ('o4', 'o1'): DeviceArray(-0.34409368+0.1616854j, dtype=complex64),
 ('o4', 'o4'): DeviceArray(0.+0.j, dtype=complex64),
 ('o4', 'o3'): DeviceArray(0.+0.j, dtype=complex64),
 ('o3', 'o2'): DeviceArray(0.35958472-0.12345627j, dtype=complex64),
 ('o3', 'o1'): DeviceArray(0.85720634-0.34735364j, dtype=complex64),
 ('o3', 'o4'): DeviceArray(0.+0.j, dtype=complex64),
 ('o3', 'o3'): DeviceArray(0.+0.j, dtype=complex64)}
[20]:
wl = jnp.linspace(1.51, 1.59, 1000)
%time S = mzi(wl=wl, top={"length": 25.0}, btm={"length": 15.0})
CPU times: user 2.46 s, sys: 7.59 ms, total: 2.46 s
Wall time: 2.46 s
[21]:
plt.plot(wl * 1e3, abs(S["o1", "o3"]) ** 2, label="o3")
plt.plot(wl * 1e3, abs(S["o1", "o4"]) ** 2, label="o4")
plt.ylim(-0.05, 1.05)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.ylim(-0.05, 1.05)
plt.legend()
plt.show()
../../../_images/notebooks_plugins_sax_sax_32_0.png

Optimization#

You can optimize an MZI to get T=0 at 1550nm. To do this, you need to define a loss function for the circuit at 1550nm. This function should take the parameters that you want to optimize as positional arguments:

[22]:
@jax.jit
def loss(delta_length):
    S = mzi(wl=1.55, top={"length": 15.0 + delta_length}, btm={"length": 15.0})
    return (abs(S["o1", "o4"]) ** 2).mean()
[23]:
%time loss(10.0)
CPU times: user 2.66 s, sys: 11.9 ms, total: 2.67 s
Wall time: 2.64 s
[23]:
DeviceArray(0.14453728, dtype=float32)

You can use this loss function to define a grad function which works on the parameters of the loss function:

[24]:
grad = jax.jit(
    jax.grad(
        loss,
        argnums=0,  # JAX gradient function for the first positional argument, jitted
    )
)

Next, you need to define a JAX optimizer, which on its own is nothing more than three more functions:

  1. an initialization function with which to initialize the optimizer state

  2. an update function which will update the optimizer state (and with it the model parameters).

  3. a function with the model parameters given the optimizer state.

[25]:
initial_delta_length = 10.0
optim_init, optim_update, optim_params = opt.adam(step_size=0.1)
optim_state = optim_init(initial_delta_length)
[26]:
def train_step(step, optim_state):
    settings = optim_params(optim_state)
    lossvalue = loss(settings)
    gradvalue = grad(settings)
    optim_state = optim_update(step, gradvalue, optim_state)
    return lossvalue, optim_state
[27]:
import tqdm

range_ = tqdm.trange(300)
for step in range_:
    lossvalue, optim_state = train_step(step, optim_state)
    range_.set_postfix(loss=f"{lossvalue:.6f}")
100%|██████████| 300/300 [00:15<00:00, 19.78it/s, loss=0.000497]
[28]:
delta_length = optim_params(optim_state)
delta_length
[28]:
DeviceArray(10.081275, dtype=float32)
[29]:
S = mzi(wl=wl, top={"length": 15.0 + delta_length}, btm={"length": 15.0})
plt.plot(wl * 1e3, abs(S["o1", "o4"]) ** 2)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.ylim(-0.05, 1.05)
plt.plot([1550, 1550], [0, 1])
plt.show()
../../../_images/notebooks_plugins_sax_sax_43_0.png

The minimum of the MZI is perfectly located at 1550nm.

Model fit#

You can fit a sax model to Sparameter FDTD simulation data.

[30]:
import tqdm
import jax
import jax.numpy as jnp
import jax.example_libraries.optimizers as opt
import matplotlib.pyplot as plt

import gdsfactory as gf
import gdsfactory.simulation.modes as gm
import gdsfactory.simulation.sax as gs
[31]:
gf.config.sparameters_path
[31]:
PosixPath('/home/runner/work/gdsfactory/gdsfactory/gdslib/sp')
[32]:
sd = gs.read.sdict_from_csv(
    gf.config.sparameters_path / "coupler" / "coupler_G224n_L20_S220.csv",
    xkey="wavelength_nm",
    prefix="S",
    xunits=1e-3,
)
[33]:
coupler_fdtd = gf.partial(
    gs.read.sdict_from_csv,
    filepath=gf.config.sparameters_path / "coupler" / "coupler_G224n_L20_S220.csv",
    xkey="wavelength_nm",
    prefix="S",
    xunits=1e-3,
)
[34]:
gs.plot_model(coupler_fdtd)
[34]:
<AxesSubplot:title={'center':'o1'}, xlabel='wavelength (nm)', ylabel='|S (dB)|'>
../../../_images/notebooks_plugins_sax_sax_50_1.png

Model fit (linear regression)#

Lets fit the coupler spectrum with a linear regression sklearn fit

[35]:
import sax
import gdsfactory as gf
import gdsfactory.simulation.sax as gs
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from scipy.constants import c
from sklearn.linear_model import LinearRegression
[36]:
f = jnp.linspace(c / 1.0e-6, c / 2.0e-6, 500) * 1e-12  # THz
wl = c / (f * 1e12) * 1e6  # um

filepath = gf.config.sparameters_path / "coupler" / "coupler_G224n_L20_S220.csv"
coupler_fdtd = gf.partial(
    gs.read.sdict_from_csv, filepath, xkey="wavelength_nm", prefix="S", xunits=1e-3
)
sd = coupler_fdtd(wl=wl)

k = sd["o1", "o3"]
t = sd["o1", "o4"]
s = t + k
a = t - k

Lets fit the symmetric (t+k) and antisymmetric (t-k) transmission

Symmetric#

[37]:
plt.plot(wl, jnp.abs(s))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
plt.title("symmetric (transmission + coupling)")
plt.legend()
plt.show()
WARNING:matplotlib.legend:No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
../../../_images/notebooks_plugins_sax_sax_55_1.png
[38]:
plt.plot(wl, jnp.abs(a))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
plt.title("anti-symmetric (transmission - coupling)")
plt.legend()
plt.show()
WARNING:matplotlib.legend:No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
../../../_images/notebooks_plugins_sax_sax_56_1.png
[39]:
r = LinearRegression()
fX = lambda x, _order=8: x[:, None] ** (
    jnp.arange(_order)[None, :]
)  # artificially create more 'features' (wl**2, wl**3, wl**4, ...)
X = fX(wl)
r.fit(X, jnp.abs(s))
asm, bsm = r.coef_, r.intercept_
fsm = lambda x: fX(x) @ asm + bsm  # fit symmetric module fiir

plt.plot(wl, jnp.abs(s))
plt.plot(wl, fsm(wl))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
plt.legend()
plt.show()
WARNING:matplotlib.legend:No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
../../../_images/notebooks_plugins_sax_sax_57_1.png
[40]:
r = LinearRegression()
r.fit(X, jnp.unwrap(jnp.angle(s)))
asp, bsp = r.coef_, r.intercept_
fsp = lambda x: fX(x) @ asp + bsp  # fit symmetric phase

plt.plot(wl, jnp.unwrap(jnp.angle(s)))
plt.plot(wl, fsp(wl))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Angle [deg]")
plt.legend()
plt.show()
WARNING:matplotlib.legend:No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
../../../_images/notebooks_plugins_sax_sax_58_1.png
[41]:
fs = lambda x: fsm(x) * jnp.exp(1j * fsp(x))

Lets fit the symmetric (t+k) and antisymmetric (t-k) transmission

Anti-Symmetric#

[42]:
r = LinearRegression()
r.fit(X, jnp.abs(a))
aam, bam = r.coef_, r.intercept_
fam = lambda x: fX(x) @ aam + bam

plt.plot(wl, jnp.abs(a))
plt.plot(wl, fam(wl))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
plt.legend()
plt.show()
WARNING:matplotlib.legend:No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
../../../_images/notebooks_plugins_sax_sax_61_1.png
[43]:
r = LinearRegression()
r.fit(X, jnp.unwrap(jnp.angle(a)))
aap, bap = r.coef_, r.intercept_
fap = lambda x: fX(x) @ aap + bap

plt.plot(wl, jnp.unwrap(jnp.angle(a)))
plt.plot(wl, fap(wl))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Angle [deg]")
plt.legend()
plt.show()
WARNING:matplotlib.legend:No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
../../../_images/notebooks_plugins_sax_sax_62_1.png
[44]:
fa = lambda x: fam(x) * jnp.exp(1j * fap(x))

Total#

[45]:
t_ = 0.5 * (fs(wl) + fa(wl))

plt.plot(wl, jnp.abs(t))
plt.plot(wl, jnp.abs(t_))
plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
[45]:
Text(0, 0.5, 'Transmission')
../../../_images/notebooks_plugins_sax_sax_65_1.png
[46]:
k_ = 0.5 * (fs(wl) - fa(wl))

plt.plot(wl, jnp.abs(k))
plt.plot(wl, jnp.abs(k_))
plt.xlabel("Frequency [THz]")
plt.ylabel("Coupling")
[46]:
Text(0, 0.5, 'Coupling')
../../../_images/notebooks_plugins_sax_sax_66_1.png
[47]:
@jax.jit
def coupler(wl=1.5):
    wl = jnp.asarray(wl)
    wl_shape = wl.shape
    wl = wl.ravel()
    t = (0.5 * (fs(wl) + fa(wl))).reshape(*wl_shape)
    k = (0.5 * (fs(wl) - fa(wl))).reshape(*wl_shape)
    sdict = {
        ("o1", "o4"): t,
        ("o1", "o3"): k,
        ("o2", "o3"): k,
        ("o2", "o4"): t,
    }
    return sax.reciprocal(sdict)
[48]:
f = jnp.linspace(c / 1.0e-6, c / 2.0e-6, 500) * 1e-12  # THz
wl = c / (f * 1e12) * 1e6  # um

filepath = gf.config.sparameters_path / "coupler" / "coupler_G224n_L20_S220.csv"
coupler_fdtd = gf.partial(
    gs.read.sdict_from_csv, filepath, xkey="wavelength_nm", prefix="S", xunits=1e-3
)
sd = coupler_fdtd(wl=wl)
sd_ = coupler(wl=wl)

T = jnp.abs(sd["o1", "o4"]) ** 2
K = jnp.abs(sd["o1", "o3"]) ** 2
T_ = jnp.abs(sd_["o1", "o4"]) ** 2
K_ = jnp.abs(sd_["o1", "o3"]) ** 2
dP = jnp.unwrap(jnp.angle(sd["o1", "o3"]) - jnp.angle(sd["o1", "o4"]))
dP_ = jnp.unwrap(jnp.angle(sd_["o1", "o3"]) - jnp.angle(sd_["o1", "o4"]))

plt.figure(figsize=(12, 3))
plt.plot(wl, T, label="T (fdtd)", c="C0", ls=":", lw="6")
plt.plot(wl, T_, label="T (model)", c="C0")

plt.plot(wl, K, label="K (fdtd)", c="C1", ls=":", lw="6")
plt.plot(wl, K_, label="K (model)", c="C1")

plt.ylim(-0.05, 1.05)
plt.grid(True)

plt.twinx()
plt.plot(wl, dP, label="ΔΦ (fdtd)", color="C2", ls=":", lw="6")
plt.plot(wl, dP_, label="ΔΦ (model)", color="C2")

plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
plt.figlegend(bbox_to_anchor=(1.08, 0.9))
plt.savefig("fdtd_vs_model.png", bbox_inches="tight")
plt.show()
../../../_images/notebooks_plugins_sax_sax_68_0.png

SAX gdsfactory Compatibility#

From Layout to Circuit Model

If you define your SAX S parameter models for your components, you can directly simulate your circuits from gdsfactory

[49]:
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
from omegaconf import OmegaConf
import sax
from pprint import pprint

import gdsfactory as gf
from gdsfactory.get_netlist import get_netlist as _get_netlist
[50]:
mzi = gf.components.mzi(delta_length=10)
mzi
../../../_images/notebooks_plugins_sax_sax_71_0.png
[50]:
mzi_delta_length10: uid 6, ports ['o1', 'o2'], aliases [], 0 polygons, 20 references
[51]:
netlist = mzi.get_netlist_dict()
pprint(netlist["connections"])
{'bend_euler_20p624_5p5,o1': 'mmi1x2_2p75_0p0,o2',
 'bend_euler_20p624_5p5,o2': 'straight_length2p0_25p5_11p625,o1',
 'bend_euler_20p624_m5p5,o1': 'mmi1x2_2p75_0p0,o3',
 'bend_euler_20p624_m5p5,o2': 'straight_length7p0_25p5_m14p125,o1',
 'bend_euler_30p376_17p75,o1': 'straight_length0p1_35p55_22p625,o1',
 'bend_euler_30p376_17p75,o2': 'straight_length2p0_25p5_11p625,o2',
 'bend_euler_30p376_m22p749,o1': 'straight_length7p0_25p5_m14p125,o2',
 'bend_euler_30p376_m22p749,o2': 'straight_length0p1_35p55_m27p625,o1',
 'bend_euler_40p816_17p75,o1': 'straight_381a0e44_35p646_22p625,o2',
 'bend_euler_40p816_17p75,o2': 'straight_45efaf1e_45p691_11p625,o1',
 'bend_euler_40p816_m22p749,o1': 'straight_381a0e44_35p646_m27p625,o2',
 'bend_euler_40p816_m22p749,o2': 'straight_4af39711_45p691_m14p125,o1',
 'bend_euler_50p567_5p5,o1': 'straight_45efaf1e_45p691_11p625,o2',
 'bend_euler_50p567_5p5,o2': 'straight_77b336e2_55p696_0p625,o1',
 'bend_euler_50p567_m5p5,o1': 'straight_4af39711_45p691_m14p125,o2',
 'bend_euler_50p567_m5p5,o2': 'straight_77b336e2_55p696_m0p625,o1',
 'mmi1x2_68p451_0p0,o2': 'straight_77b336e2_55p696_0p625,o2',
 'mmi1x2_68p451_0p0,o3': 'straight_77b336e2_55p696_m0p625,o2',
 'straight_381a0e44_35p646_22p625,o1': 'straight_length0p1_35p55_22p625,o2',
 'straight_381a0e44_35p646_m27p625,o1': 'straight_length0p1_35p55_m27p625,o2'}

The netlist has three different components:

  1. straight

  2. mmi1x2

  3. bend_euler

You need models for each subcomponents to simulate the Component.

[52]:
def straight(wl=1.5, length=10.0, neff=2.4) -> sax.SDict:
    wl0 = 1.5  # center wavelength for which the waveguide model is defined
    return sax.reciprocal({("o1", "o2"): jnp.exp(2j * jnp.pi * neff * length / wl)})


def mmi1x2():
    """Assumes a perfect 1x2 splitter"""
    return sax.reciprocal(
        {
            ("o1", "o2"): 0.5**0.5,
            ("o1", "o3"): 0.5**0.5,
        }
    )


def bend_euler(wl=1.5, length=20.0):
    """ "Let's assume a reduced transmission for the euler bend compared to a straight"""
    return {k: 0.99 * v for k, v in straight(wl=wl, length=length).items()}
[53]:
circuit = sax.circuit_from_netlist(
    netlist=netlist,
    models={
        "bend_euler": bend_euler,
        "mmi1x2": mmi1x2,
        "straight": straight,
    },
)
[54]:
wl = np.linspace(1.5, 1.6)
S = circuit(wl=wl)

plt.figure(figsize=(14, 4))
plt.title("MZI")
plt.plot(1e3 * wl, jnp.abs(S["o1", "o2"]) ** 2)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.grid(True)
plt.show()
../../../_images/notebooks_plugins_sax_sax_76_0.png
[55]:
mzi = gf.components.mzi(delta_length=20)  # Double the length, reduces FSR by 1/2
mzi
../../../_images/notebooks_plugins_sax_sax_77_0.png
[55]:
mzi_delta_length20: uid 20, ports ['o1', 'o2'], aliases [], 0 polygons, 20 references
[56]:
circuit = sax.circuit_from_netlist(
    netlist=mzi.get_netlist_dict(),
    models={
        "bend_euler": bend_euler,
        "mmi1x2": mmi1x2,
        "straight": straight,
    },
)

wl = np.linspace(1.5, 1.6, 256)
S = circuit(wl=wl)

plt.figure(figsize=(14, 4))
plt.title("MZI")
plt.plot(1e3 * wl, jnp.abs(S["o1", "o2"]) ** 2)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.grid(True)
plt.show()
../../../_images/notebooks_plugins_sax_sax_78_0.png
[ ]: