import jax.numpy as jnp
from qpdk.models import unimon_coupled

freq = jnp.linspace(1e9, 10e9, 501)
freq_ghz = freq / 1e9
s_model = unimon_coupled(f=freq)

for (pout, pin), sij in s_model.items():
   i = int(pout[-1])  # "o2" -> 2
   j = int(pin[-1])   # "o1" -> 1

   if i >= j:
       plt.plot(
           freq_ghz,
           20 * jnp.log10(jnp.abs(sij)),
           label = "$S_{" + str(i) + str(j) + "}$"
       )
plt.xlabel("Frequency [GHz]", fontsize=12)
plt.ylabel("Magnitude [dB]", fontsize=12)
plt.grid(True)
plt.legend()
plt.show()