Skip to content

compiler ¤

Compiles the group into ComponentGroups and organizes the node index.

Classes:

Name Description
ComponentGroup

Represents a BATCH of identical components (e.g., ALL Resistors).

Functions:

Name Description
compile_netlist

Compile a netlist into batched, vectorized component groups ready for simulation.

ensure_time_signature

Wraps a model function to ensure it accepts a 't' keyword argument.

get_default_params

Return a copy so callers can't mutate the cache.

get_model_width

Determines the size of the 'vars' vector expected by the model.

merge_dicts

Merges a list of dictionaries into a single dictionary.

solve_connectivity

Resolves Port-to-Port connections into a Port-to-NodeID map.

ComponentGroup ¤

Bases: Module

Represents a BATCH of identical components (e.g., ALL Resistors). Optimized for jax.vmap in the solver.

compile_netlist ¤

compile_netlist(netlist: dict, models_map: dict) -> tuple[dict, int, dict]

Compile a netlist into batched, vectorized component groups ready for simulation.

This function bridges the gap between a human-readable netlist description and the internal representation required by the ODE solver. It resolves net connectivity, instantiates component objects, assigns state variable indices, and batches components of the same type into vectorized groups for efficient JAX execution.

The compilation proceeds in three stages:

  1. Connectivity resolutionbuild_net_map assigns a unique integer node index to every net, returning a flat map of "Instance,Port" keys to node indices.
  2. Instance processing — each instance is instantiated as an Equinox object using its settings, its port indices are looked up in the node map, and it is placed into a bucket keyed by (component_type, tree_structure). The tree structure is included in the key so that instances whose static fields differ (e.g. a callable parameter) are never incorrectly batched together.
  3. Vectorization — each bucket is stacked into a single :class:ComponentGroup with batched parameters and pre-computed Jacobian sparsity index arrays, ready to be passed directly to the solver.

Parameters:

Name Type Description Default
netlist dict

Circuit description dict. Expected to contain an "instances" key mapping instance names to dicts with at least a "component" key (model name string) and an optional "settings" key (parameter dict forwarded to the component constructor). A "GND" instance with component="ground" is recognised and skipped; see models_map note on reserved types.

Settings filtering: settings whose value is None are dropped (GDSFactory convention — None means "use the model default"). Settings whose key is not declared on the model are also silently ignored, which allows GDSFactory netlists carrying geometry-only keys (e.g. dy, dx, cross_section) to be passed directly without preprocessing. Note: a mis-spelled parameter name will be silently ignored rather than raising an error; double-check spelling if a parameter appears to have no effect.

required
models_map dict

Mapping from model name strings to :class:~circulax.components.base_component.CircuitComponent subclasses, e.g. {"Resistor": Resistor, "Capacitor": Capacitor}. "ground" is a reserved component type — it does not need to appear in this map (instances with component="ground" or named "GND" are skipped automatically). Raw SAX model functions (callable with all-defaulted parameters and a sax.SDict/sax.SDense/sax.SCoo/sax.SType return annotation) are accepted directly and auto-wrapped via :func:~circulax.s_transforms.sax_component; no manual decoration required.

required

Returns:

Type Description
dict

A three-tuple (compiled_groups, sys_size, port_to_node_map) where:

int
  • compiled_groups (dict[str, ComponentGroup]) — maps group name to a fully vectorized :class:ComponentGroup. If all instances of a type share the same tree structure there is one group per type, named after the type (e.g. "Resistor"). When a type is split across multiple structures the groups are numbered ("Resistor_0", "Resistor_1", …).
dict
  • sys_size (int) — total number of scalar unknowns in the system vector y, equal to the number of nets plus the total number of state variables across all instances. This is the length of the array passed to the solver.
tuple[dict, int, dict]
  • port_to_node_map (dict[str, int]) — maps both "Instance,port" and "Instance,state" keys to their integer indices in the global state vector y. Port keys resolve to shared node indices (multiple ports on the same net share one index); state keys resolve to unique per-instance indices. Use this to extract specific node voltages or internal state variables from the solution.

Raises:

Type Description
ValueError

If a component type listed in the netlist is not present in models_map, or if a port declared on a component class has no corresponding entry in the netlist connections.

TypeError

If the settings dict for an instance does not match the constructor signature of its component class.

