Source code for jax_sbgeom.flux_surfaces.flux_surfaces_base

from abc import abstractmethod

import jax.numpy as jnp
import h5py 
import jax
import numpy as onp
from dataclasses import dataclass
from jax_sbgeom.jax_utils import stack_jacfwd, interpolate_array
from functools import partial
import equinox as eqx

[docs] @jax.tree_util.register_dataclass @dataclass(frozen=True) class FluxSurfaceSettings: mpol : int # maximum poloidal mode number [inclusive] ntor : int # maximum toroidal mode number [inclusive] nfp : int # number of field periods
@eqx.filter_jit def _create_mpol_vector(settings : FluxSurfaceSettings): ''' Create the poloidal mode number vector for VMEC representation. Uses [0] * ntor + 1, [1] * (2 * ntor + 1), [2] * (2 * ntor + 1), ..., [mpol] * (2 * ntor + 1) First is because for zero poloidal mode, there is no difference between positive and negative toroidal modes They can be combined into a single coefficient. Since the zero mode also needs representation, there are ntor + 1 entries for m = 0. For m >0, there are 2 * ntor + 1 entries, since both positive and negative toroidal modes need representation and the zero mode. Parameters ---------- mpol : int Maximum poloidal mode number. ntor : int Maximum toroidal mode number. Returns ------- jnp.ndarray The poloidal mode number vector. ''' return jnp.concatenate([ jnp.zeros(settings.ntor + 1, dtype=int), jnp.repeat(jnp.arange(1, settings.mpol + 1), 2 * settings.ntor + 1) ]) @eqx.filter_jit def _create_ntor_vector(settings : FluxSurfaceSettings): ''' Create the toroidal mode number vector for VMEC representation. Uses [0, 1, 2, ..., ntor], [-ntor, ..., -1, 0, 1, ..., ntor], ..., [-ntor, ..., -1, 0, 1, ..., ntor] for m = 0, 1, ..., mpol Multiplied by symmetry factor symm. Parameters ---------- mpol : int Maximum poloidal mode number. ntor : int Maximum toroidal mode number. symm : int The symmetry factor (number of field periods) Returns ------- jnp.ndarray The toroidal mode number vector. ''' return jnp.concatenate([ jnp.arange(settings.ntor + 1), jnp.tile(jnp.arange(-settings.ntor, settings.ntor + 1), settings.mpol) ]) * settings.nfp def _cylindrical_to_cartesian(RZphi : jnp.ndarray): R = RZphi[..., 0] Z = RZphi[..., 1] phi = RZphi[..., 2] x = R * jnp.cos(phi) y = R * jnp.sin(phi) return jnp.stack([x, y, Z], axis=-1) def _cartesian_to_cylindrical(XYZ : jnp.ndarray): x = XYZ[..., 0] y = XYZ[..., 1] z = XYZ[..., 2] R = jnp.sqrt(x**2 + y**2) phi = jnp.arctan2(y, x) return jnp.stack([R, z, phi], axis=-1) @jax.jit def _check_whether_make_normals_point_outwards_required(Rmnc : jnp.ndarray, Zmns : jnp.ndarray, mpol_vector : jnp.ndarray): ''' * Internal * Check whether the Fourier coefficients need to be modified such that the normals point outwards. This corresponds to four cases: 1. theta = 0 is outboard: a. dZ_dtheta > 0 at theta = 0 -> normals point outwards -> no change b. dZ_dtheta < 0 at theta = 0 -> normals point inwards -> reverse theta 2. theta = pi is outboard: a. dZ_dtheta > 0 at theta = 0 -> normals point inwards -> reverse theta b. dZ_dtheta < 0 at theta = 0 -> normals point outwards -> no change Parameters ----------- Rmnc : jnp.ndarray Array of radial Fourier coefficients. Shape (nsurf, nmodes) Zmns : jnp.ndarray Array of vertical Fourier coefficients. Shape (nsurf, nmodes) mpol_vector : jnp.ndarray Array of poloidal mode numbers. Shape (nmodes,) Returns -------- flip_theta : bool Whether to reverse theta to ensure normals point outwards. ''' # This computes: # dZ_dtheta on the lcfs at theta, phi = 0: # dZ_dtheta = jnp.sum(Zmns[-1,:] * mpol_vector * jnp.cos(mpol_vector * 0 - ntor_vector * 0)) = jnp.sum(Zmns[-1,:] * mpol_vector) sum_Zmns = jnp.sum(Zmns[-1,:] * mpol_vector) # Rmnc at theta, phi = 0 # R = jnp.sum(Rmnc[-1,:] * jnp.cos(mpol_vector * 0 - ntor_vector * 0)) = jnp.sum(Rmnc[-1,:]) # Rmnc at theta, phi = pi, 0: # R = jnp.sum(Rmnc[-1,:] * jnp.cos(mpol_vector * jnp.pi - ntor_vector * 0)) = jnp.sum(Rmnc[-1,:] * (-1)**mpol_vector) # We want to determine whether dZ_dtheta points in the positive or negative Z direction at the outboard midplane. # This is accomplished by checking the sign of dZ_dtheta at the outboard midplane. # The outboard midplane is at theta = 0 if sum_Rmnc > 0 and at theta = pi if sum_Rmnc < 0 cond_outboard = jnp.sum(Rmnc[-1,:]) > jnp.sum(Rmnc[-1,:] * (-1)**mpol_vector) original_u = jnp.where(cond_outboard, jnp.where(sum_Zmns > 0, True, False), # cond_outboard == True branch jnp.where(sum_Zmns > 0, False, True) # cond_outboard == False branch ) return jnp.logical_not(original_u) def _reverse_theta_single(m_vec, n_vec, coeff_vec, cosine_sign : bool): ''' * Internal * Changes the Fourier coefficients such that theta is replaced by -theta. Parameters ----------- m_vec : jnp.ndarray Array of poloidal mode numbers. n_vec : jnp.ndarray Array of toroidal mode numbers. coeff_vec : jnp.ndarray Array of Fourier coefficients. (Rmnc or Zmns) cosine_sign : bool If True, the coefficients correspond to cosine terms. If False, they correspond to sine terms. Returns -------- new_coeff_vec : jnp.ndarray The modified Fourier coefficients after reversing theta. ''' assert coeff_vec.ndim == 1 assert m_vec.shape == n_vec.shape == coeff_vec.shape keys = jnp.stack([m_vec, n_vec], axis=1) reversed_keys = jnp.stack([m_vec, -n_vec], axis=1) # target keys reversed_keys_mod = jnp.where(keys[:,0:1] > 0, reversed_keys, keys) matches = jnp.all(keys[:, None, :] == reversed_keys_mod[None, :, :], axis=-1) # assert jnp.all(jnp.any(matches, axis=0)) idx_map = jnp.argmax(matches, axis=0) # Build new coefficient map # For cosine, we just swap the numbers # For sine, we swap and change sign, except for the m=0 terms. These stay the same new_coeff_vec = jax.lax.cond(cosine_sign, lambda _ : coeff_vec[idx_map], lambda _ : jnp.where(keys[:,0] > 0, -coeff_vec[idx_map], coeff_vec[idx_map]), operand=None) return new_coeff_vec reverse_theta_total = jax.jit(jax.vmap(_reverse_theta_single, in_axes=(None, None, 0, None), out_axes=0))
[docs] @jax.tree_util.register_dataclass @dataclass(frozen=True) class FluxSurfaceModes: mpol_vector : jnp.ndarray ntor_vector : jnp.ndarray
[docs] @classmethod def from_settings(cls, settings : FluxSurfaceSettings): mpol_vector = _create_mpol_vector(settings) ntor_vector = _create_ntor_vector(settings) return cls(mpol_vector=mpol_vector, ntor_vector=ntor_vector)
[docs] @jax.tree_util.register_dataclass @dataclass(frozen=True) class FluxSurfaceData: Rmnc : jnp.ndarray Zmns : jnp.ndarray
[docs] @classmethod def from_rmnc_zmns_settings(cls, Rmnc : jnp.ndarray, Zmns : jnp.ndarray, settings : FluxSurfaceSettings, make_normals_point_outwards : bool = True): mpol_vector = _create_mpol_vector(settings) ntor_vector = _create_ntor_vector(settings) if make_normals_point_outwards: flip_theta = _check_whether_make_normals_point_outwards_required(Rmnc, Zmns, mpol_vector) Rmnc_mod = jnp.where( flip_theta, reverse_theta_total(mpol_vector, ntor_vector, Rmnc, True), Rmnc ) Zmns_mod = jnp.where( flip_theta, reverse_theta_total(mpol_vector, ntor_vector, Zmns, False), Zmns ) else: Rmnc_mod = Rmnc Zmns_mod = Zmns assert(Rmnc.shape == Zmns.shape) assert(Rmnc.shape[1] == len(mpol_vector)) return cls(Rmnc=Rmnc_mod, Zmns=Zmns_mod)
def _data_modes_settings_from_hdf5(filename : str, make_normals_point_outwards : bool = True): """Load a FluxSurface from an VMEC-type HDF5 file. Parameters ---------- filename : str Path to the HDF5 file. Returns ------- FluxSurface The loaded FluxSurface object. """ with h5py.File(filename) as f: Rmnc = jnp.array(f['rmnc']) Zmns = jnp.array(f['zmns']) mpol = int(f['mpol'][()]) - 1 # vmec uses mpol 1 larger than maximum poloidal mode number ntor = int(f['ntor'][()]) nfp = int(f['nfp'][()]) settings = FluxSurfaceSettings( mpol=mpol, ntor=ntor, nfp=nfp, ) assert( jnp.all( _create_mpol_vector(settings) == jnp.array(f['xm']))) # sanity check assert( jnp.all( _create_ntor_vector(settings) == jnp.array(f['xn']))) # sanity check data = FluxSurfaceData.from_rmnc_zmns_settings(Rmnc, Zmns, settings) modes = FluxSurfaceModes.from_settings(settings) return data, modes, settings
[docs] class ParametrisedSurface(eqx.Module): ''' * Internal * Class representing a parametrised surface. This is used for the interpolation-based methods that require 2D data. The surface is parametrised by (s, theta, phi) where s is the surface index or normalized flux label, theta is the poloidal angle and phi is the toroidal angle. '''
[docs] @abstractmethod def cylindrical_position(self, s, theta, phi): raise NotImplementedError
[docs] @abstractmethod def cartesian_position(self, s, theta, phi): raise NotImplementedError
[docs] @abstractmethod def normal(self, s, theta, phi): raise NotImplementedError
[docs] @abstractmethod def principal_curvatures(self, s, theta, phi): raise NotImplementedError
[docs] class FluxSurfaceBase(ParametrisedSurface): ''' Class representing a set of flux surfaces using a VMEC-like representation. Base abstract class that provides no implementation. The actual implementation is provided in the FluxSurface class, which inherits from this base class. This allows for different implementations of the same interface, such as interpolation-based or direct evaluation-based methods. Attributes: ----------- data : FluxSurfaceData Data object containing the Fourier coefficients Rmnc and Zmns modes : FluxSurfaceModes Modes object containing the mode vectors mpol_vector and ntor_vector settings : FluxSurfaceSettings Settings object containing parameters mpol, ntor, nfp ''' data : FluxSurfaceData = None modes : FluxSurfaceModes = None settings : FluxSurfaceSettings = None
[docs] @classmethod def from_hdf5(cls, filename : str): ''' Load a FluxSurface from an VMEC-type HDF5 file. Parameters ---------- filename : str Path to the HDF5 file. Returns ------- FluxSurface The loaded FluxSurface object. ''' data, modes, settings = _data_modes_settings_from_hdf5(filename) return cls(data=data, modes = modes, settings=settings)
[docs] @classmethod def from_flux_surface(cls, flux_surface_base : "FluxSurfaceBase"): ''' Create a FluxSurface from another FluxSurface (copy constructor). Can be used to convert between subclasses of FluxSurface. Parameters ----------- flux_surface_base : FluxSurface The input FluxSurface to copy from. Returns -------- FluxSurface The copied FluxSurface. ''' return cls(data = flux_surface_base.data, modes = flux_surface_base.modes, settings = flux_surface_base.settings)
[docs] @classmethod def from_rmnc_zmns_settings(cls, Rmnc : jnp.ndarray, Zmns : jnp.ndarray, settings : FluxSurfaceSettings, make_normals_point_outwards : bool = True): ''' Create a FluxSurface from Fourier coefficients and settings. Optionally, modify the coefficients such that normals point outwards (default=True) Parameters ----------- Rmnc : jnp.ndarray Array of radial Fourier coefficients. Shape (nsurf, nmodes) Zmns : jnp.ndarray Array of vertical Fourier coefficients. Shape (nsurf, nmodes) settings : FluxSurfaceSettings Settings object containing parameters mpol, ntor, nfp make_normals_point_outwards : bool Whether to modify the Fourier coefficients such that normals point outwards. Default is True. Returns -------- FluxSurface The created FluxSurface. ''' data = FluxSurfaceData.from_rmnc_zmns_settings(Rmnc, Zmns, settings, make_normals_point_outwards) modes = FluxSurfaceModes.from_settings(settings) return cls(data=data, modes = modes, settings=settings)
[docs] @classmethod def from_data_settings(cls, data : FluxSurfaceData, settings : FluxSurfaceSettings): ''' Create a FluxSurface from FluxSurfaceData and FluxSurfaceSettings. Parameters ----------- data : FluxSurfaceData Data object containing the Fourier coefficients Rmnc and Zmns settings : FluxSurfaceSettings Settings object containing parameters mpol, ntor, nfp Returns -------- FluxSurface The created FluxSurface. ''' modes = FluxSurfaceModes.from_settings(settings) return cls(data=data, modes = modes, settings=settings)
[docs] @classmethod def from_data_settings_full(cls, data : FluxSurfaceData, settings : FluxSurfaceSettings): ''' Create a FluxSurface from FluxSurfaceData and FluxSurfaceSettings. In this function, the data is ensured to be 2D. This is required for some functions that interpolate between the different surfaces. Parameters ----------- data : FluxSurfaceData Data object containing the Fourier coefficients Rmnc and Zmns settings : FluxSurfaceSettings Settings object containing parameters mpol, ntor, nfp Returns -------- FluxSurface The created FluxSurface. ''' data = FluxSurfaceData(Rmnc = jnp.atleast_2d(data.Rmnc), Zmns = jnp.atleast_2d(data.Zmns)) modes = FluxSurfaceModes.from_settings(settings) return cls(data=data, modes = modes, settings=settings)
@property def nfp(self): return self.settings.nfp
[docs] class FluxSurface(FluxSurfaceBase):
[docs] def cylindrical_position(self, s, theta, phi): ''' Cylindrical position on the flux surface as a function of (s, theta, phi) Parameters ----------- s : jnp.ndarray Surface index or normalized flux label theta : jnp.ndarray Poloidal angle(s) phi : jnp.ndarray Toroidal angle(s) Returns -------- jnp.ndarray Cylindrical position(s) on the flux surface [R, Z, phi] ''' return _cylindrical_position_interpolated(self, s, theta, phi)
[docs] def cartesian_position(self, s, theta, phi): ''' Cartesian position on the flux surface as a function of (s, theta, phi) Parameters ----------- s : jnp.ndarray Surface index or normalized flux label theta : jnp.ndarray Poloidal angle(s) phi : jnp.ndarray Toroidal angle(s) Returns -------- jnp.ndarray Cartesian position(s) on the flux surface [x, y, z] ''' return _cartesian_position_interpolated(self, s, theta, phi)
[docs] def normal(self, s, theta, phi): ''' Normal vector on the flux surface as a function of (s, theta, phi) Parameters ----------- s : jnp.ndarray Surface index or normalized flux label theta : jnp.ndarray Poloidal angle(s) phi : jnp.ndarray Toroidal angle(s) Returns -------- jnp.ndarray Normal vector(s) on the flux surface ''' return _normal_interpolated(self, s, theta, phi)
[docs] def principal_curvatures(self, s, theta, phi): ''' Principal curvatures on the flux surface as a function of (s, theta, phi) Parameters ----------- s : jnp.ndarray Surface index or normalized flux label theta : jnp.ndarray Poloidal angle(s) phi : jnp.ndarray Toroidal angle(s) Returns -------- jnp.ndarray Principal curvatures(s) on the flux surface ''' return _principal_curvatures_interpolated(self, s, theta, phi)
[docs] def make_2d_flux_surface(fs : FluxSurfaceBase) -> FluxSurfaceBase: ''' Convert a FluxSurface with 1D data to a FluxSurface with 2D data by adding a dummy surface dimension. This allows you to use the interpolation-based methods that require 2D data. Interpolation on one surface will just return the single surface. Parameters ----------- fs : FluxSurfaceBase The input FluxSurfaceBase with 1D data. Returns -------- FluxSurfaceBase The output FluxSurfaceBase with 2D data. ''' return type(fs)(data = FluxSurfaceData(jnp.atleast_2d(fs.data.Rmnc), jnp.atleast_2d(fs.data.Zmns)), modes = fs.modes, settings = fs.settings)
[docs] @jax.tree_util.register_dataclass @dataclass(frozen=True) class ToroidalExtent: ''' Class representing a toroidal extent in phi. Attributes: ----------- start : float Starting toroidal angle (in radians) end : float Ending toroidal angle (in radians) ''' start : float end : float
[docs] @classmethod def half_module(self, flux_surface : FluxSurfaceBase, dphi = 0.0): ''' Create a ToroidalExtent representing half a field period. Parameters ----------- flux_surface : FluxSurface The flux surface for which to create the toroidal extent. dphi : float An optional offset to add to both the start and end angles. Returns -------- ToroidalExtent The created ToroidalExtent. ''' return self(dphi, 2 * jnp.pi / flux_surface.nfp / 2.0 + dphi)
[docs] @classmethod def full_module(self, flux_surface : FluxSurfaceBase, dphi = 0.0): ''' Create a ToroidalExtent representing a full field period. Parameters ----------- flux_surface : FluxSurface The flux surface for which to create the toroidal extent. dphi : float An optional offset to add to both the start and end angles. Returns -------- ToroidalExtent The created ToroidalExtent. ''' return self(dphi, 2 * jnp.pi / flux_surface.nfp + dphi)
[docs] @classmethod def full(self): ''' Create a ToroidalExtent representing the full toroidal angle [0, 2pi]. Returns -------- ToroidalExtent The created ToroidalExtent. ''' return self(0.0, 2 * jnp.pi)
[docs] def full_angle(self): ''' Whether the toroidal extent represents a full 2pi angle. Used for creating closed meshes. Returns -------- bool True if the toroidal extent represents a full 2pi angle, False otherwise. ''' return bool(jnp.allclose(self.end - self.start, 2 * jnp.pi))
def __iter__(self): return iter((self.start, self.end, self.full_angle()))
# =================================================================================================================================================================================== # Positions # =================================================================================================================================================================================== @partial(jax.jit) def _cylindrical_position_direct(flux_surface : FluxSurfaceBase, theta, phi): ''' Cylindrical position on the flux surface as a function of (theta, phi) for a single surface (1D data). Parameters ----------- flux_surface : FluxSurface The flux surface object. theta : jnp.ndarray Poloidal angle(s) phi : jnp.ndarray Toroidal angle(s) Returns -------- jnp.ndarray Cylindrical position(s) on the flux surface [R, Z, phi] ''' assert flux_surface.data.Rmnc.ndim == 1, "Rmnc must be a 1D array but is of shape {}".format(flux_surface.data.Rmnc.shape) # This in essence computes: # R = jnp.sum(Rmnc_interp[..., None] * jnp.cos(mpol_vector[..., None] * theta[None, ...] - ntor_vector[..., None] * phi[None, ...]), axis=-1) # However, although the above can be more efficient, it creates large intermediate arrays and is thus undesirable. # Also, we call interpolate_array once per mode and per point in this setup # Instead, we could have vectorized this calculation over all points, but that would also create large intermediate arrays. # Now, no n_modes x n_points arrays are created. # This function is valid for both s,theta,phi all scalars and broadcastable arrays. def fourier_sum(vals, i): R, Z = vals R = R + flux_surface.data.Rmnc[i] * jnp.cos(flux_surface.modes.mpol_vector[i] * theta - flux_surface.modes.ntor_vector[i] * phi) Z = Z + flux_surface.data.Zmns[i] * jnp.sin(flux_surface.modes.mpol_vector[i] * theta - flux_surface.modes.ntor_vector[i] * phi) return (R,Z), None # The fourier_sum function automatically broadcast arrays. However, we need to ensure that # we start the scan with a zero object that has the correct final shape. Thus, # we create dummy arrays that have the correct shape. # The phi_bc is required to ensure the final array phi is stackable with R, Z. theta_bc, phi_bc = jnp.broadcast_arrays(theta, phi) n_modes = flux_surface.data.Rmnc.shape[0] R,Z = jax.lax.scan(fourier_sum, (jnp.zeros_like(theta_bc), jnp.zeros_like(theta_bc)), jnp.arange(n_modes))[0] return jnp.stack([R, Z, phi_bc],axis=-1) @partial(jax.jit) def _cartesian_position_direct(flux_surface : FluxSurfaceBase, theta, phi): ''' Cartesian position on the flux surface as a function of (theta, phi) for a single surface (1D data). Parameters ----------- flux_surface : FluxSurfaceBase The flux surface object. theta : jnp.ndarray Poloidal angle(s) phi : jnp.ndarray Toroidal angle(s) Returns -------- jnp.ndarray Cartesian position(s) on the flux surface [x, y, z] ''' RZphi = _cylindrical_position_direct(flux_surface, theta, phi) return _cylindrical_to_cartesian(RZphi) _dx_dtheta_direct = jax.jit(jnp.vectorize(jax.jacfwd(_cartesian_position_direct, argnums=1), excluded=(0,), signature='(),()->(3)')) @partial(jax.jit) def _arc_length_theta_direct(flux_surface : FluxSurfaceBase, theta, phi): ''' Arc length with respect to theta on the flux surface as a function of (theta, phi) for a single surface (1D data). Parameters ----------- flux_surface : FluxSurfaceBase The flux surface object. theta : jnp.ndarray Poloidal angle(s) phi : jnp.ndarray Toroidal angle(s) Returns -------- jnp.ndarray Arc length(s) with respect to theta on the flux surface. ''' dx_dtheta = _dx_dtheta_direct(flux_surface, theta, phi) dx_dtheta_norm = jnp.linalg.norm(dx_dtheta, axis=-1) return dx_dtheta_norm @eqx.filter_jit def _cylindrical_position_interpolated(flux_surface : FluxSurfaceBase, s, theta, phi): ''' Cylindrical position on the flux surface as a function of (s, theta, phi) for multiple surfaces (2D data). Computes: R = sum_mn [ Rmnc(s) * cos(m * theta - n * phi) ] Z = sum_mn [ Zmns(s) * sin(m * theta - n * phi) ] where Rmnc(s) and Zmns(s) are obtained via interpolation in the surface index s and all arrays are in the VMEC representation [not 2D arrays but flattened with explicit mode vectors]. Vectorized fully over s, theta, phi. This function does not parallelize over modes to avoid creating large intermediate arrays, only over points. Parameters ----------- flux_surface : FluxSurfaceBase The flux surface object. s : jnp.ndarray Surface index or normalized flux label theta : jnp.ndarray Poloidal angle(s) phi : jnp.ndarray Toroidal angle(s) Returns -------- jnp.ndarray Cylindrical position(s) on the flux surface [R, Z, phi] ''' assert flux_surface.data.Rmnc.ndim == 2, "Data must be a 2D array but is of shape {} [Did you use a FluxSurfaceBase with only 1D data? Check out make_2d_flux_surface.]".format(flux_surface.data.Rmnc.shape) # This in essence computes: # R = jnp.sum(Rmnc_interp[..., None] * jnp.cos(mpol_vector[..., None] * theta[None, ...] - ntor_vector[..., None] * phi[None, ...]), axis=-1) # However, although the above can be more efficient, it creates large intermediate arrays and is thus undesirable. # Also, we call interpolate_array once per mode and per point in this setup # Instead, we could have vectorized this calculation over all points, but that would also create large intermediate arrays. # Now, no n_modes x n_points arrays are created. # This function is valid for both s,theta,phi all scalars and broadcastable arrays. def fourier_sum(vals, i): R, Z = vals R = R + interpolate_array(flux_surface.data.Rmnc[..., i], s) * jnp.cos(flux_surface.modes.mpol_vector[i] * theta - flux_surface.modes.ntor_vector[i] * phi) Z = Z + interpolate_array(flux_surface.data.Zmns[..., i], s) * jnp.sin(flux_surface.modes.mpol_vector[i] * theta - flux_surface.modes.ntor_vector[i] * phi) return (R,Z), None # The fourier_sum function automatically broadcast arrays. However, we need to ensure that # we start the scan with a zero object that has the correct final shape. Thus, # we create dummy arrays that have the correct shape. # The phi_bc is required to ensure the final array phi is stackable with R, Z. s_bc, theta_bc, phi_bc = jnp.broadcast_arrays(s, theta, phi) n_modes = flux_surface.data.Rmnc.shape[1] R,Z = jax.lax.scan(fourier_sum, (jnp.zeros_like(theta_bc), jnp.zeros_like(theta_bc)), jnp.arange(n_modes))[0] return jnp.stack([R, Z, phi_bc],axis=-1) @eqx.filter_jit def _cartesian_position_interpolated(flux_surface : FluxSurfaceBase, s, theta, phi): ''' Cartesian position on the flux surface as a function of (s, theta, phi) for multiple surfaces (2D data). Computes: R = sum_mn [ Rmnc(s) * cos(m * theta - n * phi) ] Z = sum_mn [ Zmns(s) * sin(m * theta - n * phi) ] where Rmnc(s) and Zmns(s) are obtained via interpolation in the surface index s and all arrays are in the VMEC representation [not 2D arrays but flattened with explicit mode vectors]. Vectorized fully over s, theta, phi. This function does not parallelize over modes to avoid creating large intermediate arrays, only over points. Parameters ----------- flux_surface : FluxSurfaceBase The flux surface object. s : jnp.ndarray Surface index or normalized flux label theta : jnp.ndarray Poloidal angle(s) phi : jnp.ndarray Toroidal angle(s) Returns -------- jnp.ndarray Cartesian position(s) on the flux surface [R, Z, phi] ''' RZphi = _cylindrical_position_interpolated(flux_surface, s, theta, phi) return _cylindrical_to_cartesian(RZphi) _dx_dtheta = jax.jit(jnp.vectorize(jax.jacfwd(_cartesian_position_interpolated, argnums=2), excluded=(0,), signature='(),(),()->(3)')) @eqx.filter_jit def _arc_length_theta(flux_surface : FluxSurfaceBase, s, theta, phi): '''' Arc length with respect to theta on the flux surface as a function of (s, theta, phi) for multiple surfaces (2D data). Parameters ----------- flux_surface : FluxSurfaceBase The flux surface object. s : jnp.ndarray Surface index or normalized flux label theta : jnp.ndarray Poloidal angle(s) phi : jnp.ndarray Toroidal angle(s) Returns -------- jnp.ndarray Arc length(s) with respect to theta on the flux surface. ''' dx_dtheta = _dx_dtheta(flux_surface, s, theta, phi) dx_dtheta_norm = jnp.linalg.norm(dx_dtheta, axis=-1) return dx_dtheta_norm # =================================================================================================================================================================================== # Normals # =================================================================================================================================================================================== # this function requires scalars to work since it needs to return a (3,2) array # vmapping works, but loses the flexibility of either of the inputs being arrays, scalars or multidimensional arrays # furthermore, the jacobians are stacked to ensure jnp.vectorize can be used (it does not support multiple outputs like given by jacfwd) _cartesian_position_interpolated_grad = jax.jit(jnp.vectorize(stack_jacfwd(_cartesian_position_interpolated, argnums=(2,3)), excluded=(0,), signature='(),(),()->(3,2)')) @eqx.filter_jit def _dx_dphi_cross_dx_dtheta(flux_surface : FluxSurfaceBase, s, theta, phi): ''' Compute the cross product of dr/dphi and dr/dtheta on the flux surface as a function of (s, theta, phi) for multiple surfaces (2D data). Parameters ----------- flux_surface : FluxSurfaceBase The flux surface object. s : jnp.ndarray Surface index or normalized flux label theta : jnp.ndarray Poloidal angle(s) phi : jnp.ndarray Toroidal angle(s) Returns -------- jnp.ndarray The cross product of dr/dphi and dr/dtheta on the flux surface. ''' dX_dtheta_and_dX_dphi = _cartesian_position_interpolated_grad(flux_surface, s, theta, phi) # We use dr/dphi x dr/dtheta # Then, we want to have the normal vector outwards to the LCFS and not point into the plasma # This is accomplised by using the dphi_x_dtheta member. n = jnp.cross(dX_dtheta_and_dX_dphi[..., 1], dX_dtheta_and_dX_dphi[..., 0]) return n @eqx.filter_jit def _normal_interpolated(flux_surface : FluxSurfaceBase, s, theta, phi): ''' Normal vector on the flux surface as a function of (s, theta, phi) for multiple surfaces (2D data). Parameters ----------- flux_surface : FluxSurfaceBase The flux surface object. s : jnp.ndarray Surface index or normalized flux label theta : jnp.ndarray Poloidal angle(s) phi : jnp.ndarray Toroidal angle(s) Returns -------- jnp.ndarray Normal vector(s) on the flux surface. ''' # We use dr/dphi x dr/dtheta # Then, we want to have the normal vector outwards to the LCFS and not point into the plasma # This is accomplised by using the dphi_x_dtheta member. n = _dx_dphi_cross_dx_dtheta(flux_surface, s, theta, phi) n = n / jnp.linalg.norm(n, axis=-1, keepdims=True) return n # =================================================================================================================================================================================== # Principal curvatures # =================================================================================================================================================================================== _cartesian_position_interpolated_grad_grad = jax.jit(jnp.vectorize(stack_jacfwd(stack_jacfwd(_cartesian_position_interpolated, argnums=(2,3)), argnums = (2,3)), excluded=(0,), signature='(),(),()->(3,2,2)')) @eqx.filter_jit def _principal_curvatures_interpolated(flux_surface : FluxSurfaceBase, s, theta, phi): ''' Principal curvatures on the flux surface as a function of (s, theta, phi) for multiple surfaces (2D data). Parameters ----------- flux_surface : FluxSurfaceBase The flux surface object. s : jnp.ndarray Surface index or normalized flux label theta : jnp.ndarray Poloidal angle(s) phi : jnp.ndarray Toroidal angle(s) Returns -------- jnp.ndarray Principal curvatures(s) on the flux surface, shape (..., 2) where last index 0 is k1 and last index 1 is k2. ''' dX_dtheta_and_dX_dphi = _cartesian_position_interpolated_grad(flux_surface, s, theta, phi) d2X_dtheta2_and_d2X_dthetadphi_and_d2X_dphi2 = _cartesian_position_interpolated_grad_grad(flux_surface, s, theta, phi) # dx_dtheta_and_dX_dphi has shape (..., 3, 2), last index 0 is d/dtheta, last index 1 is d/dphi # d2X_dtheta2_and_d2X_dthetadphi_and_d2X_dphi2 has shape (..., 3, 2, 2) # 0,0 is d2/dtheta2, 0,1 and 1,0 is d2/dthetadphi, 1,1 is d2/dphi2 E = jnp.einsum("...i, ...i->...", dX_dtheta_and_dX_dphi[..., 0], dX_dtheta_and_dX_dphi[..., 0]) F = jnp.einsum("...i, ...i->...", dX_dtheta_and_dX_dphi[..., 0], dX_dtheta_and_dX_dphi[..., 1]) G = jnp.einsum("...i, ...i->...", dX_dtheta_and_dX_dphi[..., 1], dX_dtheta_and_dX_dphi[..., 1]) normal_vector = jnp.cross(dX_dtheta_and_dX_dphi[..., 1], dX_dtheta_and_dX_dphi[..., 0]) normal_vector = normal_vector / jnp.linalg.norm(normal_vector, axis=-1, keepdims=True) L = jnp.einsum("...i, ...i->...", normal_vector, d2X_dtheta2_and_d2X_dthetadphi_and_d2X_dphi2[..., 0, 0]) M = jnp.einsum("...i, ...i->...", normal_vector, d2X_dtheta2_and_d2X_dthetadphi_and_d2X_dphi2[..., 0, 1]) N = jnp.einsum("...i, ...i->...", normal_vector, d2X_dtheta2_and_d2X_dthetadphi_and_d2X_dphi2[..., 1, 1]) H = (E * N - 2 * F * M + G * L) / (2 * (E * G - F**2)) K = (L * N - M**2) / (E * G - F**2) sqrt_discriminant = jnp.sqrt(H**2 - K) k1 = - (H + sqrt_discriminant) k2 = - (H - sqrt_discriminant) return jnp.stack([k1, k2], axis=-1) # =================================================================================================================================================================================== # Volume and Surface # =================================================================================================================================================================================== @eqx.filter_jit def _volume_from_fourier(flux_surface : FluxSurfaceBase, s : float): ''' Compute the volume enclosed by the flux surface at s using a Fourier representation. A full module is used for the integration. Using the divergence theorem, the volume is computed as: V = (1/3) * ∫∫ (r · n) dA where r is the position vector, n is the outward normal vector, and dA is the differential area element on the surface. Note that dA = |dx/dtheta x dx/dphi| dtheta dphi and thus r · n dA = r · (dx/dphi x dx/dtheta) dtheta dphi. The trapezoidal rule is then used to arrive at the final value. Parameters ----------- flux_surface : FluxSurfaceBase The flux surface object containing Fourier coefficients and settings. s : float The normalized flux surface label (0 <= s <= 1). Returns -------- volume : float The volume enclosed by the flux surface at s. ''' # x: m,n Fourier modes # dx_dtheta: m,n fourier modes # dx_dphi: m,n fourier modes # normal: m + m, m+ fourier modes # x.normal -> 3 * m,n fourier modes # Nyquist -> 6 times the mode number nyquist_sampling = 6 n_theta = flux_surface.settings.mpol * nyquist_sampling +1 n_phi = flux_surface.settings.ntor * nyquist_sampling +1 theta = jnp.linspace(0, 2 * jnp.pi, n_theta, endpoint=False) phi = jnp.linspace(0, 2 * jnp.pi / flux_surface.nfp, n_phi, endpoint=False) dtheta = 2 * jnp.pi / n_theta dphi = 2 * jnp.pi / flux_surface.nfp / n_phi tt, pp = jnp.meshgrid(theta, phi, indexing='ij') surface_normals = _dx_dphi_cross_dx_dtheta(flux_surface, s, tt, pp) r = _cartesian_position_interpolated(flux_surface, s, tt, pp) f_ij = jnp.einsum('...i,...i->...', r, surface_normals) volume = jnp.sum(f_ij) * dtheta * dphi / 3.0 * flux_surface.nfp return volume @eqx.filter_jit def _volume_from_fourier_half_mod(flux_surface : FluxSurfaceBase, s : float): ''' Compute the volume enclosed by the flux surface at s using a Fourier representation. A half module is used for the integration. Using the divergence theorem, the volume is computed as: V = (1/3) * ∫∫ (r · n) dA where r is the position vector, n is the outward normal vector, and dA is the differential area element on the surface. Note that dA = |dx/dtheta x dx/dphi| dtheta dphi and thus r · n dA = r · (dx/dphi x dx/dtheta) dtheta dphi. The trapezoidal rule is then used to arrive at the final value. Parameters ----------- flux_surface : FluxSurfaceBase The flux surface object containing Fourier coefficients and settings. s : float The normalized flux surface label (0 <= s <= 1). Returns -------- volume : float The volume enclosed by the flux surface at s. ''' nyquist_sampling = 6 n_theta = flux_surface.settings.mpol * nyquist_sampling + 1 # We add one to always satisfy nyquist. n_phi = int((flux_surface.settings.ntor * nyquist_sampling + 1) / 2) # Now, given that we want to sample half of a module, we have two choices depending on the full module n_phi: # - Include the half module boundary # - Exclude the half module boundary # If we include the half-module boundary, we have to use # phi = jnp.linspace(0, 2 * jnp.pi / settings.nfp , n_phi, endpoint=True) # and then subtract half of the initial phi=0 and half of the phi = pi / nfp boundary contributions from the volume integral # If we exclude the half-module boundary, we have to use # phi = jnp.linspace(0, 2 * jnp.pi / settings.nfp, 2 * n_phi, endpoint=True)[:n_phi] # and then double the volume integral and subtracth only half of the initial phi=0 boundary. # We chose the latter option since it is one less computation. Numerically, they are exactly the same. theta = jnp.linspace(0, 2 * jnp.pi, n_theta, endpoint=False) phi = jnp.linspace(0, 2 * jnp.pi / flux_surface.nfp , 2* n_phi, endpoint=True)[:n_phi] dtheta = 2 * jnp.pi / n_theta dphi = phi[1]- phi[0] tt, pp = jnp.meshgrid(theta, phi, indexing='ij') surface_normals = _dx_dphi_cross_dx_dtheta(flux_surface, s, tt, pp) r = _cartesian_position_interpolated(flux_surface, s, tt, pp) f_ij = jnp.einsum('...i,...i->...', r, surface_normals) base_half_mod = jnp.sum(f_ij) * dtheta * dphi / 3.0 boundary_correction_b1 = jnp.sum(f_ij[:,0]) * dtheta * dphi / 3.0 return (base_half_mod * 2.0 - boundary_correction_b1) * flux_surface.nfp #========================================================================================================================================================================================== # d(theta,phi) extensions #========================================================================================================================================================================================== #---------------------------------------------------------------------------------------------------------------------------------------------------------- # interpolating d(theta,phi) on full module grid #---------------------------------------------------------------------------------------------------------------------------------------------------------- from jax_sbgeom.jax_utils import bilinear_interp def _normalize_theta_phi_full_mod(theta : jnp.ndarray, phi : jnp.ndarray, nfp : int): ''' Normalize theta and phi to [0, 1] range for full module interpolation Computes effectively: theta_norm = (theta % (2pi)) / (2pi) phi_norm = (phi % (2pi/nfp)) / (2pi/nfp) Parameters ----------- theta : jnp.ndarray Poloidal angles phi : jnp.ndarray Toroidal angles nfp : int Number of field periods in the flux surface Returns -------- theta_norm : jnp.ndarray Normalized poloidal angles phi_norm : jnp.ndarray Normalized toroidal angles ''' return(theta % (2 * jnp.pi)) / (2 * jnp.pi), (phi % (2 * jnp.pi / nfp)) / (2 * jnp.pi / nfp) def _interpolate_s_grid_full_mod(theta : jnp.ndarray, phi : jnp.ndarray, nfp : int, s_grid : jnp.ndarray): ''' Interpolates s values on a full module grid using bilinear interpolation. The grid of s values is assumed to be a uniformly sampled full module grid: s[0,0] is (0,0). s[-1,-1] is (2pi, 2pi/nfp) First normalised theta, phi to the [0, 1] range (within a full module) Parameters ----------- theta : jnp.ndarray Poloidal angles to interpolate at phi : jnp.ndarray Toroidal angles to interpolate at nfp : int Number of field periods in the flux surface s_grid : jnp.ndarray [n_theta_sampled, n_phi_sampled] Grid of s values to interpolate from. Assumed to be full module: i.e. phi in [0, 2pi/nfp], theta in [0, 2pi] (included endpoints) Returns -------- s_interp : jnp.ndarray Interpolated s values at (theta, phi) ''' return bilinear_interp(*_normalize_theta_phi_full_mod(theta, phi, nfp), s_grid) @partial(jax.jit) def _cartesian_position_interpolating_s_grid_full_mod(flux_surface : ParametrisedSurface, s_grid : jnp.ndarray, theta : jnp.ndarray, phi : jnp.ndarray): ''' Compute the Cartesian position of a flux surface at interpolated s values. The grid of s values is assumed to be a uniformly sampled full module grid: s[0,0] is (0,0). s[-1,-1] is (2pi, 2pi/nfp) If the tangent is desired, use dx_dtheta_d_varying instead of the base _dx_dtheta in flux_surface_base.py. This takes into account the ds/dtheta term. Parameters ----------- flux_surface : ParametrisedSurface Flux surface to compute position on. s_grid : jnp.ndarray [n_theta_sampled, n_phi_sampled] Grid of s values to interpolate from. Assumed to be full module: i.e. phi in [0, 2pi/nfp], theta in [0, 2pi] (included endpoints) theta : jnp.ndarray Poloidal angles to compute position at phi : jnp.ndarray Toroidal angles to compute position at Returns -------- positions : jnp.ndarray [..., 3] Cartesian positions at (theta, phi) with interpolated s values. ''' return flux_surface.cartesian_position(_interpolate_s_grid_full_mod(theta,phi, flux_surface.nfp, s_grid), theta, phi) _dx_dtheta_interpolating_s_grid_full_mod = jax.jit(jnp.vectorize(jax.jacfwd(_cartesian_position_interpolating_s_grid_full_mod, argnums=2), excluded=(0,1), signature = "(),()->(3)")) @partial(jax.jit) def _arc_length_theta_interpolating_s_grid_full_mod(flux_surface : ParametrisedSurface, s_grid : jnp.ndarray, theta : jnp.ndarray, phi : jnp.ndarray): ''' Compute the arc length derivative with respect to theta of a flux surface at interpolated s values. The grid of s values is assumed to be a uniformly sampled full module grid: s[0,0] is (0,0). s[-1,-1] is (2pi, 2pi/nfp) Parameters ----------- flux_surface : ParametrisedSurface Flux surface to compute position on. s_grid : jnp.ndarray [n_theta_sampled, n_phi_sampled] Grid of s values to interpolate from. Assumed to be full module: i.e. phi in [0, 2pi/nfp], theta in [0, 2pi] (included endpoints) theta : jnp.ndarray Poloidal angles to compute arc length derivative at phi : jnp.ndarray Toroidal angles to compute arc length derivative at Returns -------- arc_length : jnp.ndarray Arc length with respect to theta at (theta, phi) with interpolated s values. ''' return jnp.linalg.norm(_dx_dtheta_interpolating_s_grid_full_mod(flux_surface, s_grid, theta, phi), axis=-1) @partial(jax.jit, static_argnums = (2)) def _arc_length_theta_interpolating_s_grid_full_mod_finite_difference(flux_surface : ParametrisedSurface, s_grid : jnp.ndarray, n_theta : jnp.ndarray, phi : float): ''' Compute the arc length derivative with respect to theta of a flux surface at interpolated s values using finite differences. The _arc_length_theta_interpolating_s_grid_full_mod function uses JAX autodiff to compute the derivative, this can be somewhat slow for especially FluxSurfaceNormalExtendedConstantPhi. Instead, this uses the fact that we sample anyway from a grid to compute the derivative using finite differences. The grid of s values is assumed to be a uniformly sampled full module grid: s[0,0] is (0,0). s[-1,-1] is (2pi, 2pi/nfp) Parameters ----------- flux_surface : FluxSurface Flux surface to compute position on. s_grid : jnp.ndarray [n_theta_sampled, n_phi_sampled] Grid of s values to interpolate from. Assumed to be full module: i.e. phi in [0, 2pi/nfp], theta in [0, 2pi] (included endpoints) n_theta : jnp.ndarray Number of theta points to use for finite difference phi : float Toroidal angle to compute arc length derivative at Returns -------- arc_length : jnp.ndarray ''' theta_grid = jnp.linspace(0, 2 * jnp.pi, n_theta, endpoint=False) positions = flux_surface.cartesian_position(_interpolate_s_grid_full_mod(theta_grid, phi, flux_surface.nfp, s_grid), theta_grid, phi) du = theta_grid[1] - theta_grid[0] x = [jnp.roll(positions, i, axis=0) for i in [0,1,2,3,4,-4,-3,-2,-1]] return jnp.linalg.norm(1 /280 * x[-4] + -4 / 105 * x[-3] + 1/5 * x[-2] + -4/5 * x[-1] + 4/5 * x[1] + -1/5 * x[2] + 4/105 * x[3] + -1/280 * x[4], axis=1) / du