Source code for jax_sbgeom.flux_surfaces.flux_surfaces_extended

import jax.numpy as jnp
import h5py 
import jax
import numpy as onp
from dataclasses import dataclass
from jax_sbgeom.jax_utils import stack_jacfwd, interpolate_fractions, bilinear_interp
from functools import partial
from .flux_surfaces_base import FluxSurface, FluxSurfaceBase, ToroidalExtent, FluxSurfaceSettings, FluxSurfaceData, _data_modes_settings_from_hdf5, _cartesian_to_cylindrical, _principal_curvatures_interpolated, _cylindrical_position_interpolated, _cylindrical_to_cartesian, ParametrisedSurface, _normalize_theta_phi_full_mod


from .flux_surfaces_base import _cartesian_position_interpolated, _normal_interpolated
import equinox as eqx
from typing import Tuple



[docs] class FluxSurfaceNormalExtended(FluxSurfaceBase): ''' Class representing a flux surface that is extended along the normal direction. The extension is done such that: - For s <= 1.0, the original flux surface is used - For s > 1.0, the position is given by moving along the normal direction of the flux surface at s = 1.0 '''
[docs] def cartesian_position(self, s, theta, phi): return _normal_extended_cartesian_position(self, s, theta, phi)
[docs] def cylindrical_position(self, s, theta, phi): return _normal_extended_cylindrical_position(self, s, theta, phi)
[docs] def normal(self, s, theta, phi): return _normal_extended_normal(self, s, theta, phi)
[docs] def principal_curvatures(self, s, theta, phi): return _normal_extended_principal_curvatures(self, s, theta, phi)
[docs] class FluxSurfaceNormalExtendedNoPhi(FluxSurfaceBase): ''' Class representing a flux surface that is extended along the normal direction, but with no toroidal (phi) component in the extension. The extension is done such that: - For s <= 1.0, the original flux surface is used - For s > 1.0, the position is given by moving along the normal direction of the flux surface at s = 1.0, but with the toroidal component removed This is useful for creating an extension label that preserves phi_in = phi_out but still extends in a straight line. However, the label does not have the meaning of 'distance to the lcfs' anymore, as the extension is not along the actual normal direction. '''
[docs] def cartesian_position(self, s, theta, phi): return _normal_extended_no_phi_cartesian_position(self, s, theta, phi)
[docs] def cylindrical_position(self, s, theta, phi): return _normal_extended_no_phi_cylindrical_position(self, s, theta, phi)
[docs] def normal(self, s, theta, phi): return _normal_extended_no_phi_normal(self, s, theta, phi)
[docs] def principal_curvatures(self, s, theta, phi): return _normal_extended_no_phi_principal_curvatures(self, s, theta, phi)
[docs] class FluxSurfaceNormalExtendedConstantPhi(FluxSurfaceBase): ''' Class representing a flux surface that is extended along the normal direction, but keeping the toroidal angle (phi) constant during the extension. The extension is done such that: - For s <= 1.0, the original flux surface is used - For s > 1.0, the position is given by moving along the normal direction of the flux surface at s = 1.0, but adjusting the toroidal angle to keep it constant This is useful for creating an extension label that preserves phi_in = phi_out while retaining the meaning of 'distance to the lcfs', as the extension is still along the normal direction. However, the extension is no longer a straight line in 3D space. '''
[docs] def cartesian_position(self, s, theta, phi): return _normal_extended_constant_phi_cartesian_position(self, s, theta, phi)
[docs] def cylindrical_position(self, s, theta, phi): return _normal_extended_constant_phi_cylindrical_position(self, s, theta, phi)
# For a normal extended flux surface, the normal *remains the same* in the extended region
[docs] def normal(self, s, theta, phi): return _normal_extended_constant_phi_normal(self, s, theta, phi)
# Principal curvatures could be implemented in the extension region using
[docs] def principal_curvatures(self, s, theta, phi): return _normal_extended_constant_phi_principal_curvatures(self, s, theta, phi)
[docs] class FluxSurfaceFourierExtended(FluxSurfaceBase): ''' A flux surface that is extended using another flux surface defined in Fourier space. This does not necessarily have to have the same mpol & ntor as the inner flux surface. The inner flux surface is used for s <= 1.0, and the extension flux surface is used for s > 1.0. s = 1.0 corresponds to the LCFS of the inner surface s = 2.0 corresponds to the first surface of the extension surface etc.. s = n_extension + 1.0 corresponds to the last surface of the extension surface Beyond that, the additional s is ignored. ''' extension_flux_surface : FluxSurfaceBase = None
[docs] @classmethod def from_flux_surface_and_extension(cls, flux_surface : FluxSurfaceBase, extension_flux_surface : FluxSurfaceBase): ''' Create a FluxSurfaceFourierExtended from a base flux surface and an extension flux surface. Parameters ----------- flux_surface : FluxSurfaceBase Base flux surface to extend. extension_flux_surface : FluxSurfaceBase Extension flux surface to use for s > 1.0. Returns ------- FluxSurfaceFourierExtended ''' return cls(data = flux_surface.data, modes = flux_surface.modes, settings = flux_surface.settings, extension_flux_surface = extension_flux_surface)
[docs] def cartesian_position(self, s, theta, phi): return _fourier_extended_cartesian_position(self, self.extension_flux_surface, s, theta, phi)
[docs] def cylindrical_position(self, s, theta, phi): return _fourier_extended_cylindrical_position(self, self.extension_flux_surface, s, theta, phi)
[docs] def normal(self, s, theta, phi): return _fourier_extended_normal(self, self.extension_flux_surface, s, theta, phi)
[docs] def principal_curvatures(self, s, theta, phi): return _fourier_extended_principal_curvatures(self, self.extension_flux_surface, s, theta, phi)
[docs] class FluxSurfaceExtendedDistanceMatrix(ParametrisedSurface): flux_surface_extended : FluxSurfaceBase '''Flux surface extenion used. Cannot be of type FluxSurface.''' d_layers : jnp.ndarray '''The distance matrices of each layer: assumed to be shaped like [n_theta_sampled, n_phi_sampled]. s[0,0] is (0,0), s[-1-1] corresponds to (2pi, 2pi/nfp)''' def __check_init__(self): assert not isinstance(self.flux_surface_extended, FluxSurface), "flux_surface_extended cannot be of type FluxSurface; this means no extension exists and class is useless"
[docs] @eqx.filter_jit def cylindrical_position(self, s, theta, phi): return _cartesian_to_cylindrical(self.cartesian_position(s, theta, phi))
[docs] def s_interp(self, s, theta, phi ): s_bc, theta_bc, phi_bc = jnp.broadcast_arrays(s, theta, phi) d_total_interp = _d_interp_vectorized(self.d_layers, self.flux_surface_extended.nfp, s, theta, phi) ds_internal = jnp.where(s_bc <= 1.0, 0.0, jnp.where(s_bc >= 2.0, 1.0, s_bc - 1.0)) s_internal = jnp.where(s_bc <= 1.0, s_bc, jnp.where(s_bc >= 2.0, 1.0 + d_total_interp, 1.0 + ds_internal * d_total_interp)) return s_internal
# @eqx.filter_jit # def d_interp(self, s, theta, phi): # return _d_interp_vectorized(self.d_layers, self.flux_surface_extended.nfp, s, theta, phi)
[docs] @eqx.filter_jit def cartesian_position(self, s, theta, phi): return self.flux_surface_extended.cartesian_position(self.s_interp(s,theta,phi), theta, phi)
[docs] @eqx.filter_jit def normal(self, s, theta, phi): raise NotImplementedError
[docs] @eqx.filter_jit def principal_curvatures(self, s, theta, phi): raise NotImplementedError
# =================================================================================================================================================================================== # Normal Extended # =================================================================================================================================================================================== @eqx.filter_jit def _normal_extended_cartesian_position(flux_surface : FluxSurfaceBase, s, theta, phi): ''' Extend the cartesian position of a flux surface along the normal direction. For s <= 1.0, the original flux surface is used, while for s > 1.0, the position is given by moving along the normal direction of the flux surface at s = 1.0. Parameters ----------- flux_surface : FluxSurfaceBase Flux surface to evaluate s : jnp.ndarray Radial coordinate(s) at which to evaluate the position. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the position. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the position. Returns ------- jnp.ndarray Cartesian position(s) of the extended flux surface. ''' positions = _cartesian_position_interpolated(flux_surface, jnp.minimum(s, 1.0), theta, phi) normals = _normal_interpolated(flux_surface, 1.0, theta, phi) # this will not give NaNs, as s=1.0 is always on the surface (non axis) distance_1d = jnp.maximum(s - 1.0, 0.0) # We have to ensure that both do not produce nan values. # This is the case, as the positions are evaluated at s <= 1.0 and normals at s = 1.0 return positions + normals * distance_1d[..., None] @eqx.filter_jit def _normal_extended_cylindrical_position(flux_surface : FluxSurfaceBase, s, theta, phi): ''' Extend the cartesian position of a flux surface along the normal direction and afterwards convert to cylindrical coordinates. For s <= 1.0, the original flux surface is used, while for s > 1.0, the position is given by moving along the normal direction of the flux surface at s = 1.0. Parameters ----------- flux_surface : FluxSurfaceBase Flux surface to evaluate s : jnp.ndarray Radial coordinate(s) at which to evaluate the position. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the position. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the position. Returns ------- jnp.ndarray Cartesian position(s) of the extended flux surface. ''' return _cartesian_to_cylindrical(_normal_extended_cartesian_position(flux_surface, s, theta, phi)) @eqx.filter_jit def _normal_extended_normal(flux_surface : FluxSurfaceBase, s, theta, phi): ''' Extend the normal of a flux surface along the normal direction. Same as normal itself. Parameters ----------- flux_surface : FluxSurfaceBase Flux surface to evaluate s : jnp.ndarray Radial coordinate(s) at which to evaluate the normal. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the normal. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the normal. Returns ------- jnp.ndarray Normal ''' return _normal_interpolated(flux_surface, jnp.minimum(s, 1.0), theta, phi) @eqx.filter_jit def _normal_extended_principal_curvatures(flux_surface : FluxSurfaceBase, s, theta, phi): ''' Extend the principal curvatures of a flux surface along the normal direction. Uses the principal curvatures formulas in [1]. [1]: Farouki, R. T. (1986). The approximation of non-degenerate offset surfaces. Computer Aided Geometric Design, 3(1), 15-43. Parameters ----------- flux_surface : FluxSurfaceBase Flux surface to extend. s : jnp.ndarray Radial coordinate(s) at which to evaluate the principal curvatures. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the principal curvatures. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the principal curvatures. Returns ------- jnp.ndarray Principal curvatures ''' curvatures = _principal_curvatures_interpolated(flux_surface, jnp.minimum(s, 1.0), theta, phi) d = jnp.maximum(s - 1.0, 0.0) gamma_0 = jnp.where(1 + curvatures[...,0] * d >= 0.0, jnp.ones_like(d), jnp.ones_like(d) * -1.0) # the >= or > is arbitrary gamma_1 = jnp.where(1 + curvatures[...,1] * d >= 0.0, jnp.ones_like(d), jnp.ones_like(d) * -1.0) # the >= or > is arbitrary kappa_0 = curvatures[...,0] / jnp.abs(1.0 + curvatures[...,0] * d) * gamma_0 kappa_1 = curvatures[...,1] / jnp.abs(1.0 + curvatures[...,1] * d) * gamma_1 return jnp.stack([kappa_0, kappa_1], axis=-1) # =================================================================================================================================================================================== # No Phi # =================================================================================================================================================================================== def _hat_phi(positions): ''' Compute the unit vector in the toroidal (phi) direction for given cartesian positions. Parameters ----------- positions : jnp.ndarray Cartesian positions at which to compute the hat phi vector. Returns ------- jnp.ndarray Unit vectors in the toroidal direction at the given positions. ''' x = positions[...,0] y = positions[...,1] z = positions[...,2] r = jnp.sqrt(x**2 + y**2) safe_r = jnp.clip(r, min = 1e-12) hat_phi = jnp.stack([-y / safe_r, x / safe_r, jnp.zeros_like(z)], axis=-1) return hat_phi @eqx.filter_jit def _normal_extended_no_phi_cartesian_position(flux_surface : FluxSurfaceBase, s, theta, phi): ''' Extend the cartesian position of a flux surface along the normal direction with no toroidal (phi) component. For s <= 1.0, the original flux surface is used, while for s > 1.0, the position is given by moving along the normal direction of the flux surface at s = 1.0, but with the toroidal component removed. Parameters ----------- flux_surface : FluxSurfaceBase Flux surface to evaluate s : jnp.ndarray Radial coordinate(s) at which to evaluate the position. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the position. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the position. Returns ------- jnp.ndarray Cartesian position(s) of the extended flux surface. ''' positions = _cartesian_position_interpolated(flux_surface, jnp.minimum(s, 1.0), theta, phi) normals = _normal_interpolated(flux_surface, 1.0, theta, phi) hat_phi = _hat_phi(positions) phi_component = jnp.einsum("...i,...i->...", normals, hat_phi) normal_no_phi = normals - phi_component[..., None] * hat_phi normal_no_phi_normalised = normal_no_phi / jnp.linalg.norm(normal_no_phi, axis=-1, keepdims=True) distance_1d = jnp.maximum(s - 1.0, 0.0) return positions + normal_no_phi_normalised * distance_1d[..., None] @eqx.filter_jit def _normal_extended_no_phi_cylindrical_position(flux_surface : FluxSurfaceBase, s, theta, phi): ''' Extend the cartesian position of a flux surface along the normal direction with no toroidal component and convert to cylindrical coordinates. For s <= 1.0, the original flux surface is used, while for s > 1.0, the position is given by moving along the normal direction of the flux surface at s = 1.0, but with the toroidal component removed. Parameters ----------- flux_surface : FluxSurfaceBase Flux surface to evaluate s : jnp.ndarray Radial coordinate(s) at which to evaluate the position. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the position. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the position. Returns ------- jnp.ndarray Cylindrical position(s) of the extended flux surface. ''' return _cartesian_to_cylindrical(_normal_extended_no_phi_cartesian_position(flux_surface, s, theta, phi)) @eqx.filter_jit def _normal_extended_no_phi_normal(flux_surface : FluxSurfaceBase, s, theta, phi): ''' Compute the normal for a flux surface extended with no toroidal component. Returns NaN values as the normal is not well-defined in the extended region. If one desires the normal as desired from the original surface, use flux_surface.normal(jnp.minimum(s, 1.0), theta, phi) instead. Parameters ----------- flux_surface : FluxSurfaceBase Flux surface to evaluate s : jnp.ndarray Radial coordinate(s) at which to evaluate the normal. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the normal. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the normal. Returns ------- jnp.ndarray Array of NaN values with shape matching the input and dimension 3. ''' s_bc, theta_bc, phi_bc = jnp.broadcast_arrays(s, theta, phi) return jnp.full(s_bc.shape + (3,), jnp.nan) @eqx.filter_jit def _normal_extended_no_phi_principal_curvatures(flux_surface : FluxSurfaceBase, s, theta, phi): ''' Compute the principal curvatures for a flux surface extended with no toroidal component. Returns NaN values as the principal curvatures in not implemented in the extended region. Parameters ----------- flux_surface : FluxSurfaceBase Flux surface to evaluate s : jnp.ndarray Radial coordinate(s) at which to evaluate the principal curvatures. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the principal curvatures. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the principal curvatures. Returns ------- jnp.ndarray Array of NaN values with shape matching the input and dimension 2. ''' s_bc, theta_bc, phi_bc = jnp.broadcast_arrays(s, theta, phi) return jnp.full(s_bc.shape + (2,), jnp.nan) _normal_extended_no_phi_dx_dtheta = jax.jit(jnp.vectorize(jax.jacfwd(_normal_extended_no_phi_cartesian_position, argnums=3), excluded=(0,1), signature='(),(),()->(3)')) @eqx.filter_jit def __normal_extended_no_phi_arc_length_theta(flux_surface : FluxSurfaceBase, s, theta, phi): ''' Compute the arc length derivative with respect to theta for a flux surface extended with no toroidal component. Uses autodiff to compute the derivative. Parameters ----------- flux_surface : FluxSurfaceBase Flux surface to evaluate s : jnp.ndarray Radial coordinate(s) at which to evaluate the arc length. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the arc length. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the arc length. Returns ------- jnp.ndarray Arc length derivative with respect to theta. ''' return jnp.linalg.norm(_normal_extended_no_phi_dx_dtheta(flux_surface, s, theta, phi), axis=-1) # =================================================================================================================================================================================== # Constant Phi # =================================================================================================================================================================================== def _distance_between_angles(angle1, angle2): ''' Compute the signed angular distance between two angles, taking into account periodicity. Parameters ----------- angle1 : jnp.ndarray First angle(s) in radians. angle2 : jnp.ndarray Second angle(s) in radians. Returns ------- jnp.ndarray Signed angular distance from angle2 to angle1, in the range [-pi, pi]. ''' return jnp.arctan2(jnp.sin(angle1 - angle2), jnp.cos(angle1 - angle2)) def _distance_between_phi_phi_desired(flux_surface : FluxSurfaceBase, s, theta, phi, x): ''' Compute the angular distance between the toroidal angle of a position and a desired toroidal angle. Used as the objective function for finding the correct phi in constant-phi extensions. Parameters ----------- flux_surface : FluxSurfaceBase Flux surface to evaluate s : jnp.ndarray Radial coordinate(s). theta : jnp.ndarray Poloidal angle(s). phi : jnp.ndarray Desired toroidal angle(s). x : jnp.ndarray Trial toroidal angle(s) for the extended position. Returns ------- jnp.ndarray Angular distance between the computed phi and the desired phi. ''' positions = _normal_extended_cartesian_position(flux_surface, s, theta, x) return _distance_between_angles(jnp.arctan2(positions[...,1], positions[...,0]), phi) @eqx.filter_jit def _normal_extended_constant_phi_find_phi(flux_surface : FluxSurfaceBase, s , theta, phi, n_iter : int = 5): ''' Find the toroidal angle needed in the extended region to maintain a constant output toroidal angle. Uses the secant method to solve for the angle that produces the desired phi after normal extension. Parameters ----------- flux_surface : FluxSurfaceBase Flux surface to evaluate s : jnp.ndarray Radial coordinate(s). theta : jnp.ndarray Poloidal angle(s). phi : jnp.ndarray Desired output toroidal angle(s). n_iter : int, optional Number of secant iterations to perform (default: 5). Returns ------- jnp.ndarray Adjusted toroidal angle(s) to use for the extended position. ''' assert n_iter >= 1, "n_iter must be at least 1" _, _, phi_bc = jnp.broadcast_arrays(s, theta, phi) x_minus_two = phi_bc + 1e-3 x_minus_one = phi_bc f_minus_two = _distance_between_phi_phi_desired(flux_surface, s, theta, phi, x_minus_two) def secant_iteration(i, vals): x_minus_two, x_minus_one, f_minus_two = vals f_minus_one = _distance_between_phi_phi_desired(flux_surface, s, theta, phi, x_minus_one) x_new = x_minus_one - f_minus_one * (x_minus_one - x_minus_two) / (f_minus_one - f_minus_two + 1e-16) return (x_minus_one, x_new, f_minus_one) x_final = jax.lax.fori_loop(0, n_iter, secant_iteration, (x_minus_two, x_minus_one, f_minus_two))[1] return x_final @eqx.filter_jit def _normal_extended_constant_phi_cartesian_position(flux_surface : FluxSurfaceBase, s, theta, phi, n_iter : int = 5): ''' Extend the cartesian position of a flux surface along the normal direction while keeping the toroidal angle constant. For s <= 1.0, the original flux surface is used, while for s > 1.0, the position is given by moving along the normal direction while adjusting the internal phi to maintain the output phi constant. Parameters ----------- flux_surface : FluxSurfaceBase Flux surface to evaluate s : jnp.ndarray Radial coordinate(s) at which to evaluate the position. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the position. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the position. n_iter : int, optional Number of iterations for the phi solver (default: 5). Returns ------- jnp.ndarray Cartesian position(s) of the extended flux surface. ''' phi_c = _normal_extended_constant_phi_find_phi(flux_surface, s, theta, phi, n_iter) return _normal_extended_cartesian_position(flux_surface, s, theta, phi_c) @eqx.filter_jit def _normal_extended_constant_phi_cylindrical_position(flux_surface : FluxSurfaceBase, s, theta, phi, n_iter : int = 5): ''' Extend the position of a flux surface along the normal direction while keeping the toroidal angle constant and convert to cylindrical coordinates. For s <= 1.0, the original flux surface is used, while for s > 1.0, the position is given by moving along the normal direction while adjusting the internal phi to maintain the output phi constant. Parameters ----------- flux_surface : FluxSurfaceBase Flux surface to evaluate s : jnp.ndarray Radial coordinate(s) at which to evaluate the position. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the position. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the position. n_iter : int, optional Number of iterations for the phi solver (default: 5). Returns ------- jnp.ndarray Cylindrical position(s) of the extended flux surface. ''' return _cartesian_to_cylindrical(_normal_extended_constant_phi_cartesian_position(flux_surface, s, theta, phi, n_iter)) @eqx.filter_jit def _normal_extended_constant_phi_normal(flux_surface : FluxSurfaceBase, s, theta, phi, n_iter : int = 5): ''' Compute the normal for a flux surface extended along the normal direction with constant toroidal angle. Parameters ----------- flux_surface : FluxSurfaceBase Flux surface to evaluate s : jnp.ndarray Radial coordinate(s) at which to evaluate the normal. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the normal. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the normal. n_iter : int, optional Number of iterations for the phi solver (default: 5). Returns ------- jnp.ndarray Normal vector(s) of the extended flux surface. ''' phi_c = _normal_extended_constant_phi_find_phi(flux_surface, s, theta, phi, n_iter) return _normal_extended_normal(flux_surface, s, theta, phi_c) @eqx.filter_jit def _normal_extended_constant_phi_principal_curvatures(flux_surface : FluxSurfaceBase, s, theta, phi, n_iter : int = 5): ''' Compute the principal curvatures for a flux surface extended along the normal direction with constant toroidal angle. Parameters ----------- flux_surface : FluxSurfaceBase Flux surface to evaluate s : jnp.ndarray Radial coordinate(s) at which to evaluate the principal curvatures. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the principal curvatures. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the principal curvatures. n_iter : int, optional Number of iterations for the phi solver (default: 5). Returns ------- jnp.ndarray Principal curvatures of the extended flux surface. ''' phi_c = _normal_extended_constant_phi_find_phi(flux_surface, s, theta, phi, n_iter) return _normal_extended_principal_curvatures(flux_surface, s, theta, phi_c) # =================================================================================================================================================================================== # Fourier Extended # =================================================================================================================================================================================== @eqx.filter_jit def _fourier_extended_cylindrical_position(flux_surface : FluxSurfaceBase, extension : FluxSurfaceBase, s, theta, phi): ''' Compute the cylindrical position for a flux surface with Fourier-based extension. For s <= 1.0, the original flux surface is used. For 1.0 < s < 2.0, linear interpolation between the LCFS and the first extension surface is used. For s >= 2.0, the extension surface is used directly. Parameters ----------- flux_surface : FluxSurfaceBase Base flux surface to evaluate (s <= 1.0). extension : FluxSurfaceBase Extension flux surface for s > 1.0. s : jnp.ndarray Radial coordinate(s) at which to evaluate the position. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the position. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the position. Returns ------- jnp.ndarray Cylindrical position(s) of the extended flux surface. ''' # This is not necessarily completely efficient: but we cannot avoid evaluating both positions in batched operations. n_surf_extension = jnp.maximum(extension.data.Rmnc.shape[0], 2) # if there's only one extension surface this ensures we get a valid result (s=0.5 interpolation on the extension is just the surface itself anyway) inner_positions = _cylindrical_position_interpolated(flux_surface, jnp.minimum(s, 1.0), theta, phi) d_value = jnp.maximum(s - 1.0, 0.0) d_value_extension = jnp.maximum(d_value - 1.0, 0.0) normalized_d_value = d_value_extension / (n_surf_extension - 1.0) extension_positions = _cylindrical_position_interpolated(extension, normalized_d_value , theta, phi) extension_positions_d0 = _cylindrical_position_interpolated(extension, jnp.zeros_like(s) , theta, phi) only_extension = jnp.array(s >=2.0) return jnp.where(only_extension[..., None], extension_positions, inner_positions + (extension_positions_d0- inner_positions) * d_value[..., None]) @eqx.filter_jit def _fourier_extended_cartesian_position(flux_surface : FluxSurfaceBase, extension : FluxSurfaceBase, s, theta, phi): ''' Compute the cartesian position for a flux surface with Fourier-based extension. For s <= 1.0, the original flux surface is used. For 1.0 < s < 2.0, linear interpolation between the LCFS and the first extension surface is used. For s >= 2.0, the extension surface is used directly. Parameters ----------- flux_surface : FluxSurfaceBase Base flux surface to evaluate (s <= 1.0). extension : FluxSurfaceBase Extension flux surface for s > 1.0. s : jnp.ndarray Radial coordinate(s) at which to evaluate the position. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the position. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the position. Returns ------- jnp.ndarray Cartesian position(s) of the extended flux surface. ''' return _cylindrical_to_cartesian(_fourier_extended_cylindrical_position(flux_surface, extension, s, theta, phi)) @eqx.filter_jit def _fourier_extended_normal(flux_surface : FluxSurfaceBase, extension : FluxSurfaceBase, s, theta, phi): ''' Compute the normal vector for a flux surface with Fourier-based extension. For s <= 1.0, the original flux surface normal is used. For 1.0 < s < 2.0, linear interpolation between the LCFS normal and the first extension surface normal is used. For s >= 2.0, the extension surface normal is used. Parameters ----------- flux_surface : FluxSurfaceBase Base flux surface to evaluate (s <= 1.0). extension : FluxSurfaceBase Extension flux surface for s > 1.0. s : jnp.ndarray Radial coordinate(s) at which to evaluate the normal. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the normal. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the normal. Returns ------- jnp.ndarray Normal vector(s) of the extended flux surface. ''' n_surf_extension = jnp.maximum(extension.data.Rmnc.shape[0], 2) # if there's only one extension surface this ensures we get a valid result (s=0.5 interpolation on the extension is just the surface itself anyway) inner_normals = _normal_interpolated(flux_surface, jnp.minimum(s, 1.0), theta, phi) d_value = jnp.maximum(s - 1.0, 0.0) d_value_extension = jnp.maximum(d_value - 1.0, 0.0) normalized_d_value = d_value_extension / (n_surf_extension - 1.0) extension_normals = _normal_interpolated(extension, normalized_d_value , theta, phi) extension_normals_d0 = _normal_interpolated(extension, jnp.zeros_like(s) , theta, phi) only_extension = jnp.array(s >=2.0) return jnp.where(only_extension[..., None], extension_normals, inner_normals + (extension_normals_d0 - inner_normals) * d_value[..., None]) @eqx.filter_jit def _fourier_extended_principal_curvatures(flux_surface : FluxSurfaceBase, extension : FluxSurfaceBase, s, theta, phi): ''' Compute the principal curvatures for a flux surface with Fourier-based extension. For s <= 1.0, the original flux surface curvatures are used. For 1.0 < s < 2.0, linear interpolation between the LCFS curvatures and the first extension surface curvatures is used. For s >= 2.0, the extension surface curvatures are used. Parameters ----------- flux_surface : FluxSurfaceBase Base flux surface to evaluate (s <= 1.0). extension : FluxSurfaceBase Extension flux surface for s > 1.0. s : jnp.ndarray Radial coordinate(s) at which to evaluate the principal curvatures. theta : jnp.ndarray Poloidal angle(s) at which to evaluate the principal curvatures. phi : jnp.ndarray Toroidal angle(s) at which to evaluate the principal curvatures. Returns ------- jnp.ndarray Principal curvatures of the extended flux surface. ''' n_surf_extension = jnp.maximum(extension.data.Rmnc.shape[0], 2) # if there's only one extension surface this ensures we get a valid result (s=0.5 interpolation on the extension is just the surface itself anyway) inner_curvatures = _principal_curvatures_interpolated(flux_surface, jnp.minimum(s, 1.0), theta, phi) d_value = jnp.maximum(s - 1.0, 0.0) d_value_extension = jnp.maximum(d_value - 1.0, 0.0) normalized_d_value = d_value_extension / (n_surf_extension - 1.0) extension_curvatures = _principal_curvatures_interpolated(extension, normalized_d_value , theta, phi) extension_curvatures_d0 = _principal_curvatures_interpolated(extension, jnp.zeros_like(s) , theta, phi) only_extension = jnp.array(s >=2.0) return jnp.where(only_extension[..., None], extension_curvatures, inner_curvatures + (extension_curvatures_d0 - inner_curvatures) * d_value[..., None]) #==================================================================================================================================================================================== # FluxSurfaceExtendedDistanceMatrix methods #==================================================================================================================================================================================== def _d_interp(d_layers, nfp, s, theta, phi): s_norm = jnp.maximum(0.0, (s - 2.0) / (d_layers.shape[0] - 1)) s_norm = jnp.minimum(1.0, s_norm) i0, i1, ds = interpolate_fractions(s_norm, d_layers.shape[0]) d_lower = d_layers[i0] d_upper = d_layers[i1] d_lower_interp = bilinear_interp(*_normalize_theta_phi_full_mod(theta, phi , nfp), d_lower) d_upper_interp = bilinear_interp(*_normalize_theta_phi_full_mod(theta, phi , nfp), d_upper) d_total_interp = d_lower_interp * (1-ds) + d_upper_interp * ds return d_total_interp _d_interp_vectorized = eqx.filter_jit(jnp.vectorize(_d_interp, excluded=(0,1), signature='(),(),()->()'))