Source code in circulax/compiler.py
def compile_netlist(netlist: dict, models_map: dict) -> tuple[dict, int, dict]:  # noqa: C901, PLR0912, PLR0915
    """Compile a netlist into batched, vectorized component groups ready for simulation.

    This function bridges the gap between a human-readable netlist description and
    the internal representation required by the ODE solver. It resolves net
    connectivity, instantiates component objects, assigns state variable indices,
    and batches components of the same type into vectorized groups for efficient
    JAX execution.

    The compilation proceeds in three stages:

    1. **Connectivity resolution** — ``build_net_map`` assigns a unique integer
       node index to every net, returning a flat map of ``"Instance,Port"`` keys
       to node indices.
    2. **Instance processing** — each instance is instantiated as an Equinox
       object using its settings, its port indices are looked up in the node map,
       and it is placed into a bucket keyed by ``(component_type, tree_structure)``.
       The tree structure is included in the key so that instances whose static
       fields differ (e.g. a callable parameter) are never incorrectly batched
       together.
    3. **Vectorization** — each bucket is stacked into a single
       :class:`ComponentGroup` with batched parameters and pre-computed Jacobian
       sparsity index arrays, ready to be passed directly to the solver.

    Args:
        netlist: Circuit description dict. Expected to contain an
            ``"instances"`` key mapping instance names to dicts with at least
            a ``"component"`` key (model name string) and an optional
            ``"settings"`` key (parameter dict forwarded to the component
            constructor). A ``"GND"`` instance with ``component="ground"`` is
            recognised and skipped; see ``models_map`` note on reserved types.

            **Settings filtering**: settings whose value is ``None`` are
            dropped (GDSFactory convention — ``None`` means "use the model
            default"). Settings whose key is not declared on the model are
            also silently ignored, which allows GDSFactory netlists carrying
            geometry-only keys (e.g. ``dy``, ``dx``, ``cross_section``) to
            be passed directly without preprocessing.  *Note*: a mis-spelled
            parameter name will be silently ignored rather than raising an
            error; double-check spelling if a parameter appears to have no
            effect.
        models_map: Mapping from model name strings to
            :class:`~circulax.components.base_component.CircuitComponent`
            subclasses, e.g. ``{"Resistor": Resistor, "Capacitor": Capacitor}``.
            ``"ground"`` is a **reserved** component type — it does not need
            to appear in this map (instances with ``component="ground"`` or
            named ``"GND"`` are skipped automatically).
            Raw SAX model functions (callable with all-defaulted parameters and
            a ``sax.SDict``/``sax.SDense``/``sax.SCoo``/``sax.SType`` return
            annotation) are accepted directly and auto-wrapped via
            :func:`~circulax.s_transforms.sax_component`; no manual decoration
            required.

    Returns:
        A three-tuple ``(compiled_groups, sys_size, port_to_node_map)`` where:

        - **compiled_groups** (``dict[str, ComponentGroup]``) — maps group name
            to a fully vectorized :class:`ComponentGroup`. If all instances of a
            type share the same tree structure there is one group per type, named
            after the type (e.g. ``"Resistor"``). When a type is split across
            multiple structures the groups are numbered (``"Resistor_0"``,
            ``"Resistor_1"``, …).
        - **sys_size** (``int``) — total number of scalar unknowns in the system
            vector ``y``, equal to the number of nets plus the total number of
            state variables across all instances. This is the length of the array
            passed to the solver.
        - **port_to_node_map** (``dict[str, int]``) — maps both
            ``"Instance,port"`` and ``"Instance,state"`` keys to their integer
            indices in the global state vector ``y``. Port keys resolve to shared
            node indices (multiple ports on the same net share one index); state
            keys resolve to unique per-instance indices. Use this to extract
            specific node voltages or internal state variables from the solution.

    Raises:
        ValueError: If a component type listed in the netlist is not present in
            ``models_map``, or if a port declared on a component class has no
            corresponding entry in the netlist connections.
        TypeError: If the settings dict for an instance does not match the
            constructor signature of its component class.

    """
    # ``"ground"`` is a reserved component type — instances with
    # ``component="ground"`` (or named ``"GND"``) are skipped during
    # compilation and never looked up in ``models_map``.  Users do NOT need
    # to include ``"ground"`` in the models map; if they do, the entry is
    # silently ignored.
    _RESERVED = frozenset({"ground"})

    # Auto-wrap raw SAX model functions into CircuitComponent classes so callers
    # can pass PDKs of plain SAX models straight through the netlist interface.
    # CircuitComponent subclasses pass through unchanged; anything else raises.
    from circulax.s_transforms import _normalize_model  # noqa: PLC0415
    models_map = {
        k: _normalize_model(v, name=k)
        for k, v in models_map.items()
        if k not in _RESERVED
    }

    port_to_node_map, num_nodes = build_net_map(netlist)

    # Buckets: Key = (comp_type_name, tree_structure), Value = list of instances
    buckets = defaultdict(list)
    sys_size = num_nodes

    # --- 2. Process Instances ---
    instances = netlist.get("instances", {})

    for name, data in instances.items():
        comp_type = data["component"]

        # Skip ground (it's just a marker, already handled in build_net_map)
        if comp_type == "ground" or name == "GND":
            continue

        if comp_type not in models_map:
            msg = f"Model '{comp_type}' not found for '{name}'"
            raise ValueError(msg)

        comp_cls = models_map[comp_type]
        # GDSFactory netlists carry geometry settings that don't appear on
        # the simulation model (e.g. ``dy``/``dx`` on a ``coupler_strip``
        # instance, or ``allow_min_radius_violation`` on a ``bend_euler``).
        # Filter to fields the model actually declares, and drop ``None``
        # values (GDSFactory convention: ``None`` means "use the default").
        known_fields = {f.name for f in dataclasses.fields(comp_cls)}
        settings = {
            k: v for k, v in data.get("settings", {}).items()
            if v is not None and k in known_fields
        }

        # A. Create Equinox Object
        try:
            comp_obj = comp_cls(**settings)
        except TypeError as e:
            msg = f"Settings error for {name}: {e}"
            raise TypeError(msg)  # noqa: B904

        # B. Get Port Indices
        port_indices = []
        for port in comp_cls.ports:
            key = f"{name},{port}"

            if key in port_to_node_map:
                port_indices.append(port_to_node_map[key])
            else:
                msg = f"Port '{port}' on '{name}' is unconnected.\nYour netlist connections must include '{key}'"
                raise ValueError(msg)

        # Group by Type AND Structure (to handle static field differences)
        structure = jax.tree.structure(comp_obj)
        buckets[(comp_type, structure)].append(
            {
                "obj": comp_obj,
                "ports": port_indices,
                "num_states": len(comp_cls.states),
                "name": name,
            }
        )

    compiled_groups = {}

    # Helper to generate unique names for split groups
    type_counts = defaultdict(int)
    for ctype, _ in buckets:
        type_counts[ctype] += 1
    type_counters = defaultdict(int)

    for (comp_type, _), items in buckets.items():
        comp_cls = models_map[comp_type]

        # Generate Group Name
        if type_counts[comp_type] > 1:
            idx = type_counters[comp_type]
            group_name = f"{comp_type}_{idx}"
            type_counters[comp_type] += 1
        else:
            group_name = comp_type

        # A. Assign Internal States
        all_var_indices = []
        for item in items:
            state_indices = []
            for s_name in comp_cls.states:
                port_to_node_map[f"{item['name']},{s_name}"] = sys_size
                state_indices.append(sys_size)
                sys_size += 1
            all_var_indices.append(item["ports"] + state_indices)

        # B. Batch Params
        instance_objects = [item["obj"] for item in items]
        batched_params = jax.tree.map(lambda *args: jnp.stack(args), *instance_objects)

        # C. Matrices
        var_indices_arr = jnp.array(all_var_indices, dtype=jnp.int32)
        width = var_indices_arr.shape[1]
        count = len(items)

        jac_rows = jnp.broadcast_to(var_indices_arr[:, :, None], (count, width, width))
        jac_cols = jnp.broadcast_to(var_indices_arr[:, None, :], (count, width, width))

        # Create Index Map for parameter updates
        index_map = {item["name"]: i for i, item in enumerate(items)}

        compiled_groups[group_name] = ComponentGroup(
            name=group_name,
            var_indices=var_indices_arr,
            eq_indices=var_indices_arr,
            params=batched_params,
            physics_func=comp_cls.solver_call,
            jac_rows=jac_rows,
            jac_cols=jac_cols,
            index_map=index_map,
            is_fdomain=getattr(comp_cls, "_is_fdomain", False),
            amplitude_param=getattr(comp_cls, "amplitude_param", ""),
        )

    return compiled_groups, sys_size, port_to_node_map

