Source code for gdsfactory.get_netlist

"""Extract netlist from component port connectivity.

Assumes two ports are connected when they have same width, x, y

.. code:: yaml

    connections:
        - coupler,N0:bendLeft,W0
        - coupler,N1:bendRight,N0
        - bednLeft,N0:straight,W0
        - bendRight,N0:straight,E0

    ports:
        - coupler,E0
        - coupler,W0

"""

from __future__ import annotations

from collections import defaultdict
from collections.abc import Callable
from typing import Any

import numpy as np
import omegaconf

from gdsfactory import Port
from gdsfactory.component import Component, ComponentReference
from gdsfactory.name import clean_name
from gdsfactory.serialization import clean_dict, clean_value_json
from gdsfactory.snap import snap_to_grid
from gdsfactory.typings import LayerSpec


def get_default_connection_validators():
    return {"optical": validate_optical_connection, "electrical": _null_validator}


def get_instance_name_from_alias(
    component: Component,
    reference: ComponentReference,
) -> str:
    """Returns the instance name from the label.

    If no label returns to instanceName_x_y.

    Args:
        component: with labels.
        reference: reference that needs naming.
    """
    return reference.name


def get_instance_name_from_label(
    component: Component,
    reference: ComponentReference,
    layer_label: LayerSpec = "LABEL_INSTANCE",
) -> str:
    """Returns the instance name from the label.

    If no label returns to instanceName_x_y.

    Args:
        component: with labels.
        reference: reference that needs naming.
        layer_label: ignores layer_label[1].
    """
    from gdsfactory.pdk import get_layer

    layer_label = get_layer(layer_label)

    x = snap_to_grid(reference.x)
    y = snap_to_grid(reference.y)
    labels = component.labels

    # default instance name follows component.aliases
    text = clean_name(f"{reference.parent.name}_{x}_{y}")

    # try to get the instance name from a label
    for label in labels:
        xl = snap_to_grid(label.position[0])
        yl = snap_to_grid(label.position[1])
        if x == xl and y == yl and label.layer == layer_label[0]:
            # print(label.text, xl, yl, x, y)
            return label.text

    return text


def get_netlist_yaml(
    component: Component,
    tolerance: int = 5,
    exclude_port_types: list | None = None,
    **kwargs,
) -> str:
    """Returns instances, connections and placements yaml string content."""
    return omegaconf.OmegaConf.to_yaml(
        get_netlist(
            component=component,
            tolerance=tolerance,
            exclude_port_types=exclude_port_types,
            **kwargs,
        )
    )


