Skip to content

Photonic CMZ Wavelength Demultiplexer — Cascaded Mach-Zehnder Lattice Filter¤

This notebook demonstrates how a Cascaded Mach-Zehnder (CMZ) lattice filter architecture from Horst et al. 2013 (Optics Express, DOI: 10.1364/OE.21.011652) achieves flat passbands and superior stopband rejection compared to a simple MZI binary tree.

Why CMZ beats a simple MZI¤

A single-stage MZI has a sinusoidal transfer function — its passband rolls off gradually, so channels at the passband edges leak into neighbouring detectors. The CMZ topology replaces the root splitter node with a multi-stage lattice filter: a chain of directional couplers interleaved with waveguide arm pairs of increasing path-length difference. This creates a Chebyshev-like response with a flat top and sharp skirts.

Schematic¤

png

Design equations (Horst et al. 2013)¤

For channels spaced δλ = 15 nm around λ_c = 1307.5 nm with n_gr = 4, n_eff = 2.4:

\[\Delta L_{\rm Base} = \frac{\lambda_c^2}{2\, \delta\lambda\, n_{\rm gr}} = \frac{1307.5^2}{2 \times 15 \times 4 \times 1000} = 14.246\; \mu{\rm m}\]
\[\Delta L_{\rm FS} = \frac{\lambda_c}{n_{\rm eff}} = \frac{1307.5}{2.4 \times 1000} = 0.5448\; \mu{\rm m}\]
Element Long arm (µm) Short arm (µm) ΔL
Root stage 1 264.246 250.0 ΔL_Base = 14.246
Root stage 2 278.492 250.0 2×ΔL_Base = 28.492
Leaf A 257.123 250.0 ΔL_Base/2 = 7.123
Leaf B 257.532 250.0 ΔL_Base/2 + 0.75×ΔL_FS = 7.532

The root coupling coefficients (0.5, 0.29, 0.08) are the Horst et al. analytic optimum for a flat-top 2-stage Chebyshev filter response.

Trainable parameters (11 total)¤

  • 7 DC coupling coefficients: k_r1, k_r2, k_r3 (root) + k_a1, k_a2 (leaf A) + k_b1, k_b2 (leaf B)
  • 4 long arm lengths: L_r1a, L_r2a (root stages) + L_a1 (leaf A) + L_b1 (leaf B)

Paper-derived initial values use coupling (0.5, 0.29, 0.08) for the root — designed analytically for a flat Chebyshev passband — and equal-split (0.5) for leaves.

Backpropagation implementation¤

With 11 tunable parameters, finite-difference gradient estimation costs 12 circuit solves per gradient step. jax.grad computes the exact gradient for all 11 parameters in ~2 solves (forward + backward). For a photonic integrated circuit with 100+ elements, backprop is the only practical route to optimisation.

import jax
import jax.numpy as jnp
import numpy as np
import optax
from circulax import compile_circuit
from circulax.components.electronic import Resistor
from circulax.components.photonic import DirectionalCoupler, OpticalSource, OpticalWaveguide
from circulax.utils import update_group_params, update_params_dict

# 64-bit precision is critical for photonic circuits: small phase differences
# between arm lengths produce small field changes, and gradients can be tiny.
jax.config.update("jax_enable_x64", True)


# ── Component model registry ──────────────────────────────────────────────
models = {
    "ground":    lambda: 0,
    "source":    OpticalSource,
    "waveguide": OpticalWaveguide,
    "dc":        DirectionalCoupler,
    "resistor":  Resistor,
}

# ── Design parameters from Horst et al. 2013 ───────────────────────────────
# λ_center = mean of [1285, 1300, 1315, 1330] nm
# δλ = 15 nm channel spacing; n_eff = 2.4; n_gr = 4 (group index)
DL_BASE = 14.246   # µm  — ΔL_Base = λ² / (2·δλ·n_gr·1000)
DL_FS   = 0.5448   # µm  — ΔL_FS   = λ / (n_eff·1000)

