Source code for jax_sbgeom.coils.fourier_coil

from abc import ABC, abstractmethod
from dataclasses import dataclass
import jax 
import jax.numpy as jnp
import numpy as onp
from typing import Literal
from .base_coil import Coil
from .base_coil import _radial_vector_centroid_from_data, _frame_from_radial_vector
from .discrete_coil import DiscreteCoil
from jax_sbgeom.jax_utils import stack_jacfwd
from functools import partial
from .coilset import CoilSet
import jax_sbgeom

[docs] @jax.tree_util.register_dataclass @dataclass(frozen=True) class FourierCoil(Coil): ''' Class representing a coil defined by Fourier coefficients. It uses the Fourier series to compute the position, tangent and normal along the coil. x = centre_i[0] + sum_{n=1}^N [ fourier_cos[0, n] * cos(2 pi n s) + fourier_sin[0, n] * sin(2 pi n s) ] y = centre_i[1] + sum_{n=1}^N [ fourier_cos[1, n] * cos(2 pi n s) + fourier_sin[1, n] * sin(2 pi n s) ] z = centre_i[2] + sum_{n=1}^N [ fourier_cos[2, n] * cos(2 pi n s) + fourier_sin[2, n] * sin(2 pi n s) ] For creating FourierCoil objects from discrete positions, use the function `curve_to_fourier_coefficients'`. The parametrisation can be converted to equal-arclength by using `convert_fourier_coil_to_equal_arclength` or `convert_fourier_coilset_to_equal_arclength`. Parameters ---------- fourier_cos : jnp.ndarray Fourier cosine coefficients [N_modes, 3] fourier_sin : jnp.ndarray Fourier sine coefficients [N_modes, 3] centre_i : jnp.ndarray Centre of the coil [3] ''' fourier_cos : jnp.ndarray fourier_sin : jnp.ndarray centre_i : jnp.ndarray
[docs] def position(self, s): ''' Position along the coil as a function of arc length Parameters ---------- s : jnp.ndarray Arc length(s) along the coil Returns ------- jnp.ndarray Cartesian position(s) along the coil ''' return _fourier_position(self, s)
[docs] def tangent(self, s): ''' Tangent vector along the coil as a function of arc length Parameters ---------- s : jnp.ndarray Arc length(s) along the coil Returns ------- jnp.ndarray Tangent vector(s) along the coil ''' return _fourier_tangent(self, s)
[docs] def normal(self, s): ''' Normal vector along the coil as a function of arc length Parameters ---------- s : jnp.ndarray Arc length(s) along the coil Returns ------- jnp.ndarray Normal vector(s) along the coil ''' return _fourier_normal(self, s)
[docs] def centre(self): return self.centre_i
[docs] def reverse_parametrisation(self): return _fourier_reverse_parametrisation(self)
#=================================================================================================================================================================== # Implementation #=================================================================================================================================================================== @jax.jit def _fourier_position(coil : FourierCoil, s): ''' Position along the coil as a function of arc length Parameters ---------- coil : FourierCoil Coil object s : jnp.ndarray Arc length(s) along the coil Returns ------- jnp.ndarray Cartesian position(s) along the coil ''' n_modes = coil.fourier_cos.shape[-2] n = jnp.arange(1.0, n_modes + 1.0) * 2 * jnp.pi # shape (N_modes,) # The final shape should be (fourier_coil_batch_dimensions, s_shape, 3) final_shape = jnp.array(s).shape + (3,) initial_sum = jnp.zeros(final_shape) def fourier_sum(vals, i): xyz = vals # shape (..., s_shape, 3) angle_cos = jnp.cos(n[i] * s) # shape (s_shape,) angle_sin = jnp.sin(n[i] * s) # shape (s_shape,) xyz = xyz + coil.fourier_cos[i, :] * angle_cos[..., None] + coil.fourier_sin[ i, :] * angle_sin[..., None] return xyz, None # Fourier_cos is shape (..., N_modes, 3) where ... are batch dimensions # so we need to create an output shape of (..., s_shape, 3) xyz = jax.lax.scan(fourier_sum, initial_sum, jnp.arange(n_modes))[0] return xyz + coil.centre_i _grad_fourier_position = jax.jit(jnp.vectorize(stack_jacfwd(_fourier_position, argnums=1), excluded=(0,), signature='()->(3)')) @jax.jit def _fourier_tangent(coil : FourierCoil, s): ''' Tangent vector along the coil as a function of arc length Parameters ---------- coil : FourierCoil Coil object s : jnp.ndarray Arc length(s) along the coil Returns ------- jnp.ndarray Tangent vector(s) along the coil ''' grad_pos = _grad_fourier_position(coil, s) # shape (..., 3) tangent = grad_pos / jnp.linalg.norm(grad_pos, axis=-1, keepdims=True) return tangent _grad_grad_fourier_position = jax.jit(jnp.vectorize(stack_jacfwd(_fourier_tangent, argnums=1), excluded=(0,), signature='()->(3)')) @jax.jit def _fourier_normal(coil : FourierCoil, s): ''' Normal vector along the coil as a function of arc length Parameters ---------- coil : FourierCoil Coil object s : jnp.ndarray Arc length(s) along the coil Returns ------- jnp.ndarray Normal vector(s) along the coil ''' tangent_deriv = _grad_grad_fourier_position(coil, s) # shape (..., 3) normal = tangent_deriv / jnp.linalg.norm(tangent_deriv, axis=-1, keepdims=True) return normal @jax.jit def _arc_length_fourier(fourier_coil, s): tangent = _grad_fourier_position(fourier_coil, s) return jnp.linalg.norm(tangent,axis=-1) @jax.jit def _fourier_reverse_parametrisation(coil : FourierCoil): return FourierCoil(fourier_cos=coil.fourier_cos, fourier_sin= -1.0 * coil.fourier_sin, centre_i=coil.centre_i) # =================================================================================================================================================================================== # Finite Sizes # =================================================================================================================================================================================== # ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- # Centroids # ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- @jax.jit def _fourier_coil_radial_vector_centroid(coil : FourierCoil, s): ''' Internal function to find the centroid radial vector at arc length s Parameters ---------- coil : FourierCoil Coil object s : jnp.ndarray Arc length(s) along the coil Returns ------- jnp.ndarray [..., 3] Radial vector(s) along the coil ''' pos_i = _fourier_position(coil, s) tangent_i = _fourier_tangent(coil, s) return _radial_vector_centroid_from_data(coil.centre_i, pos_i, tangent_i) @jax.jit def _fourier_coil_finite_size_frame_centroid(coil : FourierCoil, s): ''' Internal function to find the centroid finite size frame at arc length s Parameters ---------- coil : FourierCoil Coil object s : jnp.ndarray Arc length(s) along the coil Returns ------- jnp.ndarray [..., 3, 3] Finite size frame(s) along the coil ''' radial_vector = _fourier_coil_radial_vector_centroid(coil, s) tangent = _fourier_tangent(coil, s) return _frame_from_radial_vector(tangent, radial_vector) #===================================================================================================================================================================================== # Converting curve to Fourier coefficients #===================================================================================================================================================================================== @partial(jax.jit, static_argnums = 1) def _xyz_to_fourier_coefficients(positions : jnp.ndarray, n_modes : int): # positions is a 1D array. N = positions.shape[0] loc_fourier = jnp.fft.rfft(positions) loc_fourier_cos = jnp.real(loc_fourier[1:]) / N * 2.0 loc_fourier_sin = - jnp.imag(loc_fourier[1:]) / N * 2.0 centre = jnp.real(loc_fourier[0]) / N if N%2==0: # Nyquist mode only has a cosine component loc_fourier_cos = loc_fourier_cos.at[-1].set(loc_fourier_cos[-1] * 0.5) loc_fourier_cos = loc_fourier_cos[:n_modes] loc_fourier_sin = loc_fourier_sin[:n_modes] return loc_fourier_cos, loc_fourier_sin, centre xyz_fourier_batched = jax.jit(jnp.vectorize(_xyz_to_fourier_coefficients, signature='(N)->(M),(M),()', excluded=(1,)), static_argnums=(1,)) @partial(jax.jit, static_argnums = (1,2,3)) def _sampling_positions_equal_arc_length(fourier_coil : FourierCoil, n_points_sample : int, n_points_desired : int, method : Literal['pchip', 'linear']): ''' Resample a Fourier coil to have points equally spaced in arc length Parameters ---------- fourier_coil : FourierCoil Fourier coil object n_points_sample : int Number of points sampled for resolving the arc length inverse n_points_desired : int Number of points to resample to method : Literal['pchip', 'linear'] Method to use for resampling ('pchip' or 'linear') Returns ------- jnp.ndarray [n_points, 3] Resampled positions along the coil ''' s_sampling = jnp.linspace(0.0, 1.0, n_points_sample, endpoint=False) arc_length = _arc_length_fourier(fourier_coil, s_sampling) if method == 'linear': return jax_sbgeom.jax_utils.resample_uniform_periodic_linear(arc_length, n_points_desired) elif method == 'pchip': return jax_sbgeom.jax_utils.resample_uniform_periodic_pchip(arc_length, n_points_desired) else: raise ValueError(f"Unknown method {method} for resampling to equal arc length")
[docs] @partial(jax.jit, static_argnums = (1,2,3)) def convert_fourier_coil_to_equal_arclength(fourier_coil : FourierCoil, n_points_sample : int = None, n_points_desired : int = None, method : Literal['pchip', 'linear'] = 'pchip'): ''' Resample a Fourier coil to have points equally spaced in arc length. Parameters ---------- fourier_coil : FourierCoil Fourier coil object with N modes n_points_sample : int Number of points to sample the arc length inverse. If None, uses N*16 points. n_points_desired : int Number of points to resample to. If None, uses N*4 points. method : Literal['pchip', 'linear'] Method to use for interpolating the arc length inverse. 'pchip' is significantly better while not increasing runtime. Most of the time is spent on computing the Fourier sums. Returns ------- FourierCoil [n_points_desired, 3] Resampled positions along the coil ''' if n_points_sample is None: n_points_sample = fourier_coil.fourier_cos.shape[0] * 16 if n_points_desired is None: n_points_desired = fourier_coil.fourier_cos.shape[0] * 4 s_resampled = _sampling_positions_equal_arc_length(fourier_coil, n_points_sample, n_points_desired, method) resampling_positions = fourier_coil.position(s_resampled) fourier_cos, fourier_sin, centre = curve_to_fourier_coefficients(resampling_positions) return FourierCoil(fourier_cos=fourier_cos, fourier_sin=fourier_sin, centre_i=centre)
_convert_fourier_coilset_to_equal_arclength_internal = jax.jit(jax.vmap(convert_fourier_coil_to_equal_arclength, in_axes=(0,None,None,None)), static_argnums =(1,2,3))
[docs] @partial(jax.jit, static_argnums = (1,2,3)) def convert_fourier_coilset_to_equal_arclength(fourier_coilset : CoilSet, n_points_sample : int = None, n_points_desired : int = None, method : Literal['pchip', 'linear'] = 'pchip'): ''' Resample a Fourier coilset to have points equally spaced in arc length. Parameters ---------- fourier_coil : FourierCoil Fourier coil object with N modes n_points_sample : int Number of points to sample the arc length inverse. If None, uses N*16 points. n_points_desired : int Number of points to resample to. If None, uses N*4 points. method : Literal['pchip', 'linear'] Method to use for interpolating the arc length inverse. 'pchip' is significantly better while not increasing runtime. Most of the time is spent on computing the Fourier sums. Returns ------- FourierCoil [n_coils, n_points_desired, 3] Resampled positions along the coil ''' return CoilSet(_convert_fourier_coilset_to_equal_arclength_internal(fourier_coilset.coils, n_points_sample, n_points_desired, method))
#------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- # Converting curve to Fourier coefficients: Convenience #-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
[docs] @partial(jax.jit, static_argnums = 1) def curve_to_fourier_coefficients(positions : jnp.ndarray, n_modes : int = None): positions_first_batch = jnp.moveaxis(positions, -1, 0) fourier_cos, fourier_sin, centre = xyz_fourier_batched(positions_first_batch, n_modes) return jnp.moveaxis(fourier_cos,0 , -1), jnp.moveaxis(fourier_sin,0, -1), jnp.moveaxis(centre, 0, -1)
[docs] @partial(jax.jit, static_argnums = 1) def convert_to_fourier_coil(coil : DiscreteCoil, n_modes : int = None): ''' Convert a DiscreteCoil to a FourierCoil by computing Fourier coefficients from the discrete positions Parameters ---------- coil : DiscreteCoil Discrete coil object n_modes : int Number of Fourier modes to use. If None, uses N/2 modes where N is the number of discrete points in the coil. Returns ------- FourierCoil Fourier coil object ''' fourier_cos, fourier_sin, centre = curve_to_fourier_coefficients(coil.positions, n_modes) return FourierCoil(fourier_cos=fourier_cos, fourier_sin=fourier_sin, centre_i=centre)
[docs] @partial(jax.jit, static_argnums = 1) def convert_to_fourier_coilset(coilset : CoilSet, n_modes : int = None): ''' Convert a DiscreteCoil to a FourierCoil by computing Fourier coefficients from the discrete positions Parameters ---------- coil : DiscreteCoil Discrete coil object n_modes : int Number of Fourier modes to use. If None, uses (N+1)//2 modes where N is the number of discrete points in the coil. Returns ------- FourierCoil Fourier coil object ''' fourier_cos, fourier_sin, centre = curve_to_fourier_coefficients(coilset.coils.positions, n_modes) return CoilSet(FourierCoil(fourier_cos=fourier_cos, fourier_sin=fourier_sin, centre_i=centre))