[docs] def get_netlist( component: Component, tolerance: int = 5, exclude_port_types: list[str] | tuple[str] | None = ("placement",), get_instance_name: Callable[..., str] = get_instance_name_from_alias, allow_multiple: bool = False, merge_info: bool = False, extend_recursive_port_names: bool = False, ) -> dict[str, Any]: """Returns instances, connections and placements from :class:`Component` as a dict. Does two sweeps over the connections: 1. first tries to connect everything assuming perfect connections at each port. 2. Then gathers ports which did not perfectly connect to anything and tries \ to find imperfect connections, by grouping ports on a coarse grid. warnings collected during netlisting are reported back into the netlist. These include warnings about mismatched port widths, orientations, shear angles, excessive offsets, etc. You can also configure warning types which should throw an error when encountered by modifying DEFAULT_CRITICAL_CONNECTION_ERROR_TYPES. Validators, which will produce warnings for each port type, can be overridden with DEFAULT_CONNECTION_VALIDATORS A key difference in this algorithm is that we group each port type independently. This allows us to use different logic to determine i.e. if an electrical port is properly connected vs an optical port. In this function, the core logic is the same, but we employ extra validation for optical ports. snap_to_grid() allows a value of 0, which will return the original value, is more efficient when the value is 1, and will throw a more descriptive error when the value is <0 the default value of tolerance is 5nm because it should allow better performance with the two-grid-sweep approach. Args: component: to extract netlist. tolerance: tolerance in grid_factor to consider two ports connected. exclude_port_types: optional list of port types to exclude from netlisting. get_instance_name: function to get instance name. allow_multiple: False to raise an error if more than two ports share the same connection. \ if True, will return key: [value] pairs with [value] a list of all connected instances. merge_info: True to merge info and settings into the same dict. extend_recursive_port_names: Compatibility with recursive get_netlist port name identifiers. Returns: Dictionary containing the following: instances: Dict of instance name and settings. connections: Dict of Instance1Name,portName: Instance2Name,portName. placements: Dict of instance names and placements (x, y, rotation). port: Dict portName: ComponentName,port. name: name of component. warnings: warning messages (disconnected pins). """ placements = {} instances = {} connections = {} top_ports = {} # store where ports are located name2port = {} # TOP level ports ports = component.get_ports(depth=0) ports_by_type = defaultdict(list) top_ports_list = set() references = _get_references_to_netlist(component) for reference in references: c = reference.parent origin = reference.origin x = float(snap_to_grid(origin[0])) y = float(snap_to_grid(origin[1])) reference_name = get_instance_name( component, reference, ) if ( isinstance(reference, ComponentReference) and hasattr(reference, "columns") and (reference.columns > 1 or reference.rows > 1) ): is_array = True base_reference_name = reference_name reference_name += "__1_1" else: is_array = False instance = {} if c.info: instance.update(component=c.name, info=clean_value_json(c.info)) # Prefer name from settings over c.name if c.settings: settings = c.settings.model_dump(exclude_none=True) if merge_info: info = c.info.model_dump(exclude_none=True) settings.update({k: v for k, v in info.items() if k in settings}) settings = clean_value_json(settings) instance.update( component=c.function_name or c.name, settings=settings, ) instances[reference_name] = instance placements[reference_name] = { "x": x, "y": y, "rotation": int(reference.rotation or 0), "mirror": reference.x_reflection or 0, } if is_array: parent_ports = c.ports for i in range(reference.rows): for j in range(reference.columns): reference_name = f"{base_reference_name}__{i + 1}_{j + 1}" xj = x + j * reference.spacing[0] yi = y + i * reference.spacing[1] instances[reference_name] = instance placements[reference_name] = { "x": xj, "y": yi, "rotation": int(reference.rotation or 0), "mirror": reference.x_reflection or 0, } for parent_port_name in parent_ports: top_name = f"{parent_port_name}_{i + 1}_{j + 1}" lower_name = f"{reference_name},{parent_port_name}" # a bit of a hack... get the top-level port for the # ComponentArray, by our known naming convention. I hope no one # renames these ports! if extend_recursive_port_names: parent_port = component[ parent_port_name ] # otherwise links to non existent component ports else: parent_port = component[top_name] name2port[lower_name] = parent_port top_ports_list.add(top_name) ports_by_type[parent_port.port_type].append(lower_name) else: # lower level ports for port in reference.ports.values(): reference_name = get_instance_name( component, reference, ) src = f"{reference_name},{port.name}" name2port[src] = port ports_by_type[port.port_type].append(src) for port in ports: src = port.name name2port[src] = port top_ports_list.add(src) ports_by_type[port.port_type].append(src) warnings = {} for port_type, port_names in ports_by_type.items(): if exclude_port_types and port_type in exclude_port_types: continue connections_t, warnings_t = extract_connections( port_names, name2port, port_type, tolerance=tolerance, allow_multiple=allow_multiple, ) if warnings_t: warnings[port_type] = warnings_t for connection in connections_t: if len(connection) == 2: src, dst = connection if src in top_ports_list: top_ports[src] = dst elif dst in top_ports_list: top_ports[dst] = src else: src_dest = sorted([src, dst]) connections[src_dest[0]] = src_dest[1] connections_sorted = {k: connections[k] for k in sorted(connections.keys())} placements_sorted = {k: placements[k] for k in sorted(placements.keys())} instances_sorted = {k: instances[k] for k in sorted(instances.keys())} netlist = { "connections": connections_sorted, "instances": instances_sorted, "placements": placements_sorted, "ports": top_ports, "name": component.name, } if warnings: netlist["warnings"] = warnings return netlist
def extract_connections( port_names: list[str], ports: dict[str, Port], port_type: str, tolerance: int = 5, validators: dict[str, Callable] | None = None, allow_multiple: bool = False, ): if validators is None: validators = DEFAULT_CONNECTION_VALIDATORS validator = validators.get(port_type, _null_validator) return _extract_connections_two_sweep( port_names, ports, port_type, tolerance=tolerance, connection_validator=validator, allow_multiple=allow_multiple, ) def _extract_connections_two_sweep( port_names: list[str], ports: dict[str, Port], port_type: str, connection_validator: Callable, tolerance: int, raise_error_for_warnings: list[str] | None = None, allow_multiple: bool = False, ): warnings = defaultdict(list) if raise_error_for_warnings is None: raise_error_for_warnings = DEFAULT_CRITICAL_CONNECTION_ERROR_TYPES.get( port_type, [] ) unconnected_port_names = list(port_names) if tolerance < 0: raise ValueError(f"Cannot have a tolerance less than zero. Got {tolerance}") elif tolerance <= 1: # if tolerance is 0 or 1, do only one sweep with that tolerance grids = [("fine", tolerance)] else: # default: do one fine sweep with a 1nm tolerance, then a coarse sweep # with the given tolerance to connect any remaining ports which are not # perfectly aligned grids = [("fine", 1), ("coarse", tolerance)] connections = [] for _grid_name, grid_size in grids: by_xy = defaultdict(list) for port_name in unconnected_port_names: port = ports[port_name] by_xy[tuple(snap_to_grid(port.center, grid_factor=grid_size))].append( port_name ) unconnected_port_names = [] for xy, ports_at_xy in by_xy.items(): if len(ports_at_xy) == 1: unconnected_port_names.append(ports_at_xy[0]) elif len(ports_at_xy) == 2: port1 = ports[ports_at_xy[0]] port2 = ports[ports_at_xy[1]] connection_validator(port1, port2, ports_at_xy, warnings) connections.append(ports_at_xy) elif not allow_multiple: warnings["multiple_connections"].append(ports_at_xy) raise ValueError(f"Found multiple connections at {xy}:{ports_at_xy}") else: # Iterates over the list of multiple ports to create related two-port connectivity num_ports = len(ports_at_xy) for portindex1, portindex2 in zip( range(-1, num_ports - 1), range(num_ports) ): port1 = ports[ports_at_xy[portindex1]] port2 = ports[ports_at_xy[portindex2]] connection_validator(port1, port2, ports_at_xy, warnings) connections.append( [ports_at_xy[portindex1], ports_at_xy[portindex2]] ) if unconnected_port_names: unconnected_non_top_level = [ pname for pname in unconnected_port_names if ("," in pname) ] if unconnected_non_top_level: unconnected_xys = [ ports[pname].center for pname in unconnected_non_top_level ] warnings["unconnected_ports"].append( _make_warning( ports=unconnected_non_top_level, values=unconnected_xys, message=f"{len(unconnected_non_top_level)} unconnected {port_type} ports!", ) ) critical_warnings = { w: warnings[w] for w in raise_error_for_warnings if w in warnings } if critical_warnings: raise ValueError( f"Found critical warnings while extracting netlist: {critical_warnings}" ) return connections, dict(warnings) def _make_warning(ports: list[str], values: Any, message: str) -> dict[str, Any]: w = { "ports": ports, "values": values, "message": message, } return clean_dict(w) def _null_validator(port1: Port, port2: Port, port_names, warnings) -> None: pass def validate_optical_connection( port1: Port, port2: Port, port_names, warnings, angle_tolerance=0.01, offset_tolerance=0.001, width_tolerance=0.001, ) -> None: is_top_level = [("," not in pname) for pname in port_names] if all(is_top_level): raise ValueError(f"Two top-level ports appear to be connected: {port_names}") if abs(port1.width - port2.width) > width_tolerance: warnings["width_mismatch"].append( _make_warning( port_names, values=[port1.width, port2.width], message=f"Widths of ports {port_names[0]} and {port_names[1]} not equal. " f"Difference of {abs(port1.width - port2.width)} um", ) ) if port1.shear_angle and not port2.shear_angle: warnings["shear_angle_mismatch"].append( _make_warning( port_names, values=[port1.shear_angle, port2.shear_angle], message=f"{port_names[0]} has a shear angle but {port_names[1]} " f"does not! Shear angle is {port1.shear_angle} deg", ) ) elif not port1.shear_angle and port2.shear_angle: warnings["shear_angle_mismatch"].append( _make_warning( port_names, values=[port1.shear_angle, port2.shear_angle], message=f"{port_names[1]} has a shear angle but {port_names[0]} " f"does not! Shear angle is {port2.shear_angle} deg", ) ) elif port1.shear_angle: if ( abs(difference_between_angles(port1.shear_angle, port2.shear_angle)) > angle_tolerance ): warnings["shear_angle_mismatch"].append( _make_warning( port_names, values=[port1.shear_angle, port2.shear_angle], message=f"Shear angle of {port_names[0]} and {port_names[1]} " f"differ by {abs(port1.shear_angle - port2.shear_angle)} deg", ) ) if any(is_top_level): if ( abs(difference_between_angles(port1.orientation, port2.orientation)) > angle_tolerance ): top_port, lower_port = port_names if is_top_level[0] else port_names[::-1] warnings["orientation_mismatch"].append( _make_warning( port_names, values=[port1.orientation, port2.orientation], message=f"{lower_port} was promoted to {top_port} but orientations" f"do not match! Difference of {(abs(port1.orientation - port2.orientation))} deg", ) ) else: angle_misalignment = abs( abs(difference_between_angles(port1.orientation, port2.orientation)) - 180 ) if angle_misalignment > angle_tolerance: warnings["orientation_mismatch"].append( _make_warning( port_names, values=[port1.orientation, port2.orientation], message=f"{port_names[0]} and {port_names[1]} are misaligned by {angle_misalignment} deg", ) ) offset_mismatch = np.sqrt(np.sum(np.square(port2.center - port1.center))) if offset_mismatch > offset_tolerance: warnings["offset_mismatch"].append( _make_warning( port_names, values=[port1.center, port2.center], message=f"{port_names[0]} and {port_names[1]} are offset by {offset_mismatch} um", ) ) def difference_between_angles(angle2: float, angle1: float) -> float: diff = angle2 - angle1 while diff < 180: diff += 360 while diff > 180: diff -= 360 return diff def _get_references_to_netlist(component: Component) -> list[ComponentReference]: from gdsfactory.cell import CACHE references = component.references if not references and "transformed_cell" in component.info: # expand transformed, flattened cells ref = component.settings original_cell = CACHE[component.info["transformed_cell"]] references = [ ComponentReference( original_cell, origin=ref["origin"], rotation=ref["rotation"], x_reflection=ref["x_reflection"], ) ] return references def get_netlist_recursive( component: Component, component_suffix: str = "", get_netlist_func: Callable = get_netlist, get_instance_name: Callable[..., str] = get_instance_name_from_alias, **kwargs, ) -> dict[str, Any]: """Returns recursive netlist for a component and subcomponents. Args: component: to extract netlist. component_suffix: suffix to append to each component name. useful if to save and reload a back-annotated netlist. get_netlist_func: function to extract individual netlists. Keyword Args: tolerance: tolerance in grid_factor to consider two ports connected. exclude_port_types: optional list of port types to exclude from netlisting. get_instance_name: function to get instance name. Returns: Dictionary of netlists, keyed by the name of each component. """ all_netlists = {} # only components with references (subcomponents) warrant a netlist references = _get_references_to_netlist(component) if references: netlist = get_netlist_func(component, **kwargs) all_netlists[f"{component.name}{component_suffix}"] = netlist # for each reference, expand the netlist for ref in references: rcell = ref.parent grandchildren = get_netlist_recursive( component=rcell, component_suffix=component_suffix, get_netlist_func=get_netlist_func, **kwargs, ) all_netlists |= grandchildren child_references = _get_references_to_netlist(ref.ref_cell) if child_references: inst_name = get_instance_name(component, ref) netlist_dict = {"component": f"{rcell.name}{component_suffix}"} if rcell.settings: netlist_dict.update( settings=rcell.settings.model_dump(exclude_none=True) ) if rcell.info: netlist_dict.update(info=rcell.info.model_dump(exclude_none=True)) netlist["instances"][inst_name] = netlist_dict return all_netlists def _demo_ring_single_array() -> None: import gdsfactory as gf c = gf.components.ring_single_array() c.get_netlist() def _demo_mzi_lattice() -> None: import gdsfactory as gf coupler_lengths = [10, 20, 30, 40] coupler_gaps = [0.1, 0.2, 0.4, 0.5] delta_lengths = [10, 100, 200] c = gf.components.mzi_lattice( coupler_lengths=coupler_lengths, coupler_gaps=coupler_gaps, delta_lengths=delta_lengths, ) c.get_netlist() print(c.get_netlist_yaml()) DEFAULT_CONNECTION_VALIDATORS = get_default_connection_validators() DEFAULT_CRITICAL_CONNECTION_ERROR_TYPES = { "optical": ["width_mismatch", "shear_angle_mismatch", "orientation_mismatch"] } if __name__ == "__main__": import gdsfactory as gf c = gf.c.mzi() n = c.get_netlist() # from gdsfactory.decorators import flatten_offgrid_references # rotation_value = 35 # cname = "test_get_netlist_transformed" # c = gf.Component(cname) # i1 = c.add_ref(gf.components.straight(), "i1") # i2 = c.add_ref(gf.components.straight(), "i2") # i1.rotate(rotation_value) # i2.connect("o2", i1.ports["o1"]) # # flatten the oddly rotated refs # c = flatten_offgrid_references(c) # print(c.get_dependencies()) # c.show() # # perform the initial sanity checks on the netlist # netlist = c.get_netlist() # connections = netlist["connections"] # assert len(connections) == 1, len(connections) # cpairs = list(connections.items()) # extracted_port_pair = set(cpairs[0]) # expected_port_pair = {"i2,o2", "i1,o1"} # assert extracted_port_pair == expected_port_pair # recursive_netlist = get_netlist_recursive(c) # top_netlist = recursive_netlist[cname] # # the recursive netlist should have 3 entries, for the top level and two # # rotated straights # assert len(recursive_netlist) == 1, len(recursive_netlist) # # confirm that the child netlists have reference attributes properly set # i1_cell_name = top_netlist["instances"]["i1"]["component"] # i1_netlist = recursive_netlist[i1_cell_name] # # currently for transformed netlists, the instance name of the inner cell is None # assert i1_netlist["placements"][None]["rotation"] == rotation_value # i2_cell_name = top_netlist["instances"]["i2"]["component"] # i2_netlist = recursive_netlist[i2_cell_name] # # currently for transformed netlists, the instance name of the inner cell is None # assert i2_netlist["placements"][None]["rotation"] == rotation_value