utils ¤
circulax utilities.
Functions:
| Name | Description |
|---|---|
apply_global_params | Forward global scalar params to all component groups that declare them. |
update_group_params | Updates a parameter for ALL instances in a component group. |
update_params_dict | Updates a parameter for a specific instance within a component group. |
apply_global_params ¤
Forward global scalar params to all component groups that declare them.
For each (param_name, value) pair in params, updates every group whose batched params object has an attribute with that name, broadcasting the value to all instances in that group.
Works correctly under jax.jit and jax.vmap: the dict walk is Python-level (static at trace time), and value is the only traced leaf.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
groups | dict | Compiled groups dict as returned by :func: | required |
params | dict | Mapping from parameter name to scalar JAX-traceable value. | required |
Returns:
| Type | Description |
|---|---|
dict | New groups dict with updated parameter values (immutable functional update). |
Source code in circulax/utils.py
update_group_params ¤
update_group_params(
groups_dict: dict, group_name: str, param_key: str, new_value: float
) -> dict[str, ComponentGroup]
Updates a parameter for ALL instances in a component group.
Source code in circulax/utils.py
update_params_dict ¤
update_params_dict(
groups_dict: dict,
group_name: str,
instance_name: str,
param_key: str,
new_value: float,
) -> dict[str, ComponentGroup]
Updates a parameter for a specific instance within a component group.