Source code for jax_sbgeom.jax_utils.splines

import jax
import jax.numpy as jnp
from functools import partial
from dataclasses import dataclass
import equinox as eqx

@partial(jax.jit, static_argnums = (3,))
def _cox_de_boor(x : jnp.ndarray, t : jnp.ndarray, i : int, k : int) -> jnp.ndarray:
    """
    Cox-de Boor recursive definition of B-splines. 
    Since k is fixed, the recursion depth is fixed and can be jitted.

    Parameters
    ----------
    x : jnp.ndarray [M]
        Points at which to evaluate the basis function.
    t : jnp.ndarray [N]
        Knot vector
    i : int  [0,..,N-k-2]
        Knot index
    k : int
        Degree of the basis function.
    Returns 
    -------
    jnp.ndarray [M]
        Values of the basis function at points x

    """


@partial(jax.jit, static_argnums = (3,4))
def _cox_de_boor_gradients(x, t, i, k, derivative):
    """
    Cox-de Boor recursive definition of B-splines. 
    To jit this function, the recursion depth must be fixed. Therefore, the degree k and derivative order are static arguments
    and standard python control flow is used.

    Parameters
    ----------
    x : jnp.ndarray [M]
        Points at which to evaluate the basis function.
    t : jnp.ndarray [N]
        Knot vector
    i : int  [0,..,N-k-2]
        Knot index
    k : int
        Degree of the basis function.
    derivative : int
        Order of the derivative to compute
    Returns 
    -------
    jnp.ndarray [M]
        Values of the basis function at points x

    """

    if k == 0:
        return jnp.where((t[i] <= x) & (x < t[i+1]), 1.0, 0.0) if derivative == 0 else jnp.zeros_like(x)
    else:
        denom1 = t[i+k] - t[i]
        denom2 = t[i+k+1] - t[i+1]

        if derivative == 0:
            term1 = jnp.where(denom1 != 0, (x - t[i]) / denom1 * _cox_de_boor_gradients(x, t, i, k-1, 0), 0.0)
            term2 = jnp.where(denom2 != 0, (t[i+k+1] - x) / denom2 * _cox_de_boor_gradients(x, t, i+1, k-1, 0), 0.0)
            return term1 + term2
        else:
            term1 = jnp.where(denom1 != 0, k / denom1 * _cox_de_boor_gradients(x, t, i, k-1, derivative-1), 0.0)
            term2 = jnp.where(denom2 != 0, -k / denom2 * _cox_de_boor_gradients(x, t, i+1, k-1, derivative-1), 0.0)
            return term1 + term2

