import warnings
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from matplotlib.axes import Axes
from matplotlib.lines import Line2D
from matplotlib.patches import Circle
from networkx.algorithms.dag import dag_longest_path
from networkx.drawing.nx_agraph import graphviz_layout
from ultrasphere._coordinates import SphericalCoordinates, TCartesian, TSpherical
_ASCII_TO_GREEK = {
"alpha": "α", # noqa
"beta": "β",
"gamma": "γ", # noqa
"delta": "δ",
"epsilon": "ε",
"zeta": "ζ",
"theta": "Θ",
"iota": "ι", # noqa
"kappa": "κ",
"lambda": "λ",
"mu": "μ",
"nu": "ν", # noqa
"xi": "ξ",
"pi": "π",
"rho": "ρ", # noqa
"sigma": "σ", # noqa
"tau": "τ",
"upsilon": "υ", # noqa
"phi": "φ",
"chi": "χ",
"psi": "ψ",
"omega": "ω",
"eta": "η", # last
}
def _ascii_to_greek(s: str) -> str:
for k, v in _ASCII_TO_GREEK.items():
if s.startswith(k):
s = s.replace(k, v + "_{") + "}"
return s
[docs]
def draw(
c: SphericalCoordinates[TSpherical, TCartesian],
root_bottom: bool = True,
ax: Axes | None = None,
) -> tuple[float, float]:
"""
Nicely draw the rooted tree representing the coordinates using matplotlib.
Parameters
----------
root_bottom : bool, optional
Whether to draw the root at the bottom, by default True
Returns
-------
tuple[float, float]
The recommended width and height of the figure (in inches).
Example
-------
.. skip: start
>>> import ultrasphere as us
>>> c = us.create_from_branching_types("ccabbab'b'ba")
>>> us.draw(c)
(6.5, 3.5)
.. skip: end
.. image:: /_static/coordinates.*
"""
# plt.rcParams["text.usetex"] = True
# remove spines
ax = ax or plt.gca()
fig = ax.figure
ax.set_frame_on(False)
ax.grid(False)
width = max(c.c_ndim * 0.5 + 1, 3.5)
height = max((len(dag_longest_path(c.G)) + 1) * 0.5, 3.5)
additional_width = 1.2
fig.set_size_inches(width + additional_width, height)
fig.subplots_adjust(
right=0.9 - additional_width / (width + additional_width),
left=0,
top=0.9,
bottom=0,
)
# layout
try:
pos = graphviz_layout(c.G, prog="dot", args='-GTBbalance="max"')
if root_bottom:
# invert y-axis
y_center = np.mean([y for x, y in pos.values()])
pos = {k: (x, 2 * y_center - y) for k, (x, y) in pos.items()}
except FileNotFoundError as e:
warnings.warn(
"Graphviz is not installed. "
"The layout will be calculated by spring layout.",
RuntimeWarning,
stacklevel=2,
source=e,
)
pos = nx.spring_layout(c.G)
# Spherical
nx.draw_networkx_nodes(
c.G,
pos,
nodelist=c.s_nodes,
node_color="darkgray",
node_size=850,
label="Spherical",
ax=ax,
margins=0.1,
)
nx.draw_networkx_labels(
c.G,
pos,
labels={
n: f"${_ascii_to_greek(str(n))}$\n{c.branching_types[n].value}/{c.S[n]}"
for n in c.s_nodes
},
ax=ax,
)
# Cartesian
nx.draw_networkx_nodes(
c.G,
pos,
nodelist=c.c_nodes,
node_color="lightgray",
node_shape="s",
label="Cartesian",
ax=ax,
)
nx.draw_networkx_labels(c.G, pos, labels={n: f"{n}" for n in c.c_nodes}, ax=ax)
# edges
cos_color = "orange"
sin_color = "blue"
nx.draw_networkx_edges(
c.G,
pos,
edgelist=c.cos_edges,
edge_color=cos_color,
label="cos",
style="dashed",
ax=ax,
)
nx.draw_networkx_edges(
c.G, pos, edgelist=c.sin_edges, edge_color=sin_color, label="sin", ax=ax
)
# legend
handles = [
Circle(
(0, 0),
0.25,
facecolor="darkgray",
label="Spherical (Name\nBranching type\n/Descendants)",
),
Circle((0, 0), 0.12, facecolor="lightgray", label="Cartesian (Name)"),
Line2D([0], [0], color=cos_color, lw=2, label="cos", linestyle="dashed"),
Line2D([0], [0], color=sin_color, lw=2, label="sin"),
]
fig.legend(
handles=handles,
loc="lower right" if root_bottom else "upper right",
)
ax.set_title(f"Type {c.branching_types_expression_str} coordinates")
return width, height