# Arm lengths: all short arms fixed at 250.0 µm
L_R1A_INIT = 250.0 + DL_BASE           # 264.246 µm  (root stage 1 long arm)
L_R2A_INIT = 250.0 + 2 * DL_BASE      # 278.492 µm  (root stage 2 long arm)
L_A1_INIT  = 250.0 + DL_BASE / 2      # 257.123 µm  (leaf A long arm)
L_B1_INIT  = 250.0 + DL_BASE / 2 + 0.75 * DL_FS  # 257.532 µm  (leaf B long arm)

# ── SAX-format netlist ─────────────────────────────────────────────────────
#
# CMZ topology:
#   Root: DC_r1 → WG_r1a/WG_r1b → DC_r2 → WG_r2a/WG_r2b → DC_r3  (2-stage lattice)
#   Leaf A: DC_a1 → WG_a1/WG_a2 → DC_a2 → Det_1/Det_3             (1-stage MZI)
#   Leaf B: DC_b1 → WG_b1/WG_b2 → DC_b2 → Det_2/Det_4             (1-stage MZI)
#
# Coupling coefficients: root uses paper-derived (0.5, 0.29, 0.08);
# leaf DCs start at 0.5 (equal split).
net_dict = {
    "instances": {
        "GND":    {"component": "ground"},
        "Laser":  {"component": "source",    "settings": {"power": 1.0, "phase": 0.0}},
        # Root 2-stage lattice filter (3 DCs + 2 waveguide arm pairs)
        "DC_r1":  {"component": "dc",        "settings": {"coupling": 0.5}},
        "DC_r2":  {"component": "dc",        "settings": {"coupling": 0.29}},
        "DC_r3":  {"component": "dc",        "settings": {"coupling": 0.08}},
        "WG_r1a": {"component": "waveguide", "settings": {"length_um": L_R1A_INIT, "neff": 2.4, "loss_dB_cm": 0.0, "center_wavelength_nm": 1310.0}},
        "WG_r1b": {"component": "waveguide", "settings": {"length_um": 250.0,       "neff": 2.4, "loss_dB_cm": 0.0, "center_wavelength_nm": 1310.0}},
        "WG_r2a": {"component": "waveguide", "settings": {"length_um": L_R2A_INIT, "neff": 2.4, "loss_dB_cm": 0.0, "center_wavelength_nm": 1310.0}},
        "WG_r2b": {"component": "waveguide", "settings": {"length_um": 250.0,       "neff": 2.4, "loss_dB_cm": 0.0, "center_wavelength_nm": 1310.0}},
        # Leaf A (bar output of root → channels 1285 + 1315 nm)
        "DC_a1":  {"component": "dc",        "settings": {"coupling": 0.5}},
        "DC_a2":  {"component": "dc",        "settings": {"coupling": 0.5}},
        "WG_a1":  {"component": "waveguide", "settings": {"length_um": L_A1_INIT, "neff": 2.4, "loss_dB_cm": 0.0, "center_wavelength_nm": 1310.0}},
        "WG_a2":  {"component": "waveguide", "settings": {"length_um": 250.0,      "neff": 2.4, "loss_dB_cm": 0.0, "center_wavelength_nm": 1310.0}},
        # Leaf B (cross output of root → channels 1300 + 1330 nm)
        "DC_b1":  {"component": "dc",        "settings": {"coupling": 0.5}},
        "DC_b2":  {"component": "dc",        "settings": {"coupling": 0.5}},
        "WG_b1":  {"component": "waveguide", "settings": {"length_um": L_B1_INIT, "neff": 2.4, "loss_dB_cm": 0.0, "center_wavelength_nm": 1310.0}},
        "WG_b2":  {"component": "waveguide", "settings": {"length_um": 250.0,      "neff": 2.4, "loss_dB_cm": 0.0, "center_wavelength_nm": 1310.0}},
        # Photodetectors (1 Ω matched load — power ∝ |E|²)
        "Det_1":  {"component": "resistor",  "settings": {"R": 1.0}},
        "Det_2":  {"component": "resistor",  "settings": {"R": 1.0}},
        "Det_3":  {"component": "resistor",  "settings": {"R": 1.0}},
        "Det_4":  {"component": "resistor",  "settings": {"R": 1.0}},
    },
    "connections": {
        # GND bus: laser return + unused DC second inputs + detector returns
        "GND,p1":    ("Laser,p2", "DC_r1,p2", "DC_a1,p2", "DC_b1,p2",
                       "Det_1,p2", "Det_2,p2", "Det_3,p2", "Det_4,p2"),
        # Signal path: laser → root lattice
        "Laser,p1":  "DC_r1,p1",
        # Root stage 1: DC_r1 → arm pair → DC_r2
        "DC_r1,p3":  "WG_r1a,p1",
        "DC_r1,p4":  "WG_r1b,p1",
        "WG_r1a,p2": "DC_r2,p1",
        "WG_r1b,p2": "DC_r2,p2",
        # Root stage 2: DC_r2 → arm pair → DC_r3
        "DC_r2,p3":  "WG_r2a,p1",
        "DC_r2,p4":  "WG_r2b,p1",
        "WG_r2a,p2": "DC_r3,p1",
        "WG_r2b,p2": "DC_r3,p2",
        # Root → leaves: bar (p3) → Leaf A, cross (p4) → Leaf B
        "DC_r3,p3":  "DC_a1,p1",
        "DC_r3,p4":  "DC_b1,p1",
        # Leaf A: DC_a1 → arm pair → DC_a2 → Det_1 / Det_3
        "DC_a1,p3":  "WG_a1,p1",
        "DC_a1,p4":  "WG_a2,p1",
        "WG_a1,p2":  "DC_a2,p1",
        "WG_a2,p2":  "DC_a2,p2",
        "DC_a2,p3":  "Det_1,p1",
        "DC_a2,p4":  "Det_3,p1",
        # Leaf B: DC_b1 → arm pair → DC_b2 → Det_2 / Det_4
        "DC_b1,p3":  "WG_b1,p1",
        "DC_b1,p4":  "WG_b2,p1",
        "WG_b1,p2":  "DC_b2,p1",
        "WG_b2,p2":  "DC_b2,p2",
        "DC_b2,p3":  "Det_2,p1",
        "DC_b2,p4":  "Det_4,p1",
    },
}

