from collections.abc import Sequence
from functools import lru_cache
from typing import Any, Literal
import networkx as nx
import numpy as np
from ultrasphere._coordinates import BranchingType, SphericalCoordinates
from ultrasphere._coordinates import SphericalCoordinates as cls
def _get_digraph_from_branching_type(
branching_types: str | Sequence[BranchingType],
) -> nx.DiGraph:
"""
Get a rooted tree from the branching types.
Parameters
----------
branching_types : str | Sequence[BranchingType]
The branching types. e.g. "ba" for standard spherical coordinates.
Returns
-------
nx.DiGraph
The rooted tree representing the coordinates.
Raises
------
ValueError
If the branching types are invalid.
"""
if isinstance(branching_types, str):
branching_types_str = branching_types.replace("bp", "b'")
branching_types_: list[BranchingType] = []
while branching_types_str:
if branching_types_str.startswith("b'"):
branching_types_.append(BranchingType.BP)
branching_types_str = branching_types_str[2:]
elif branching_types_str[0] in ["a", "b", "c"]:
branching_types_.append(BranchingType(branching_types_str[0]))
branching_types_str = branching_types_str[1:]
else:
raise ValueError(f"Invalid branching type: {branching_types_str}")
else:
branching_types_ = list(branching_types)
G = nx.DiGraph()
type_c_stack: list[Any] = []
next_e_idx = 0
next_s_idx = 0
current_node = _s_node_name_default(next_s_idx)
G.add_node(current_node)
next_s_idx += 1
for i, branching_type in enumerate(branching_types_):
if branching_type == BranchingType.A:
G.add_node(next_e_idx)
G.add_edge(current_node, next_e_idx, type="cos")
next_e_idx += 1
G.add_node(next_e_idx)
G.add_edge(current_node, next_e_idx, type="sin")
next_e_idx += 1
if i == len(branching_types_) - 1:
break
else:
try:
next_node = type_c_stack.pop()
except IndexError as e:
raise ValueError("Invalid branching types.") from e
elif branching_type == BranchingType.B:
G.add_node(next_e_idx)
G.add_edge(current_node, next_e_idx, type="cos")
next_e_idx += 1
G.add_node(_s_node_name_default(next_s_idx))
G.add_edge(current_node, _s_node_name_default(next_s_idx), type="sin")
next_node = _s_node_name_default(next_s_idx)
next_s_idx += 1
elif branching_type == BranchingType.BP:
G.add_node(_s_node_name_default(next_s_idx))
G.add_edge(current_node, _s_node_name_default(next_s_idx), type="cos")
next_node = _s_node_name_default(next_s_idx)
next_s_idx += 1
G.add_node(next_e_idx)
G.add_edge(current_node, next_e_idx, type="sin")
next_e_idx += 1
elif branching_type == BranchingType.C:
G.add_node(_s_node_name_default(next_s_idx))
G.add_edge(current_node, _s_node_name_default(next_s_idx), type="cos")
next_node = _s_node_name_default(next_s_idx)
next_s_idx += 1
G.add_node(_s_node_name_default(next_s_idx))
G.add_edge(current_node, _s_node_name_default(next_s_idx), type="sin")
type_c_stack.append(_s_node_name_default(next_s_idx))
next_s_idx += 1
current_node = next_node
return G
[docs]
def create_polar() -> 'SphericalCoordinates[Literal["phi"], Literal[0, 1]]':
r"""
Polar coordinates.
.. math::
x_0 &= r \cos(\phi) \\
x_1 &= r \sin(\phi)
Returns
-------
SphericalCoordinates
The polar coordinates.
Examples
--------
>>> c = create_polar()
>>> c
SphericalCoordinates(a)
>>> c.s_nodes
['phi']
>>> c.c_nodes
[0, 1]
"""
G = _get_digraph_from_branching_type("a")
G = nx.relabel_nodes(G, {"theta0": "phi"})
return cls(G)
[docs]
def create_spherical() -> (
'SphericalCoordinates[Literal["theta", "phi"], Literal[0, 1, 2]]'
):
r"""
Spherical coordinates.
.. math::
x_0 &= r \sin(\theta) \cos(\phi) \\
x_1 &= r \sin(\theta) \sin(\phi) \\
x_2 &= r \cos(\theta)
Returns
-------
SphericalCoordinates
The spherical coordinates.
Examples
--------
>>> c = create_spherical()
>>> c
SphericalCoordinates(ba)
>>> c.s_nodes
['theta', 'phi']
>>> c.c_nodes
[0, 1, 2]
"""
G = _get_digraph_from_branching_type("ba")
# swap x0 and x2
G = nx.relabel_nodes(G, {0: 2, 2: 1, 1: 0, "theta0": "theta", "theta1": "phi"})
return cls(G)
[docs]
def create_standard(s_ndim: int) -> "SphericalCoordinates[Any, Any]":
r"""
Standard spherical coordinates.
.. math::
x_0 &= \cos(\theta_0) \\
x_1 &= \sin(\theta_0) \cos(\theta_1) \\
x_2 &= \sin(\theta_0) \sin(\theta_1) \cos(\theta_2) \\
x_3 &= \sin(\theta_0) \sin(\theta_1) \sin(\theta_2) \cos(\theta_3) \\
&\vdots \\
Parameters
----------
s_ndim : int
The number of spherical dimensions.
Returns
-------
SphericalCoordinates
The standard coordinates.
Examples
--------
>>> c = create_standard(4)
>>> c
SphericalCoordinates(bbba)
>>> c.s_nodes
['theta0', 'theta1', 'theta2', 'theta3']
>>> c.c_nodes
[0, 1, 2, 3, 4]
"""
if s_ndim == 0:
return create_from_branching_types("")
return cls(_get_digraph_from_branching_type("b" * (s_ndim - 1) + "a"))
[docs]
def create_standard_prime(s_ndim: int) -> "SphericalCoordinates[Any, Any]":
r"""
Standard prime spherical coordinates.
.. math::
x_0 &= \sin(\theta_0) \\
x_1 &= \cos(\theta_0) \sin(\theta_1) \\
x_2 &= \cos(\theta_0) \cos(\theta_1) \sin(\theta_2) \\
x_3 &= \cos(\theta_0) \cos(\theta_1) \cos(\theta_2) \sin(\theta_3) \\
&\vdots \\
Parameters
----------
s_ndim : int
The number of spherical dimensions.
Returns
-------
SphericalCoordinates
The standard prime coordinates.
Examples
--------
>>> c = create_standard_prime(4)
>>> c
SphericalCoordinates(b'b'b'a)
>>> c.s_nodes
['theta0', 'theta1', 'theta2', 'theta3']
>>> c.c_nodes
[0, 1, 2, 3, 4]
"""
if s_ndim == 0:
return create_from_branching_types("")
return cls(_get_digraph_from_branching_type("bp" * (s_ndim - 1) + "a"))
[docs]
def create_hopf(q: int) -> "SphericalCoordinates[Any, Any]":
"""
Hopf coordinates.
Parameters
----------
q : int
Where 2^q = c.c_ndim.
Returns
-------
SphericalCoordinates
The Hopf coordinates.
Examples
--------
>>> c = create_hopf(3)
>>> c
SphericalCoordinates(ccaacaa)
>>> c.s_nodes
['theta0', 'theta1', 'theta2', 'theta3', 'theta4', 'theta5', 'theta6']
>>> c.c_nodes
[0, 1, 2, 3, 4, 5, 6, 7]
"""
@lru_cache
def _hoph(q: int) -> str:
if q < 0:
raise ValueError("q should be non-negative.")
elif q == 0:
return ""
elif q == 1:
return "a"
return f"c{_hoph(q - 1)}{_hoph(q - 1)}"
return cls(_get_digraph_from_branching_type(_hoph(q)))
[docs]
def create_from_branching_types(
branching_types: str | Sequence[BranchingType],
) -> "SphericalCoordinates[Any, Any]":
"""
Spherical coordinates from branching types.
Parameters
----------
branching_types : str | Sequence[BranchingType]
The branching types. e.g. "ba" for standard spherical coordinates.
Returns
-------
SphericalCoordinates
The spherical coordinates.
"""
return cls(_get_digraph_from_branching_type(branching_types))
def _s_node_name_default(idx: int) -> Any:
"""
The naming convention for the spherical node.
Parameters
----------
idx : int
The index of the spherical node.
Returns
-------
str
The name of the spherical node.
"""
return f"theta{idx}"
def _e_node_name_default(idx: int) -> Any:
"""
The naming convention for the Cartesian node.
Parameters
----------
idx : int
The index of the Cartesian node.
Returns
-------
str
The name of the Cartesian node.
"""
return idx
def _get_random_digraph(
s_ndim: int, *, rng: np.random.Generator | None = None
) -> nx.DiGraph:
"""
Get a random rooted tree representing the coordinates.
Parameters
----------
s_ndim : int
The number of spherical dimensions.
rng : np.random.Generator | None, optional
The random number generator, by default None
Returns
-------
nx.DiGraph
The rooted tree representing the coordinates.
"""
rng = np.random.default_rng() if rng is None else rng
G = nx.DiGraph()
leaf_nodes = [0]
G.add_node(0)
for _ in range(s_ndim):
node_parent = rng.choice(leaf_nodes)
leaf_nodes.remove(node_parent)
node_cos = len(G)
node_sin = len(G) + 1
for type, node in [("cos", node_cos), ("sin", node_sin)]:
G.add_node(node)
G.add_edge(node_parent, node, type=type)
leaf_nodes.append(node)
non_leaf_nodes = set(G.nodes) - set(leaf_nodes)
G = nx.relabel_nodes(
G,
{node: _e_node_name_default(i) for i, node in enumerate(leaf_nodes)}
| {node: _s_node_name_default(i) for i, node in enumerate(non_leaf_nodes)},
)
return G
[docs]
def create_random(
s_ndim: int, *, rng: np.random.Generator | None = None
) -> "SphericalCoordinates[Any, Any]":
"""
Get a random spherical coordinates.
Parameters
----------
s_ndim : int
The number of spherical dimensions.
rng : np.random.Generator | None, optional
The random number generator, by default None
Returns
-------
SphericalCoordinates
The random spherical coordinates.
Examples
--------
>>> from array_api_compat import numpy as np
>>> rng = np.random.default_rng(0)
>>> c = create_random(5, rng=rng)
>>> c
SphericalCoordinates(bcb'aa)
>>> c.s_nodes
['theta0', 'theta1', 'theta2', 'theta3', 'theta4']
>>> c.c_nodes
[0, 1, 2, 3, 4, 5]
"""
return cls(_get_random_digraph(s_ndim, rng=rng))