ensure_time_signature ¤

ensure_time_signature(model_func: callable) -> callable

Wraps a model function to ensure it accepts a 't' keyword argument.

If the original model doesn't take 't', the wrapper swallows it.

Source code in circulax/compiler.py
def ensure_time_signature(model_func: callable) -> callable:
    """Wraps a model function to ensure it accepts a 't' keyword argument.

    If the original model doesn't take 't', the wrapper swallows it.
    """
    sig = inspect.signature(model_func)

    if "t" in sig.parameters or "**kwargs" in str(sig):
        return model_func

    # Wrapper for static models
    @wraps(model_func)
    def time_aware_wrapper(signal_states: Signals, params: dict, t: float | None = None) -> PhysicsReturn:  # noqa: ARG001
        return model_func(signal_states, params)

    return time_aware_wrapper

get_default_params ¤

get_default_params(func: callable) -> dict[str, Any]

Return a copy so callers can't mutate the cache.

Source code in circulax/compiler.py
def get_default_params(func: callable) -> dict[str, Any]:
    """Return a copy so callers can't mutate the cache."""
    return dict(_get_default_params_cached(func))

get_model_width ¤

get_model_width(func: callable) -> int

Determines the size of the 'vars' vector expected by the model.

Source code in circulax/compiler.py
def get_model_width(func: callable) -> int:
    """Determines the size of the 'vars' vector expected by the model."""
    sig = inspect.signature(func)
    if "vars" not in sig.parameters:
        msg = f"{func.__name__} missing 'vars' argument"
        raise ValueError(msg)
    default_val = sig.parameters["vars"].default
    if default_val is inspect.Parameter.empty:
        msg = f"{func.__name__} 'vars' must have a default (e.g. jnp.zeros(3))"
        raise ValueError(msg)
    return len(default_val)