# ── Compile and analyse ────────────────────────────────────────────────────
circuit = compile_circuit(net_dict, models, is_complex=True)
groups = circuit.groups
sys_size = circuit.sys_size
port_map = circuit.port_map

print(f"System size: {sys_size} real nodes  ({sys_size * 2} complex DOFs)")
print(f"Component groups: {list(groups.keys())}")
print("Detector node indices: " +
      ", ".join(f"Det_{i+1}={port_map[f'Det_{i+1},p1']}" for i in range(4)))

# ── Solver bookkeeping ─────────────────────────────────────────────────────
TARGET_WLS = jnp.array([1285.0, 1300.0, 1315.0, 1330.0])   # nm
det_nodes  = jnp.array([port_map[f"Det_{i+1},p1"] for i in range(4)])
Y_GUESS    = jnp.ones(sys_size * 2)   # flat initial guess for Newton-Raphson
System size: 25 real nodes  (50 complex DOFs)
Component groups: ['source', 'dc', 'waveguide', 'resistor']
Detector node indices: Det_1=6, Det_2=13, Det_3=7, Det_4=14

Trainable parameters¤

There are 11 trainable parameters in total:

Coupling coefficients (7)¤

Parameter Initial value Role
k_r1 0.5 Root stage 1 DC — 50/50 split at first coupler
k_r2 0.29 Root stage 2 DC — asymmetric coupling for flat passband
k_r3 0.08 Root stage 3 DC — small coupling for steep roll-off
k_a1 0.5 Leaf A input DC
k_a2 0.5 Leaf A output DC
k_b1 0.5 Leaf B input DC
k_b2 0.5 Leaf B output DC

The root coupling tuple (0.5, 0.29, 0.08) comes from Horst et al.'s analytic optimisation for a 2-stage Chebyshev flat-top filter.

Long arm lengths (4)¤

Parameter Initial value (µm) Design formula
L_r1a 264.246 250 + ΔL_Base
L_r2a 278.492 250 + 2·ΔL_Base
L_a1 257.123 250 + ΔL_Base/2
L_b1 257.532 250 + ΔL_Base/2 + 0.75·ΔL_FS

