from abc import ABC, abstractmethod
from dataclasses import dataclass
import jax
import jax.numpy as jnp
import numpy as onp
from typing import Literal
from .base_coil import Coil, _finite_size_from_data,_radial_vector_centroid_from_data, _frame_from_radial_vector
from jax_sbgeom.jax_utils import interpolate_array_modulo_broadcasted, interpolate_fractions_modulo
from jax_sbgeom.jax_utils.numerical import reverse_except_begin
import warnings
from functools import partial
from .base_coil import _rmf_radial_vector_from_data
[docs]
@jax.tree_util.register_dataclass
@dataclass(frozen=True)
class DiscreteCoil(Coil):
'''
Class representing a coil defined by discrete positions and a centre.
This centre is precomputed from the given positions. It does not feature in the computation of the coil positions; only the centre.
'''
positions : jnp.ndarray # [..., 3] Cartesian positions of discrete coil points
_centre_i : jnp.ndarray # Centre of the coil: is simply the mean of the positions. This could be a cached property, but this does not play well with JAX.
[docs]
@classmethod
def from_positions(cls, positions: jnp.ndarray):
'''
Create a DiscreteCoil from discrete positions
Positions are assumed to be non-periodic (in other words, the first and last point are not equal).
Parameters
----------
positions : jnp.ndarray
Cartesian positions of discrete coil points [..., 3]
Returns
-------
DiscreteCoil
DiscreteCoil object
'''
return cls(positions=positions, _centre_i = jnp.mean(positions, axis=-2))
[docs]
def position(self, s):
'''
Position along the coil as a function of arc length
Parameters
----------
s : jnp.ndarray
Arc length(s) along the coil
Returns
-------
jnp.ndarray
Cartesian position(s) along the coil
'''
return _discrete_coil_position(self, s)
[docs]
def tangent(self, s):
'''
Tangent vector along the coil as a function of arc length
Parameters
----------
s : jnp.ndarray
Arc length(s) along the coil
Returns
-------
jnp.ndarray
Tangent vector(s) along the coil
'''
return _discrete_coil_tangent(self, s)
[docs]
def centre(self):
'''
Centre of the coil
Returns
-------
jnp.ndarray
Centre of the coil
'''
return self._centre_i
[docs]
def normal(self, s):
'''
Normal vector along the coil as a function of arc length
Not defined for DiscreteCoil
Parameters
----------
s : jnp.ndarray
Arc length(s) along the coil
Returns
-------
jnp.ndarray
Normal vector(s) along the coil (jnp.nan)
'''
return jnp.full(jnp.array(s).shape + (3,), jnp.nan)
[docs]
def reverse_parametrisation(self):
'''
Reverse the parametrisation of the discrete coil. Reverses all points except the first (s=0 remains s=0).
'''
return _discrete_coil_reverse_parametrisation(self)
# ===================================================================================================================================================================================
# Implementation
# ===================================================================================================================================================================================
@jax.jit
def _discrete_coil_discrete_position(discrete_coil : DiscreteCoil, index):
return discrete_coil.positions[index % discrete_coil.positions.shape[0]]
@jax.jit
def _discrete_coil_discrete_tangent(discrete_coil : DiscreteCoil, index):
i1 = index + 1
pos_i0 = discrete_coil.positions[index % discrete_coil.positions.shape[0]]
pos_i1 = discrete_coil.positions[i1 % discrete_coil.positions.shape[0]]
tangent = pos_i1 - pos_i0
tangent = tangent / jnp.linalg.norm(tangent, axis=-1, keepdims=True)
return tangent
@jax.jit
def _discrete_coil_position(discrete_coil : DiscreteCoil, s):
return interpolate_array_modulo_broadcasted(discrete_coil.positions, s)
@jax.jit
def _discrete_coil_tangent(discrete_coil : DiscreteCoil, s):
i0, i1, ds = interpolate_fractions_modulo(s, discrete_coil.positions.shape[0])
pos_i0 = _discrete_coil_discrete_position(discrete_coil, i0)
pos_i1 = _discrete_coil_discrete_position(discrete_coil, i1)
tangent = pos_i1 - pos_i0
tangent = tangent / jnp.linalg.norm(tangent, axis=-1, keepdims=True)
return tangent
@jax.jit
def _discrete_coil_reverse_parametrisation(discrete_coil : DiscreteCoil):
return DiscreteCoil.from_positions(reverse_except_begin(discrete_coil.positions))
# ===================================================================================================================================================================================
# Finite Sizes
# ===================================================================================================================================================================================
# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Centroids
# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
@jax.jit
def _discrete_coil_radial_vector_centroid_index(discrete_coil : DiscreteCoil, index):
'''
Internal function to find the centroid radial vector at a discrete coil index
Parameters
----------
discrete_coil : DiscreteCoil
Discrete coil object
index : jnp.ndarray
Discrete coil indexs
Returns
-------
jnp.ndarray [..., 3]
Radial vector(s) at the discrete coil indexes
'''
pos_i = _discrete_coil_discrete_position(discrete_coil, index)
tangent_i = _discrete_coil_discrete_tangent(discrete_coil, index)
return _radial_vector_centroid_from_data(discrete_coil._centre_i, pos_i, tangent_i)
@jax.jit
def _discrete_coil_radial_vector_centroid(discrete_coil : DiscreteCoil, s):
'''
Internal function to find the centroid radial vector at arc length s
Interpolates radial vectors at surrounding data points
Parameters
----------
discrete_coil : DiscreteCoil
Discrete coil object
s : jnp.ndarray
Arc length(s) along the coil
Returns
-------
jnp.ndarray [..., 3]
Radial vector(s) along the coil
'''
i0, i1, ds = interpolate_fractions_modulo(s, discrete_coil.positions.shape[0])
radial_i0 = _discrete_coil_radial_vector_centroid_index(discrete_coil, i0)
radial_i1 = _discrete_coil_radial_vector_centroid_index(discrete_coil, i1)
return radial_i0 * (1.0 - ds)[..., jnp.newaxis] + radial_i1 * ds[..., jnp.newaxis]
@jax.jit
def _discrete_coil_finite_size_frame_centroid(discrete_coil : DiscreteCoil, s):
'''
Compute finite size frame at a location s along the discrete coil, using centroid method
Centroid frames at the discrete points are interpolated.
Parameters
----------
discrete_coil : DiscreteCoil
Discrete coil object
s : jnp.ndarray
Arc length(s) along the coil
Returns
-------
jnp.ndarray [..., 2, 3]
Finite size frame(s) along the coil
'''
i0, i1, ds = interpolate_fractions_modulo(s, discrete_coil.positions.shape[0])
radial_vector_i0 = _discrete_coil_radial_vector_centroid_index(discrete_coil, i0)
radial_vector_i1 = _discrete_coil_radial_vector_centroid_index(discrete_coil, i1)
tangent_i0 = _discrete_coil_discrete_tangent(discrete_coil, i0)
tangent_i1 = _discrete_coil_discrete_tangent(discrete_coil, i1)
frame_0 = _frame_from_radial_vector(tangent_i0, radial_vector_i0)
frame_1 = _frame_from_radial_vector(tangent_i1, radial_vector_i1)
return (1 - ds)[..., jnp.newaxis, jnp.newaxis] * frame_0 + ds[..., jnp.newaxis, jnp.newaxis] * frame_1
# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# Frenet-Serret
# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
@jax.jit
def _discrete_coil_radial_vector_frenet_serret(discrete_coil : DiscreteCoil, s):
'''
Internal function to find the frenet-serret radial vector at arc length s
Not valid for discrete coils due to vanishing curvature between the data points
Parameters
----------
discrete_coil : DiscreteCoil
Discrete coil object
s : jnp.ndarray
Arc length(s) along the coil
Returns
-------
jnp.ndarray [..., 3]
Radial vector(s) along the coil (jnp.nan)
'''
warnings.warn("Frenet-Serret frame is ill-defined for DiscreteCoil due to zero curvature. Returning NaN. ", RuntimeWarning)
return jnp.full(s.shape + (3,), jnp.nan)
# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# RMF
# -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------