[docs] @partial(jax.jit, static_argnums = (3, 4)) def bspline(x : jnp.ndarray, t : jnp.ndarray, c : jnp.ndarray, k : int, derivative : int): ''' Evaluate a B-spline at points x. Note that it is required that c.shape[0] == t.shape[0] - k - 1 Parameters ---------- x : jnp.ndarray [M] Points at which to evaluate the B-spline. t : jnp.ndarray [N] Knot vector c : jnp.ndarray [N - k - 1] Coefficients of the B-spline basis functions. k : int Degree of the B-spline. derivative : int Order of the derivative to compute Returns ------- jnp.ndarray [M] Values of the B-spline at points x ''' return jnp.dot(c, jax.vmap(_cox_de_boor_gradients, in_axes=(None, None, 0, None, None))(x,t, jnp.arange(t.shape[0] - k - 1), k, derivative))
def _periodic_knots(x : jnp.ndarray, k :int): ''' Creates knot vector for periodic B-spline interpolation. See scipy.interpolate.BSpline._periodic_knots for reference. Parameters ---------- x : jnp.ndarray [n] Data points to be interpolated. k : int Degree of the B-spline. Returns ------- jnp.ndarray [n + 2*k] Knot vector for periodic B-spline interpolation. ''' xc = x n = xc.shape[0] if k%2 == 0: dx = jnp.diff(xc) xc = xc.at[1:-1].set(xc[1:-1] - 0.5 * (dx[:-1] )) dx = jnp.diff(xc) t = jnp.zeros(n + 2 * k) t = t.at[k:-k].set(xc) for i in range(k): t = t.at[k - i - 1].set(t[k - i] - dx[-(i % (n - 1)) - 1]) t = t.at[-k+i].set(t[-k + i - 1] + dx[i % (n - 1)]) return t bspline_vectorized = jax.jit(jnp.vectorize(bspline, excluded = (3,4), signature = '(k),(n),(m)->(k)'), static_argnums=(3,4))
[docs] @jax.tree_util.register_dataclass @dataclass(frozen=True) class BSpline: ''' Convenience class for representing an arbitrarily batched B-spline. Data ---------- t : jnp.ndarray [..., N] Knot vector c : jnp.ndarray [..., N - k - 1] Coefficients of the B-spline basis functions. k : int Degree of the B-spline. ''' t : jnp.ndarray c : jnp.ndarray k : int def __call__(self, x : jnp.ndarray, derivative : int = 0) -> jnp.ndarray: ''' Evaluate the B-spline at points x with derivative order. Can be arbitrarily batched, but x is not batched and evaluated for the same x points for all splines. ''' return bspline_vectorized(x, self.t, self.c, int(self.k), derivative)
@partial(jax.jit, static_argnums = (2,)) def _fit_periodic_interpolating_spline(x,y,k:int): ''' Fit a periodic interpolating B-spline to data points (x,y) of degree k. The resulting spline satisfies S(x[i]) = y[i] for all i and is periodic (including derivatives up to < k). Currently construct a fully dense matrix and directly solves the linear system. See scipy.interpolate._bsplines._make_periodic_spline for a more advanced implementation. This matrix is: A_{ij} = B_{j}^0(x[i]) for i < n-1 (periodic last point not used) A_{n-1 + m, j} = B_{j}^{(m)}(x[0]) - B_{j}^{(m)}(x[n-1]) for m in 0,..,k-1 (periodic derivative conditions) where m is the derivative order and B_{j}^{(m)} is the m-th derivative of the j-th B-spline basis function (see _cox_de_boor_gradients). Parameters ---------- x : jnp.ndarray [n] Data points to be interpolated. y : jnp.ndarray [n] Data values to be interpolated. k : int Degree of the B-spline. Returns ------- t_knots : jnp.ndarray [n + 2*k] Knot vector for periodic B-spline interpolation. c_coeff : jnp.ndarray [n + k - 1] Coefficients of the B-spline basis functions. ''' t_knots = _periodic_knots(x, k) nt = len(t_knots) - k - 1 total_AA = jax.vmap(_cox_de_boor_gradients, in_axes = (None, None, 0, None, None))(x[:-1], t_knots, jnp.arange(nt), k, 0) rhs = jnp.concatenate([y[:-1], jnp.zeros(k)]) derivs = [] for m in range(k): derivs.append(jax.vmap(_cox_de_boor_gradients, in_axes=(None, None, 0, None, None))(jnp.array([x[0], x[-1]]), t_knots, jnp.arange(nt), k, m)) A_total = jnp.zeros((nt,nt)) A_total = A_total.at[:-k, :].set(total_AA.T) for i in range(k): A_total = A_total.at[-i-1, :].set(derivs[i][:, -1] - derivs[i][:, 0]) c_coeff = jnp.linalg.solve(A_total, rhs) return t_knots, c_coeff _fit_periodic_interpolating_spline_vec = jax.jit(jnp.vectorize(_fit_periodic_interpolating_spline, signature = '(n),(n)->(m),(k)', excluded=(2,)), static_argnums=(2,))
[docs] @partial(jax.jit, static_argnums = (2,)) def periodic_interpolating_spline(x,y, k : int): ''' Fit a periodic interpolating B-spline to batched data points (x,y) of degree k. The resulting spline satisfies S(x[i]) = y[i] for all i and is periodic (including derivatives up to < k). Parameters ---------- x : jnp.ndarray [..., n] Data points to be interpolated. y : jnp.ndarray [..., n] Data values to be interpolated. k : int Degree of the B-spline. Returns ------- BSpline Fitted periodic interpolating B-spline. ''' return BSpline(*_fit_periodic_interpolating_spline_vec(x,y, k), k)
#======================================================================= # BSpline for paramatrization #=======================================================================
[docs] def make_periodic_knots(n: int, k: int) -> jnp.ndarray: """ Build a periodic knot vector for n control points, degree k, on the domain [0, 2*pi). This is different than _periodic_knots, which builds a knot vector for periodic interpolation of given data points. This function is for building a knot vector for periodic B-spline parametrization. Parameters ---------- n : int Number of control points. k : int Degree of the B-spline. Returns ------- jnp.ndarray [n + 2*k + 1] Knot vector for periodic B-spline interpolation. """ base = jnp.linspace(0, 2 * jnp.pi, n + 1) # n+1 pts, spacing = 2pi/n # prepend k knots by wrapping backward left = base[-k-1:-1] - 2 * jnp.pi # k knots < 0 # append k knots by wrapping forward right = base[1:k+1] + 2 * jnp.pi # k knots > 2pi return jnp.concatenate([left, base, right])
[docs] def make_periodic_coeffs(c: jnp.ndarray, k: int) -> jnp.ndarray: """ Extend n coefficients to n+k by wrapping the first k, enforcing periodicity. Parameters ---------- c : jnp.ndarray [n] Coefficients of the B-spline basis functions. k : int Degree of the B-spline. Returns ------- jnp.ndarray [n + k] Extended coefficients for periodic B-spline interpolation. """ return jnp.concatenate([c, c[:k]])
[docs] @eqx.filter_jit def periodic_bspline(theta: jnp.ndarray, c: jnp.ndarray, k: int, derivative: int = 0): """ Function to evaluate a periodic B-spline at angles theta given coefficients c and degree k. Note that the coefficients c are the free parameters, which do not map directly to a control point given by their index. Use :func`greville_abscissa_periodic_bspline` to compute the control points location. Parameters ---------- theta : jnp.ndarray [M] Angles at which to evaluate the B-spline, in radians. c : jnp.ndarray [n] Coefficients of the B-spline basis functions. Must have length n, which is the number of free parameters. Internally extended to n+k for periodicity. k : int Degree of the B-spline. derivative : int, optional Order of the derivative to compute. Default is 0 (the function itself). Returns ------- jnp.ndarray [M] Values of the periodic B-spline (or its derivative) at the angles theta. """ n = c.shape[0] t = make_periodic_knots(n, k) # [n + 2k + 1] c_ext = make_periodic_coeffs(c, k) # [n + k] (= (n+2k+1) - k - 1 required for bspline) theta_wrapped = jnp.mod(theta, 2 * jnp.pi) return bspline(theta_wrapped, t, c_ext, k, derivative)
[docs] @eqx.filter_jit def greville_abscissa_periodic_bspline(n : int, k: int) -> jnp.ndarray: t = make_periodic_knots(n, k) # [n + 2k + 1] theta_ctrl = jnp.array([jnp.mean(t[i:i+k+2]) for i in range(n)]) # [n] return jnp.mod(theta_ctrl, 2 * jnp.pi)