All short arms are fixed at 250.0 µm and are not trained — only the long arm lengths determine the MZI phase difference.

Why Adam handles the parameter scale mismatch¤

Coupling coefficients live in [0, 1] while arm lengths are in the range 250–280 µm. Adam's adaptive per-parameter learning rate normalises these automatically: it divides each gradient by a running estimate of its second moment, so coupling and length updates are implicitly rescaled without any manual tuning.

def get_power_matrix(params, wavelengths=TARGET_WLS):
    """Compute the 4×4 power routing matrix for given circuit parameters.

    Args:
        params: Array of shape (11,) — [k_r1, k_r2, k_r3, k_a1, k_a2,
                k_b1, k_b2, L_r1a, L_r2a, L_a1, L_b1]
                7 DC coupling coefficients + 4 long arm lengths (µm).
        wavelengths: Array of wavelengths in nm, shape (4,).

    Returns:
        power: Array of shape (4, 4), row-normalised.
               power[i, j] = fraction of total power reaching Det_{j+1}
               when driving at wavelengths[i].
               An ideal demux has power ≈ identity matrix.
    """
    k_r1, k_r2, k_r3, k_a1, k_a2, k_b1, k_b2, L_r1a, L_r2a, L_a1, L_b1 = params

    # Update DC couplings
    grps = update_params_dict(groups, "dc", "DC_r1", "coupling", k_r1)
    grps = update_params_dict(grps,   "dc", "DC_r2", "coupling", k_r2)
    grps = update_params_dict(grps,   "dc", "DC_r3", "coupling", k_r3)
    grps = update_params_dict(grps,   "dc", "DC_a1", "coupling", k_a1)
    grps = update_params_dict(grps,   "dc", "DC_a2", "coupling", k_a2)
    grps = update_params_dict(grps,   "dc", "DC_b1", "coupling", k_b1)
    grps = update_params_dict(grps,   "dc", "DC_b2", "coupling", k_b2)

    # Update long arm lengths
    grps = update_params_dict(grps,   "waveguide", "WG_r1a", "length_um", L_r1a)
    grps = update_params_dict(grps,   "waveguide", "WG_r2a", "length_um", L_r2a)
    grps = update_params_dict(grps,   "waveguide", "WG_a1",  "length_um", L_a1)
    grps = update_params_dict(grps,   "waveguide", "WG_b1",  "length_um", L_b1)

    def solve_at_wavelength(wl):
        # Set wavelength for dispersion in all waveguides, then solve DC
        grps_wl = update_group_params(grps, "waveguide", "wavelength_nm", wl)
        y_flat  = circuit.solver.solve_dc(grps_wl, Y_GUESS)
        # Photonic solve: complex field E = Re(y_flat[n]) + j·Im(y_flat[n + sys_size])
        E_det   = y_flat[det_nodes] + 1j * y_flat[det_nodes + sys_size]
        return jnp.abs(E_det) ** 2

    # vmap evaluates all 4 wavelengths in one parallel call
    powers     = jax.vmap(solve_at_wavelength)(wavelengths)   # shape: (4, 4)
    row_totals = jnp.sum(powers, axis=1, keepdims=True) + 1e-12
    return powers / row_totals


# ── Initial parameters from paper design equations ─────────────────────────
params_init = jnp.array([
    0.5,  0.29, 0.08,         # root DC couplings (Horst et al. optimum)
    0.5,  0.5,                # leaf A DC couplings
    0.5,  0.5,                # leaf B DC couplings
    L_R1A_INIT, L_R2A_INIT,   # root long arm lengths (µm)
    L_A1_INIT,                # leaf A long arm length
    L_B1_INIT,                # leaf B long arm length
])

print("Initial parameters:")
param_names = ["k_r1", "k_r2", "k_r3", "k_a1", "k_a2", "k_b1", "k_b2",
               "L_r1a", "L_r2a", "L_a1", "L_b1"]
for name, val in zip(param_names, params_init):
    print(f"  {name:8s} = {float(val):.4f}")

print("\nComputing initial power matrix (vmap over 4 wavelengths)...")
power_init = jax.jit(get_power_matrix)(params_init)

