"""Cell decorator for functions that return a Component."""
from __future__ import annotations
import functools
import hashlib
import inspect
import warnings
from collections.abc import Callable, Sequence
from functools import partial
from typing import TypeVar, overload
from pydantic import validate_call
from gdsfactory.component import Component, name_counters
from gdsfactory.component_layout import CellSettings
from gdsfactory.config import CONF
from gdsfactory.name import clean_name, get_name_short
from gdsfactory.serialization import clean_value_name
CACHE: dict[str, Component] = {}
CACHE_IDS = set()
INFO_VERSION = 2
_F = TypeVar("_F", bound=Callable[..., Component])
class CellReturnTypeError(ValueError):
pass
def remove_from_cache(name: str | Component) -> None:
"""Removes Component name from CACHE and resets the name counter."""
if not isinstance(name, str):
name = name.name
if name in CACHE:
del CACHE[name]
if name_counters[name] == 1:
name_counters[name] = 0
def clear_cache() -> None:
"""Clears Component CACHE and reset the name counters."""
CACHE.clear()
CACHE_IDS.clear()
name_counters.clear()
def print_cache() -> None:
for k in CACHE:
print(k)
@overload
def cell(
func: None,
/,
*,
autoname: bool = True,
copy_if_cached: bool = True,
max_name_length: int | None = None,
include_module: bool = False,
with_hash: bool = False,
ports_offgrid: str | None = None,
ports_not_manhattan: str | None = None,
flatten: bool = False,
naming_style: str = "default",
default_decorator: Callable[[Component], Component] | None = None,
add_settings: bool = True,
validate: bool = False,
get_child_name: bool = False,
post_process: Sequence[Callable] | None = None,
info: dict[str, int | float | str] | None = None,
) -> partial: ...
@overload
def cell(
func: _F,
/,
*,
autoname: bool = True,
copy_if_cached: bool = True,
max_name_length: int | None = None,
include_module: bool = False,
with_hash: bool = False,
ports_offgrid: str | None = None,
ports_not_manhattan: str | None = None,
flatten: bool = False,
naming_style: str = "default",
default_decorator: Callable[[Component], Component] | None = None,
add_settings: bool = True,
validate: bool = False,
get_child_name: bool = False,
post_process: Sequence[Callable] | None = None,
info: dict[str, int | float | str] | None = None,
) -> _F: ...
[docs]
def cell(
func=None,
/,
*,
autoname: bool = True,
copy_if_cached: bool = True,
max_name_length: int | None = None,
include_module: bool = False,
with_hash: bool = False,
ports_offgrid: str | None = None,
ports_not_manhattan: str | None = None,
flatten: bool = False,
naming_style: str = "default",
default_decorator: Callable[[Component], Component] | None = None,
add_settings: bool = True,
validate: bool = False,
get_child_name: bool = False,
post_process: Sequence[Callable] | None = None,
info: dict[str, int | float | str] | None = None,
) -> Callable[..., Component] | partial:
"""Parametrized Decorator for Component functions.
Args:
func: function to decorate.
autoname: True renames Component based on args and kwargs. True by default.
copy_if_cached: True by default. If the component is already in the cache, it returns a copy of the component.
max_name_length: truncates name beyond some characters with a hash. Defaults to CONF.max_name_length.
include_module: True adds module name to the cell name.
with_hash: True adds a hash to the cell name.
ports_offgrid: "warn", "error" or "ignore". Checks if ports are on grid. Defaults to CONF.ports_offgrid.
ports_not_manhattan: "warn", "error" or "ignore". Checks if ports are manhattan. Defaults to CONF.ports_non_manhattan.
flatten: False by default. True flattens component hierarchy.
naming_style: "default" or "updk". "default" is the default naming style.
default_decorator: default decorator to apply to the component. None by default.
add_settings: True by default. Adds settings to the component.
validate: validate the function call. Does not work with annotations that have None | Callable.
get_child_name: Use child name as component name prefix.
post_process: list of post processing functions to apply to the component.
info: dictionary with metadata to add to the component.
Implements a cache so that if a component has already been build it returns the component from the cache directly.
This avoids creating two exact Components that have the same name.
Can autoname components based on the function name and arguments.
A decorator is a function that runs over a function, so when you do.
.. code::
import gdsfactory as gf
@gf.cell
def mzi_with_bend():
c = gf.Component()
mzi = c << gf.components.mzi()
bend = c << gf.components.bend_euler()
return c
it’s equivalent to
.. code::
def mzi_with_bend():
c = gf.Component()
mzi = c << gf.components.mzi()
bend = c << gf.components.bend_euler(radius=radius)
return c
mzi_with_bend_decorated = gf.cell(mzi_with_bend)
"""
if func is None:
return partial(
cell,
autoname=autoname,
copy_if_cached=copy_if_cached,
max_name_length=max_name_length,
include_module=include_module,
with_hash=with_hash,
ports_offgrid=ports_offgrid,
ports_not_manhattan=ports_not_manhattan,
flatten=flatten,
naming_style=naming_style,
default_decorator=default_decorator,
add_settings=add_settings,
validate=validate,
get_child_name=get_child_name,
post_process=post_process,
info=info,
)
if default_decorator is not None:
warnings.warn(
"default_decorator is deprecated and will be removed soon. Use post_process instead.",
DeprecationWarning,
stacklevel=2,
)
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Component:
assert func is not None
nonlocal ports_not_manhattan, ports_offgrid, max_name_length
from gdsfactory.pdk import get_active_pdk
active_pdk = get_active_pdk()
name = kwargs.pop("name", None)
prefix = kwargs.pop("prefix", None)
metadata = info or {} # noqa
if name:
warnings.warn(
f"name is deprecated and will be removed soon. {func.__name__}",
stacklevel=2,
)
if prefix:
warnings.warn(
f"prefix is deprecated and will be removed soon. {func.__name__}",
stacklevel=2,
)
prefix = prefix or func.__name__
sig = inspect.signature(func)
args_as_kwargs = dict(zip(sig.parameters.keys(), args))
args_as_kwargs.update(kwargs)
if max_name_length is None:
max_name_length = CONF.max_name_length
if ports_offgrid is None:
ports_offgrid = CONF.ports_offgrid
if ports_not_manhattan is None:
ports_not_manhattan = CONF.ports_not_manhattan
default = {
p.name: p.default
for p in sig.parameters.values()
if p.default != inspect._empty
}
changed = args_as_kwargs
full = default.copy()
full.update(**args_as_kwargs)
default2 = default.copy()
changed2 = changed.copy()
# list of default args as strings
default_args_list = [
f"{key}={clean_value_name(default2[key])}" for key in sorted(default.keys())
]
# list of explicitly passed args as strings
passed_args_list = [
f"{key}={clean_value_name(changed2[key])}" for key in sorted(changed.keys())
]
if naming_style == "updk":
full_args_list = [
f"{key}={clean_value_name(full[key])}" for key in sorted(full.keys())
]
named_args_string = ",".join(full_args_list)
name = f"{prefix}:{named_args_string}" if named_args_string else prefix
name = clean_name(name, allowed_characters=[":", ".", "="])
elif naming_style == "default":
changed_arg_set = set(passed_args_list).difference(default_args_list)
changed_arg_list = sorted(changed_arg_set)
named_args_string = "_".join(changed_arg_list)
if include_module:
named_args_string += f"_{func.__module__}"
if changed_arg_list:
named_args_string = (
hashlib.md5(named_args_string.encode()).hexdigest()[:8]
if with_hash
or len(named_args_string) > 28
or "'" in named_args_string
or "{" in named_args_string
else named_args_string
)
name_signature = (
clean_name(f"{prefix}_{named_args_string}")
if named_args_string
else clean_value_name(prefix)
)
# filter the changed dictionary to only keep entries which have truly changed
changed_arg_names = [carg.split("=")[0] for carg in changed_arg_list]
changed = {k: changed[k] for k in changed_arg_names}
name = name or name_signature
else:
raise ValueError('naming_style must be "default" or "updk"')
name = get_name_short(name, max_name_length=max_name_length)
decorator = kwargs.pop("decorator", default_decorator)
# if no decorator is specified, but there is one specified for the active PDK, use the PDK's default decorator
if decorator is None and active_pdk.default_decorator is not None:
decorator = active_pdk.default_decorator
if decorator:
warnings.warn(
f"decorator is deprecated and will be removed soon. {func.__name__}",
stacklevel=2,
)
if name in CACHE:
# print(f"CACHE LOAD {name} {func.__name__}({named_args_string})")
return CACHE[name]
# print(f"BUILD {name} {func.__name__}({named_args_string})")
if not callable(func):
raise ValueError(
f"{func!r} is not callable! @cell decorator is only for functions"
)
if validate:
component = validate_call(func)(*args, **kwargs)
else:
component = func(*args, **kwargs)
if ports_offgrid in ("warn", "error"):
component.assert_ports_on_grid(error_type=ports_offgrid)
if ports_not_manhattan in ("warn", "error"):
component.assert_ports_manhattan(error_type=ports_not_manhattan)
if flatten:
component = component.flatten()
# if the component is already in the cache, but under a different alias,
# make sure we use a copy, so we don't run into mutability errors
if copy_if_cached and id(component) in CACHE_IDS:
component = component.copy()
if not isinstance(component, Component):
raise CellReturnTypeError(
f"function {func.__name__!r} return type = {type(component)}",
"make sure that functions with @cell decorator return a Component",
)
if get_child_name:
if not component.child:
raise ValueError(
f"component {component.name} does not have a child component. "
"Make sure you use component_new.copy_child_info(component)"
)
child_name = component.child.function_name
component_name = f"{child_name}_{name}"
component_name = get_name_short(
component_name, max_name_length=max_name_length
)
else:
component_name = name
if autoname:
component.rename(component_name, max_name_length=max_name_length)
if add_settings:
component.settings = CellSettings(**full)
component.function_name = func.__name__
component.module = func.__module__
component.__doc__ = func.__doc__
for post in post_process or []:
component = post(component)
component.info.update(metadata)
if decorator:
if not callable(decorator):
raise ValueError(f"decorator = {type(decorator)} needs to be callable")
component_new = decorator(component)
component = component_new or component
CACHE[name] = component
component._locked = True
CACHE_IDS.add(id(component))
return component
sig = inspect.signature(func)
wrapper.__signature__ = sig.replace(return_annotation=Component) # type: ignore
return wrapper
cell_without_validator = cell
cell_with_module = partial(cell, include_module=True)
cell_import_gds = partial(cell, autoname=False, add_settings=False)
cell_with_child = partial(cell, get_child_name=True)
@cell_with_child
def container(
component,
function: Callable[..., None] | None = None,
**kwargs,
) -> gf.Component:
"""Returns new component with a component reference.
Args:
component: to add to container.
function: function to apply to component.
kwargs: keyword arguments to pass to function.
"""
import gdsfactory as gf
component = gf.get_component(component)
c = Component()
cref = c << component
c.add_ports(cref.ports)
if function:
function(c, **kwargs)
c.copy_child_info(component)
return c
@cell_with_child
def component_with_function(
component,
function: Callable[..., None] | None = None,
**kwargs,
) -> gf.Component:
"""Returns new component with a component reference.
Args:
component: to add to container.
function: function to apply to component.
kwargs: keyword arguments to pass to component.
"""
import gdsfactory as gf
component = gf.get_component(component, **kwargs)
c = Component()
cref = c << component
c.add_ports(cref.ports)
if function:
function(c)
c.copy_child_info(component)
return c
if __name__ == "__main__":
from functools import partial
import gdsfactory as gf
c = partial(gf.components.mzi)
c = gf.routing.add_fiber_array(c)
c.show()
# c = gf.components.straight(info={"simulation": "eme"}, name="hi")
# c = gf.components.straight()
# c = gf.Component()
# print(type(c.info))
# print(c.name)
# print(c.info["simulation"])
# c.show()