merge_dicts ¤

merge_dicts(dict_list: list[dict[str, Any]]) -> dict[str, Any]

Merges a list of dictionaries into a single dictionary.

Source code in circulax/compiler.py
def merge_dicts(dict_list: list[dict[str, Any]]) -> dict[str, Any]:
    """Merges a list of dictionaries into a single dictionary."""
    merged = {}
    for d in dict_list:
        merged.update(d)
    return merged

solve_connectivity ¤

solve_connectivity(connections: dict) -> dict

Resolves Port-to-Port connections into a Port-to-NodeID map.

Example: {"R1,p1": "V1,p1"} -> {"R1,p1": 1, "V1,p1": 1}

Source code in circulax/compiler.py
def solve_connectivity(connections: dict) -> dict:  # noqa: C901
    """Resolves Port-to-Port connections into a Port-to-NodeID map.

    Example: {"R1,p1": "V1,p1"} -> {"R1,p1": 1, "V1,p1": 1}
    """
    parent = {}

    def find(i: int):  # noqa: ANN202
        if i not in parent:
            parent[i] = i
        if parent[i] != i:
            parent[i] = find(parent[i])
        return parent[i]

    def union(i: int, j: int) -> None:
        root_i = find(i)
        root_j = find(j)
        if root_i != root_j:
            parent[root_i] = root_j

    # 1. Process all connections
    for src, targets in connections.items():
        # Ensure 'targets' is a list
        if not isinstance(targets, (list, tuple)):
            targets = [targets]

        # Link Source to all Targets
        for tgt in targets:
            union(src, tgt)

    # 2. Assign Node IDs
    # We reserve ID 0 for Ground (any group containing "GND")
    groups = {}
    node_map = {}

    # Identify the root for "GND" if it exists
    gnd_roots = {find(k) for k in parent if "GND" in k}

    node_counter = 1

    for port in parent:
        root = find(port)

        if root in gnd_roots:
            node_id = 0
        else:
            if root not in groups:
                groups[root] = node_counter
                node_counter += 1
            node_id = groups[root]

        node_map[port] = node_id

    return node_map, node_counter