pm_init_np = np.array(power_init)
print("\nInitial power routing matrix:")
wl_labels = ["1285 nm", "1300 nm", "1315 nm", "1330 nm"]
print(f"{'':9s}  {'Det_1':>8s}  {'Det_2':>8s}  {'Det_3':>8s}  {'Det_4':>8s}")
for i, (wl, row) in enumerate(zip(wl_labels, pm_init_np)):
    print(f"  {wl}  " + "  ".join(f"{v:8.3f}" for v in row))

diag_init = np.diag(pm_init_np)
print("\nInitial routing contrast (diagonal):")
for i, (wl, c) in enumerate(zip([1285, 1300, 1315, 1330], diag_init)):
    print(f"  lambda={wl} nm -> Det_{i+1}: {c*100:.1f}%")
print(f"  Mean contrast: {diag_init.mean()*100:.1f}%  (random = 25%, perfect = 100%)")
Initial parameters:
  k_r1     = 0.5000
  k_r2     = 0.2900
  k_r3     = 0.0800
  k_a1     = 0.5000
  k_a2     = 0.5000
  k_b1     = 0.5000
  k_b2     = 0.5000
  L_r1a    = 264.2460
  L_r2a    = 278.4920
  L_a1     = 257.1230
  L_b1     = 257.5316

Computing initial power matrix (vmap over 4 wavelengths)...



Initial power routing matrix:
              Det_1     Det_2     Det_3     Det_4
  1285 nm     0.113     0.431     0.001     0.455
  1300 nm     0.333     0.001     0.506     0.160
  1315 nm     0.002     0.522     0.159     0.318
  1330 nm     0.518     0.113     0.367     0.003

Initial routing contrast (diagonal):
  lambda=1285 nm -> Det_1: 11.3%
  lambda=1300 nm -> Det_2: 0.1%
  lambda=1315 nm -> Det_3: 15.9%
  lambda=1330 nm -> Det_4: 0.3%
  Mean contrast: 6.9%  (random = 25%, perfect = 100%)

Why the 2-stage root gives a flatter passband¤

A single-stage MZI has the transfer function:

\[T(\lambda) = \cos^2\!\left(\frac{\pi\, n_{\rm eff}\, \Delta L}{\lambda}\right)\]

This is purely sinusoidal — it reaches its maximum at the target wavelength but rolls off immediately on both sides. For the 15 nm channel spacing used here, adjacent channels see significant residual power.

The 2-stage lattice adds a second interference stage with twice the arm-length difference. The interplay between the two sinusoidal responses flattens the passband top and steepens the transition region, exactly like a two-pole filter compared to a one-pole filter. The coupling ratios (0.5, 0.29, 0.08) are chosen to give a Chebyshev-optimal flat-top response:

Filter type Passband flatness Roll-off steepness
Single MZI Poor (sinusoidal) Gradual
2-stage CMZ Good (flat top) Steep

In practice, the CMZ topology enables closer channel spacing (smaller δλ) without crosstalk — a key advantage for dense WDM (DWDM) systems.

Note that the initial contrast of ~7% is lower than the simple MZI starting point — the paper-derived couplings are designed for the final optimised device, not for the pre-optimisation state. The gradient optimiser quickly finds the correct arm lengths to make these couplings work together.

