from collections.abc import Callable, Mapping
from typing import Any, Literal, overload
import array_api_extra as xpx
import numpy as np
from array_api._2024_12 import Array, ArrayNamespaceFull
from scipy.special import roots_jacobi
from ._coordinates import (
BranchingType,
SphericalCoordinates,
TCartesian,
TSpherical,
get_child,
)
[docs]
def roots(
c: SphericalCoordinates[TSpherical, TCartesian],
n: int,
*,
expand_dims_x: bool,
expand_dims_w: bool = False,
device: Any | None = None,
dtype: Any | None = None,
xp: ArrayNamespaceFull,
) -> tuple[Mapping[TSpherical, Array], Mapping[TSpherical, Array]]:
r"""
Gauss-Jacobi quadrature roots and weights.
.. math::
\int_\mathbb{S}^{d-1} f d\omega^{d-1} \approx
\sum_{(\theta_1, w_1)} w_1 \cdots \sum_{(\theta_{d-1}, w_{d-1})} w_{d-1}
f(\theta_1, \ldots, \theta_{d-1})
Parameters
----------
n : int
The number of roots.
expand_dims_x : bool
Whether to expand dimensions of the roots, by default False
expand_dims_w : bool, optional
Whether to expand dimensions of the weights, by default False
device : Any, optional
The device, by default None
dtype : Any, optional
The data type, by default None
Returns
-------
tuple[Mapping[TSpherical, Array], Mapping[TSpherical, Array]]
roots and weights
Raises
------
ValueError
If the branching type is invalid.
Example
-------
>>> from array_api_compat import numpy as np
>>> import ultrasphere as us
>>> c = us.create_spherical()
>>> xs, ws = us.roots(c, 5, expand_dims_x=True, expand_dims_w=True, xp=np)
>>> xs
{'theta': array([[2.7049577 ],
[2.13941585],
[1.57079633],
[1.0021768 ],
[0.43663495]]), 'phi': array([[0. , 0.62831853, 1.25663706, 1.88495559, 2.51327412,
3.14159265, 3.76991118, 4.39822972, 5.02654825, 5.65486678]])}
>>> ws
{'theta': array([[0.23692689],
[0.47862867],
[0.56888889],
[0.47862867],
[0.23692689]]), 'phi': array([[0.62831853, 0.62831853, 0.62831853, 0.62831853, 0.62831853,
0.62831853, 0.62831853, 0.62831853, 0.62831853, 0.62831853]])}
"""
xs = {}
ws = {}
for i, node in enumerate(c.s_nodes):
branching_type = c.branching_types[node]
if branching_type == BranchingType.A:
x = xp.arange(2 * n, device=device, dtype=dtype) * xp.pi / n
w = xp.ones(2 * n, device=device, dtype=dtype) * xp.pi / n
elif branching_type == BranchingType.B:
s_beta = c.S[get_child(c.G, node, "sin")]
beta = s_beta / 2
x, w = roots_jacobi(n, beta, beta)
x = np.acos(x)
elif branching_type == BranchingType.BP:
s_alpha = c.S[get_child(c.G, node, "cos")]
alpha = s_alpha / 2
x, w = roots_jacobi(n, alpha, alpha)
x = np.asin(x)
elif branching_type == BranchingType.C:
s_alpha = c.S[get_child(c.G, node, "cos")]
s_beta = c.S[get_child(c.G, node, "sin")]
alpha = s_alpha / 2
beta = s_beta / 2
x, w = roots_jacobi(n, alpha, beta)
w /= 2 ** (alpha + beta + 2)
x = np.acos(x) / 2
else:
raise ValueError(f"Invalid branching type {branching_type}.")
x = xp.asarray(x, device=device, dtype=dtype)
w = xp.asarray(w, device=device, dtype=dtype)
if expand_dims_x:
x = x[(None,) * i + (slice(None),) + (None,) * (c.s_ndim - i - 1)]
if expand_dims_w:
w = w[(None,) * i + (slice(None),) + (None,) * (c.s_ndim - i - 1)]
xs[node] = x
ws[node] = w
return xs, ws
@overload
def integrate(
c: SphericalCoordinates[TSpherical, TCartesian],
f: (
Callable[
[Mapping[TSpherical, Array]],
Mapping[TSpherical, Array],
]
| Mapping[TSpherical, Array]
),
does_f_support_separation_of_variables: Literal[True],
n: int,
*,
xp: ArrayNamespaceFull,
device: Any | None = None,
dtype: Any | None = None,
) -> Mapping[TSpherical, Array]: ...
@overload
def integrate(
c: SphericalCoordinates[TSpherical, TCartesian],
f: (
Callable[
[Mapping[TSpherical, Array]],
Array,
]
| Array
),
does_f_support_separation_of_variables: Literal[False],
n: int,
*,
xp: ArrayNamespaceFull,
device: Any | None = None,
dtype: Any | None = None,
) -> Array: ...
[docs]
def integrate(
c: SphericalCoordinates[TSpherical, TCartesian],
f: (
Callable[
[Mapping[TSpherical, Array]],
Mapping[TSpherical, Array] | Array,
]
| Mapping[TSpherical, Array]
| Array
),
does_f_support_separation_of_variables: bool,
n: int,
*,
xp: ArrayNamespaceFull,
device: Any | None = None,
dtype: Any | None = None,
) -> Array | Mapping[TSpherical, Array]:
r"""
Integrate the function over the hypersphere.
.. math::
\int_{\mathbb{S}^{d-1}} f d\omega^{d-1}
Parameters
----------
f : Callable[ [Mapping[TSpherical, Array]], Mapping[TSpherical, Array] | Array, ] | Mapping[TSpherical, Array] | Array # noqa: E501
The function to integrate or the values of the function.
If mapping, the separated parts of the function for each spherical coordinate.
If mapping, the shapes do not need to be broadcastable.
If function, if does_f_support_separation_of_variables is True,
1D array of integration points are passed,
and extra axis should be added to the last dimension.
If function, if does_f_support_separation_of_variables is False,
``c.s_ndim``-D array of integration points are passed,
and extra axis should be added to the last dimension.
does_f_support_separation_of_variables : bool
Whether the function supports separation of variables.
This could significantly reduce the computational cost.
n : int
The number of roots.
device : Any, optional
The device, by default None
dtype : Any, optional
The data type, by default None
Returns
-------
Array | Mapping[TSpherical, Array]
The integrated value.
Has the same shape as the return values of f or the values of f.
Example
-------
>>> from array_api_compat import numpy as np
>>> import ultrasphere as us
>>> c = us.create_spherical()
>>> f = lambda spherical: spherical["theta"] ** 2 * spherical["phi"]
>>> np.round(us.integrate(
... c,
... f,
... False, # does not support separation of variables
... 10, # number of quadrature points
... xp=np # the array namespace
... ), 5)
np.float64(110.02621)
"""
xs, ws = roots(
c,
n,
device=device,
dtype=dtype,
expand_dims_x=not does_f_support_separation_of_variables,
xp=xp,
)
if isinstance(f, Callable): # type: ignore
try:
val = f(xs) # type: ignore
except Exception as e:
raise e
else:
val = f
# in case f(theta1, ...) = f_1(theta1) * f_2(theta2) * ...
if isinstance(val, Mapping):
result = {}
for node in c.s_nodes:
value = val[node]
# supports vectorized function
# axis=0 because in sph_harm
# we add axis to the last dimension
# theta(node),u1,...,uM
xpx.broadcast_shapes(value.shape[:1], ws[node].shape)
w = xp.reshape(ws[node], (-1,) + (1,) * (value.ndim - 1))
if value.shape[0] == 1:
result[node] = value[0, ...] * xp.sum(w)
else:
result[node] = xp.vecdot(w, value, axis=0)
# we don't know how to einsum the result
return result
if val.ndim < c.s_ndim:
raise ValueError(
f"The dimension of the return value of f should be at least {c.s_ndim}, got {val.ndim}."
)
xpx.broadcast_shapes(
val.shape[: c.s_ndim],
xpx.broadcast_shapes(*(xs[node].shape for node in c.s_nodes)),
)
# theta1,...,thetaN,u1,...,uM\
for node in c.s_nodes:
w = ws[node]
if val.shape[0] == 1:
val = val[0, ...] * xp.sum(w)
else:
val = xp.vecdot(w[(slice(None),) + (None,) * (val.ndim - 1)], val, axis=0)
return val