def plot_power_matrix(power_matrix, title, ax=None):
    """Plot a 4×4 power routing matrix as an annotated heatmap."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 4))
    else:
        fig = ax.figure

    pm = np.array(power_matrix)
    im = ax.imshow(pm, vmin=0, vmax=1, cmap="viridis", aspect="auto")
    fig.colorbar(im, ax=ax, label="Fraction of power")

    ax.set_xticks(range(4))
    ax.set_yticks(range(4))
    ax.set_xticklabels(["Det_1", "Det_2", "Det_3", "Det_4"], color="grey")
    ax.set_yticklabels(["1285 nm", "1300 nm", "1315 nm", "1330 nm"], color="grey")
    ax.set_xlabel("Detector", color="grey")
    ax.set_ylabel("Wavelength", color="grey")
    ax.set_title(title, color="grey")

    for i in range(4):
        for j in range(4):
            val = pm[i, j]
            text_color = "white" if val < 0.5 else "black"
            ax.text(j, i, f"{val*100:.1f}%",
                    ha="center", va="center", fontsize=9, color=text_color)

    return fig, ax


fig, ax = plt.subplots(figsize=(5, 4))
plot_power_matrix(power_init, "CMZ Power Routing Matrix — Before Training", ax=ax)
plt.tight_layout()
plt.show()

print("The initial state routes power almost randomly — the paper-designed")
print("coupling coefficients are optimised for the final geometry, not this")
print("starting point. Backpropagation will find the arm lengths that make them work.")

png

The initial state routes power almost randomly — the paper-designed
coupling coefficients are optimised for the final geometry, not this
starting point. Backpropagation will find the arm lengths that make them work.
def loss_fn(params):
    """Negative mean routing contrast — minimise to maximise wavelength selectivity.

    Returns a scalar in [-1, 0]:
      -1.0 = perfect demux (all power at correct detector)
       0.0 = worst case (no power at any correct detector)
    """
    power = get_power_matrix(params)
    return -jnp.mean(jnp.diag(power))


# jax.value_and_grad computes loss AND all 11 gradients in a single
# forward+backward pass — exactly as efficient as computing the loss alone.
print("Computing loss and gradients at initial parameters...")
loss_val, grads = jax.jit(jax.value_and_grad(loss_fn))(params_init)

print(f"  Initial loss:     {float(loss_val):.4f}  (random ≈ -0.25, perfect = -1.0)")
print(f"  Mean contrast:    {-float(loss_val)*100:.1f}%")
print()
print("  Per-parameter gradient norms:")
for name, g in zip(param_names, grads):
    print(f"    d(loss)/d({name:8s}) = {float(g):+.6f}")
print()
print("Non-zero gradients confirm the full circuit is differentiable end-to-end.")
print("Note the scale difference: coupling gradients (dimensionless) are ~100x")
print("larger than length gradients (µm). Adam's adaptive LR handles this.")
Computing loss and gradients at initial parameters...


  Initial loss:     -0.0690  (random ≈ -0.25, perfect = -1.0)
  Mean contrast:    6.9%

  Per-parameter gradient norms:
    d(loss)/d(k_r1    ) = -0.002439
    d(loss)/d(k_r2    ) = -0.002941
    d(loss)/d(k_r3    ) = -0.399318
    d(loss)/d(k_a1    ) = +0.000000
    d(loss)/d(k_a2    ) = +0.000000
    d(loss)/d(k_b1    ) = -0.000000
    d(loss)/d(k_b2    ) = -0.000000
    d(loss)/d(L_r1a   ) = +2.002157
    d(loss)/d(L_r2a   ) = +0.844339
    d(loss)/d(L_a1    ) = -0.075943
    d(loss)/d(L_b1    ) = +0.083997

Non-zero gradients confirm the full circuit is differentiable end-to-end.
Note the scale difference: coupling gradients (dimensionless) are ~100x
larger than length gradients (µm). Adam's adaptive LR handles this.
# ── Adam optimisation loop ─────────────────────────────────────────────────
#
# Single learning rate of 0.02 for all 11 parameters.
# Adam normalises by the running gradient variance, so coupling params (∼1)
# and length params (∼250 µm) are updated at comparable effective rates.
# The CMZ topology is well-conditioned and converges in ~100 steps.

N_STEPS = 2000
LR      = 0.02

optimizer  = optax.adam(learning_rate=LR)
params     = params_init
opt_state  = optimizer.init(params)

value_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn))

losses    = []
p_history = []

print(f"Training for {N_STEPS} steps (Adam, lr={LR})...")
print(f"{'Step':>5}  {'Loss':>8}  {'Contrast':>10}  {'k_r1':>7}  {'k_r2':>7}  {'k_r3':>7}")
print("-" * 60)

for step in range(N_STEPS):
    loss, grads = value_and_grad_fn(params)
    losses.append(float(loss))
    p_history.append(np.array(params))

    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    if step == 0 or (step + 1) % 400 == 0:
        k_r1, k_r2, k_r3 = float(params[0]), float(params[1]), float(params[2])
        print(f"{step+1:5d}  {float(loss):8.4f}  {-float(loss)*100:9.1f}%  "
              f"{k_r1:7.4f}  {k_r2:7.4f}  {k_r3:7.4f}")

p_history = np.array(p_history)

print("\nFinal parameters:")
print(f"{'Name':8s}  {'Initial':>10s}  {'Final':>10s}  {'Delta':>10s}")
for name, p0, pf in zip(param_names, params_init, params):
    delta = float(pf) - float(p0)
    print(f"{name:8s}  {float(p0):10.4f}  {float(pf):10.4f}  {delta:+10.4f}")
Training for 2000 steps (Adam, lr=0.02)...
 Step      Loss    Contrast     k_r1     k_r2     k_r3
------------------------------------------------------------


    1   -0.0690        6.9%   0.5200   0.3100   0.1000


  400   -0.9988       99.9%   0.4971   0.5997   0.1301


  800   -0.9997      100.0%   0.4989   0.8361   0.1309


 1200   -0.9997      100.0%   0.4994   0.8539   0.1475


 1600   -0.9997      100.0%   0.4994   0.8569   0.1506


 2000   -0.9997      100.0%   0.4994   0.8573   0.1510

Final parameters:
Name         Initial       Final       Delta
k_r1          0.5000      0.4994     -0.0006
k_r2          0.2900      0.8573     +0.5673
k_r3          0.0800      0.1510     +0.0710
k_a1          0.5000      0.5009     +0.0009
k_a2          0.5000      0.5001     +0.0001
k_b1          0.5000      0.5000     +0.0000
k_b2          0.5000      0.5000     +0.0000
L_r1a       264.2460    264.0049     -0.2411
L_r2a       278.4920    278.2829     -0.2091
L_a1        257.1230    257.1391     +0.0161
L_b1        257.5316    257.2768     -0.2548
# ── Convergence plots ──────────────────────────────────────────────────────

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 3.5))

# Loss curve shown as routing contrast (inverted, climbing toward 100%)
contrast_history = [-l * 100 for l in losses]
ax1.plot(contrast_history, color="C0", lw=2)
ax1.axhline(25.0,  color="grey", ls=":",  lw=1, label="Random baseline (25%)")
ax1.axhline(100.0, color="C1",   ls="--", lw=1, label="Perfect demux (100%)")
ax1.set_xlabel("Optimisation step")
ax1.set_ylabel("Mean routing contrast (%)")
ax1.set_title("CMZ convergence")
ax1.set_ylim(0, 105)
ax1.legend(fontsize=8)

# Root coupling coefficient trajectories (the key CMZ parameters)
coupling_labels = ["k_r1 (root DC1)", "k_r2 (root DC2)", "k_r3 (root DC3)",
                   "k_a1 (leaf A DC1)", "k_a2 (leaf A DC2)",
                   "k_b1 (leaf B DC1)", "k_b2 (leaf B DC2)"]
colors = ["C0", "C1", "C2", "C3", "C4", "C5", "C6"]
for i, (label, color) in enumerate(zip(coupling_labels, colors)):
    ax2.plot(p_history[:, i], color=color, lw=1.5, label=label)
ax2.set_xlabel("Optimisation step")
ax2.set_ylabel("Coupling coefficient")
ax2.set_title("DC coupling trajectories")
ax2.set_ylim(-0.1, 1.1)
ax2.legend(fontsize=7, ncol=1)

plt.tight_layout()
plt.show()

png

sweep_wls = jnp.linspace(1260.0, 1360.0, 300)

def get_raw_power_sweep(params, wavelengths):
    """N*4 unnormalised power matrix for a wavelength sweep."""
    k_r1, k_r2, k_r3, k_a1, k_a2, k_b1, k_b2, L_r1a, L_r2a, L_a1, L_b1 = params
    grps = update_params_dict(groups, "dc", "DC_r1", "coupling", k_r1)
    grps = update_params_dict(grps,   "dc", "DC_r2", "coupling", k_r2)
    grps = update_params_dict(grps,   "dc", "DC_r3", "coupling", k_r3)
    grps = update_params_dict(grps,   "dc", "DC_a1", "coupling", k_a1)
    grps = update_params_dict(grps,   "dc", "DC_a2", "coupling", k_a2)
    grps = update_params_dict(grps,   "dc", "DC_b1", "coupling", k_b1)
    grps = update_params_dict(grps,   "dc", "DC_b2", "coupling", k_b2)
    grps = update_params_dict(grps,   "waveguide", "WG_r1a", "length_um", L_r1a)
    grps = update_params_dict(grps,   "waveguide", "WG_r2a", "length_um", L_r2a)
    grps = update_params_dict(grps,   "waveguide", "WG_a1",  "length_um", L_a1)
    grps = update_params_dict(grps,   "waveguide", "WG_b1",  "length_um", L_b1)

    def solve_wl(wl):
        grps_wl = update_group_params(grps, "waveguide", "wavelength_nm", wl)
        y_flat  = circuit.solver.solve_dc(grps_wl, Y_GUESS)
        E_det   = y_flat[det_nodes] + 1j * y_flat[det_nodes + sys_size]
        return jnp.abs(E_det) ** 2

    return jax.vmap(solve_wl)(wavelengths)

# ── Final power routing matrix ─────────────────────────────────────────────

print("Computing final power routing matrix...")
power_final = jax.jit(get_power_matrix)(params)

fig, (ax_before, ax_after) = plt.subplots(1, 2, figsize=(11, 4))
plot_power_matrix(power_init,  "Before Training (paper init)", ax=ax_before)
plot_power_matrix(power_final, "After Training (2000 steps)",  ax=ax_after)
fig.suptitle("CMZ Wavelength Demultiplexer — Power Routing Matrix", color="grey", fontsize=12)
plt.tight_layout()
plt.show()

pm_final_np = np.array(power_final)
diag_final  = np.diag(pm_final_np)
print("\nFinal routing contrast (diagonal):")
for i, (wl, p) in enumerate(zip([1285, 1300, 1315, 1330], diag_final)):
    print(f"  Det_{i+1} <- {wl} nm : {p*100:.1f}%")
print(f"  Mean contrast: {np.mean(diag_final)*100:.1f}%")
print()
print(f"Improvement: {diag_init.mean()*100:.1f}% -> {np.mean(diag_final)*100:.1f}%")
Computing final power routing matrix...

png

Final routing contrast (diagonal):
  Det_1 <- 1285 nm : 99.9%
  Det_2 <- 1300 nm : 100.0%
  Det_3 <- 1315 nm : 100.0%
  Det_4 <- 1330 nm : 100.0%
  Mean contrast: 100.0%

Improvement: 6.9% -> 100.0%


Pre-computing 44 wavelength sweeps for animation...


  [  1/44] step 0


  [ 16/44] step 25


  [ 31/44] step 266


Done.


Saved → examples/inverse_design/cmz_optimisation.gif  (44 frames)

cmz_optimisation.gif

Key results¤

Metric CMZ lattice
Architecture 2-level tree, 2-stage lattice at root
Trainable parameters 11 (7 couplings + 4 arm lengths)
Final mean contrast ~99%
Passband shape Flat top, steep skirts

Why the CMZ excels¤

  1. 2-stage lattice root: The interplay between two interference stages creates a Chebyshev-like flat-top response. The analytic coupling ratios (0.5, 0.29, 0.08) from Horst et al. are the starting point; the optimiser fine-tunes them.

  2. Flat passband: Light at wavelengths within each channel's band all arrive at the correct detector with high efficiency. The simple MZI wastes power at passband edges.

  3. Sharp stopband: Neighbouring channels are strongly suppressed, enabling denser channel packing (smaller δλ) without crosstalk.

Backpropagation scaling¤

With 11 parameters, backpropagation computes all gradients in ~2 circuit solves per step — versus 12 solves for finite differences. For a production photonic IC with 100+ phase elements, backprop enables a viable way to optimize such a circuit.

The key point is that jax.grad works through the entire circulax solver — sparse linear solve, complex field assembly — without any special-casing of the photonic physics. Currently, circuit topology is not part of this differentiable flow as JAX requires functions should have floating numbers in and out to be auto-differentiable