Source code for jax_sbgeom.flux_surfaces.convert_to_vmec

import jax 
import jax.numpy as jnp
from functools import partial
from .flux_surfaces_base import FluxSurfaceBase, _create_mpol_vector, _create_ntor_vector, FluxSurface, FluxSurfaceData, FluxSurfaceModes, FluxSurfaceSettings, _interpolate_s_grid_full_mod, _arc_length_theta_interpolating_s_grid_full_mod, _arc_length_theta_interpolating_s_grid_full_mod_finite_difference
from .flux_surfaces_base import _arc_length_theta_direct, _cylindrical_position_direct
from .flux_surfaces_extended import FluxSurfaceNormalExtended, FluxSurfaceNormalExtendedNoPhi, FluxSurfaceNormalExtendedConstantPhi, FluxSurfaceFourierExtended
from jax_sbgeom.jax_utils import bilinear_interp, resample_uniform_periodic_pchip, resample_uniform_periodic_linear
from warnings import warn
from typing import Type, Tuple
import equinox as eqx
from dataclasses import dataclass

@jax.jit
def _dft_forward(points : jnp.ndarray) -> jnp.ndarray:
    '''
    Compute the scaled discrete fourier transform of a 2D grid of points

    Parameters
    ----------
    points : jnp.ndarray
        2D grid of points to compute the DFT of (n_theta, n_phi)
    Returns
    -------
    fft_values : jnp.ndarray
        The scaled DFT values
    n_theta : int
        Number of theta points
    n_phi : int
        Number of phi points
    '''    

    return jnp.fft.fft2(points, norm='forward')        

@jax.jit
def _cos_sin_from_dft_forward(dft_coefficients : jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    N, M = dft_coefficients.shape # static so can use control flow

    N_h = N // 2 + 1
    M_h = M // 2 + 1

    def divide_nyquist(arr, N, M):
        if N % 2 == 0:
            arr = arr.at[-1, :].divide(2.0)
        if M % 2 == 0:
            arr = arr.at[:, -1].divide(2.0)
        return arr

    # x^c_{kl}
    xckl = 2 * jnp.real(dft_coefficients[:N_h, :M_h])
    xckl = xckl.at[0, 0].divide(2.0)
    xckl = divide_nyquist(xckl, N, M)

    # x^{c-}_{kl}
    xcmkl = jnp.zeros_like(xckl)
    flipped = jnp.real(dft_coefficients[:, ::-1])
    xcmkl = xcmkl.at[1:, 1:].set(2 * flipped[1:N_h, :M_h - 1])
    xcmkl = divide_nyquist(xcmkl, N, M)

    # x^s_{kl}
    xskl = -2 * jnp.imag(dft_coefficients[:N_h, :M_h])
    xskl = divide_nyquist(xskl, N, M)

    # x^{s-}_{kl}
    xsmkl = jnp.zeros_like(xskl)
    flipped_imag = jnp.imag(dft_coefficients[:, ::-1])
    xsmkl = xsmkl.at[1:, 1:].set(-2 * flipped_imag[1:N_h, :M_h - 1])
    xsmkl = divide_nyquist(xsmkl, N, M)

    if N % 2 == 0:
        xsmkl = xsmkl.at[-1, :].multiply(-1.0)

    return xckl, xcmkl, xskl, xsmkl


@partial(jax.jit, static_argnums = 4)
def _convert_cos_sin_to_vmec(xckl : jnp.ndarray, xcmkl : jnp.ndarray, xskl : jnp.ndarray, xsmkl : jnp.ndarray, cosine : bool) -> jnp.ndarray:
    mpol = xckl.shape[0] - 1
    ntor = xckl.shape[1] - 1
    settings    = FluxSurfaceSettings(mpol=mpol, ntor=ntor, nfp=1)

    modes       = FluxSurfaceModes.from_settings(settings)
    mpol_vector = modes.mpol_vector
    ntor_vector = modes.ntor_vector
    ntor_vector_abs = jnp.abs(ntor_vector)

    if cosine:        
        mn_is0     = jnp.logical_or(mpol_vector == 0, ntor_vector == 0)
        n_isneg    = ntor_vector < 0
        v_mn0      =  xckl [mpol_vector, ntor_vector_abs] + xcmkl[mpol_vector, ntor_vector_abs]      # m = 0 or n = 0
        v_pos_npos =  xcmkl[mpol_vector, ntor_vector_abs]                                            # m > 0, n > 0
        v_pos_nneg =  xckl [mpol_vector, ntor_vector_abs]                                            # m > 0, n < 0
        return jnp.where( mn_is0, v_mn0, jnp.where( n_isneg, v_pos_nneg, v_pos_npos))
    else:
        mn_isboth0 = jnp.logical_and(mpol_vector == 0, ntor_vector == 0)
        m_is0      = mpol_vector == 0
        n_isneg    = ntor_vector < 0
        n_is0      = ntor_vector == 0

        v_mnboth0      = jnp.zeros_like(mpol_vector)                                                   # m = 0, n = 0  
        v_m0_npos      = - xskl [mpol_vector, ntor_vector_abs]   + xsmkl[mpol_vector, ntor_vector_abs] # m = 0, n > 0
        v_mpos_n0      =   xskl [mpol_vector, ntor_vector_abs]   + xsmkl[mpol_vector,ntor_vector_abs ] # m > 0, n = 0
        v_mpos_nneg    =   xskl [mpol_vector, ntor_vector_abs]                                         # m > 0, n < 0
        v_mpos_npos    =   xsmkl[mpol_vector, ntor_vector_abs]                                         # m > 0, n > 0        
        return jnp.where( mn_isboth0, v_mnboth0,                                        # m = 0, n = 0
                                      jnp.where( m_is0,      v_m0_npos,                 #  m_0, n! =0  
                                                jnp.where(n_is0,  v_mpos_n0,            # n = 0, m > 0
                                                        jnp.where(n_isneg, v_mpos_nneg, # m != 0 n < 0
                                                                    v_mpos_npos))))     # m!= 0 n > 0

@jax.jit
def _rz_to_vmec_representation(R_grid : jnp.ndarray, Z_grid : jnp.ndarray) -> FluxSurfaceData:
    assert R_grid.shape == Z_grid.shape, "R and Z grids must have the same shape but got {} and {}".format(R_grid.shape, Z_grid.shape)
    R_dft = _dft_forward(R_grid)
    Z_dft = _dft_forward(Z_grid)
    R_ckl, R_cmkl, R_skl, R_smkl = _cos_sin_from_dft_forward(R_dft)
    Z_ckl, Z_cmkl, Z_skl, Z_smkl = _cos_sin_from_dft_forward(Z_dft)
    R_vmec = _convert_cos_sin_to_vmec(R_ckl, R_cmkl, R_skl, R_smkl, cosine=True)
    Z_vmec = _convert_cos_sin_to_vmec(Z_ckl, Z_cmkl, Z_skl, Z_smkl, cosine=False)
    return FluxSurfaceData(R_vmec, Z_vmec)

def _index_mn(m,n, ntor):    
    return n + (m>0) * (2 * ntor + 1) * m  

def _size_mn(mpol, ntor):
    return (2 * ntor + 1) * mpol + ntor +1

@eqx.filter_jit
def _convert_array_to_different_settings(array : jnp.ndarray, new_settings : FluxSurfaceSettings, old_settings : FluxSurfaceSettings) -> jnp.ndarray:    
    '''
    Convert a Fourier representation from one (mpol, ntor) to another (mpol, ntor) by zero-padding or truncating.
    Does not take into account the field-period symmetry: this can thus also be used to convert to different nfp.

    Parameters
    ----------
    array : jnp.ndarray
        Array of shape (..., N) where N is the number of Fourier modes in the old representation.   
    new_settings : FluxSurfaceSettings
        The new Fourier settings (mpol, ntor).
    old_settings : FluxSurfaceSettings
        The old Fourier settings (mpol, ntor).
    Returns
    -------
    array_new : jnp.ndarray
        Array of shape (..., N_new) where N_new is the number of Fourier modes in the new representation.
    '''
    settings_1_nfp  = FluxSurfaceSettings(mpol=new_settings.mpol, ntor=new_settings.ntor, nfp=1) 
    mpol_vector_new = _create_mpol_vector(settings_1_nfp)
    ntor_vector_new = _create_ntor_vector(settings_1_nfp)     
    data_available  = jnp.logical_and(mpol_vector_new <= old_settings.mpol, jnp.abs(ntor_vector_new) <= old_settings.ntor)
    
    # we ensure we don't go out of bounds here by setting indices to 0 when data is not available
    # jnp.where *will* access both branches before selecting. We need to set the out-of-bounds indices to a safe value and then 
    # select 0.0 in the following where:
    index_mn_new    = jnp.where(data_available, _index_mn(mpol_vector_new, ntor_vector_new, old_settings.ntor), 0)
    array_new       = jnp.where(data_available, array[..., index_mn_new], 0.0)            
    return array_new

@eqx.filter_jit
def _convert_fluxsurfacedata_to_different_settings(data : FluxSurfaceData, new_settings : FluxSurfaceSettings, old_settings : FluxSurfaceSettings):
    Rmnc_new = _convert_array_to_different_settings(data.Rmnc, new_settings, old_settings)
    Zmns_new = _convert_array_to_different_settings(data.Zmns, new_settings, old_settings)
    return FluxSurfaceData(Rmnc_new, Zmns_new)

[docs] def convert_to_different_settings(fluxsurface : FluxSurfaceBase, settings_new : FluxSurfaceSettings) -> FluxSurface: ''' Convert FluxSurface to a different (mpol, ntor) representation. Note that this returns the same type as the input fluxsurface. However, if it is e.g. a FluxSurfaceFourierExtended, the extension data is not converted or used, so the return type will be only the base FluxSurface. Parameters ----------- fluxsurface : FluxSurfaceBase The flux surface to convert. settings_new : FluxSurfaceSettings The new Fourier settings (mpol, ntor). Returns -------- fluxsurface_new : FluxSurface New flux surface with Fourier coefficients in the new (mpol, ntor) representation. Same as type as input fluxsurface. ''' return type(fluxsurface)(data = _convert_fluxsurfacedata_to_different_settings(fluxsurface.data, settings_new, fluxsurface.settings), modes = FluxSurfaceModes.from_settings(settings_new), settings = settings_new)
@eqx.filter_jit def _convert_to_equal_arclength_single(flux_surface : FluxSurfaceBase, n_theta : int, n_phi : int, n_theta_s_arclength : int) -> Tuple[FluxSurfaceData, FluxSurfaceSettings]: ''' Convert a single flux surface to a Fourier representation sampled on an equal arclength poloidal grid. This requires a full FluxSurface instead of only settings and data, as the position function is needed. This makes batching easier as well: the function can be vmapped over FluxSurface objects. (see convert_to_equal_arclength) Parameters ----------- flux_surface : FluxSurfaceBase Flux surface to convert. n_theta : int Number of poloidal modes in the Fourier representation [= n_theta // 2] n_phi : int Number of toroidal modes in the Fourier representation [= n_phi // 2] n_theta_s_arclength : int Number of poloidal points to use for the arclength sampling grid. Returns -------- flux_surface_data : FluxSurfaceData Fourier representation of the sampled flux surface. settings : FluxSurfaceSettings Settings of the Fourier representation (mpol, ntor, nfp). ''' assert flux_surface.data.Rmnc.ndim == 1, "convert_to_equal_arclength only supports single surface conversion" theta_s = jnp.linspace(0, 2 * jnp.pi, n_theta_s_arclength, endpoint=False) phi_s = jnp.linspace(0, 2 * jnp.pi / flux_surface.nfp, n_phi, endpoint=False) theta_mg_s, phi_mg_s = jnp.meshgrid(theta_s, phi_s, indexing='ij') arc_lengths = _arc_length_theta_direct(flux_surface, theta_mg_s, phi_mg_s) #[n_theta_s_arclength, n_phi] new_theta_mg = jax.vmap(resample_uniform_periodic_pchip, in_axes=(1, None), out_axes=1)(arc_lengths, n_theta) * 2 * jnp.pi _, phi_mg = jnp.meshgrid(jnp.zeros(new_theta_mg.shape[0]), phi_s, indexing='ij') # [n_theta_sample_arclength, n_phi] RZphi_sampled = _cylindrical_position_direct(flux_surface, new_theta_mg, phi_mg) # [n_theta_sample_arclength, n_phi, 3] flux_surface_data = _rz_to_vmec_representation(RZphi_sampled[..., 0], RZphi_sampled[..., 1]) return flux_surface_data, FluxSurfaceSettings(*mpol_ntor_from_ntheta_nphi(n_theta, n_phi), flux_surface.nfp)
[docs] @eqx.filter_jit def convert_to_equal_arclength(flux_surface : FluxSurfaceBase, n_theta : int, n_phi : int, n_theta_s_arclength : int) -> Tuple[FluxSurfaceData, FluxSurfaceSettings]: if flux_surface.data.Rmnc.ndim == 1: return _convert_to_equal_arclength_single(flux_surface, n_theta, n_phi, n_theta_s_arclength) else: flux_surface_data, _ = jax.vmap(_convert_to_equal_arclength_single, in_axes=(type(flux_surface)(FluxSurfaceData(0,0), FluxSurfaceModes(None, None), FluxSurfaceSettings(None, None, None)), None, None, None ))(flux_surface, n_theta, n_phi, n_theta_s_arclength) return flux_surface_data, FluxSurfaceSettings(*mpol_ntor_from_ntheta_nphi(n_theta, n_phi), flux_surface.nfp)
[docs] def mpol_ntor_from_ntheta_nphi(n_theta : int, n_phi : int) -> Tuple[int,int]: mpol = n_theta // 2 ntor = n_phi // 2 return mpol, ntor
[docs] @eqx.filter_jit def create_fourier_representation(flux_surface : FluxSurfaceBase, s : jnp.ndarray, theta_grid : jnp.ndarray) -> Tuple[FluxSurfaceData, FluxSurfaceSettings]: ''' Create a Fourier representation of a flux surface at given (s, theta) grid points. Parameters ----------- flux_surface : FluxSurfaceBase Flux_Surface to create the Fourier representation of. s : jnp.ndarray [n_theta, n_phi] or float Radial coordinate(s) at which to sample the flux surface. If an array, must have the same shape as theta_grid. theta_grid : jnp.ndarray [n_theta, n_phi] Grid of poloidal angles at which to sample the flux surface. Returns -------- flux_surface_data : FluxSurfaceData Fourier representation of the sampled flux surface. settings : FluxSurfaceSettings Settings of the Fourier representation (mpol, ntor, nfp). ''' # Static Checks assert theta_grid.ndim == 2, "theta_grid must be a 2D grid (n_theta, n_phi) but got shape {}".format(theta_grid.shape) if jnp.array(s).ndim != 0: assert s.shape == theta_grid.shape, "If s is an array, it must have the same shape as theta_grid but got s shape {} and theta_grid shape {}".format(s.shape, theta_grid.shape) # Static warnings if isinstance(flux_surface, FluxSurfaceNormalExtended): warn("FluxSurfaceNormalExtended does not have phi_in = phi_out. This introduces errors when Fourier transforming", UserWarning) if type(flux_surface) == FluxSurface: warn("FluxSurface base class does not extend beyond the LCFS. Any conversion with s>0.0 will reproduce the LCFS", UserWarning) n_theta = theta_grid.shape[0] n_phi = theta_grid.shape[1] phi_grid = jnp.linspace(0, 2*jnp.pi / flux_surface.nfp, n_phi, endpoint=False) _, phi_mg = jnp.meshgrid(jnp.linspace(0, 2*jnp.pi, n_theta, endpoint=False), phi_grid, indexing='ij') RZphi_sampled = flux_surface.cylindrical_position(s, theta_grid, phi_mg) flux_surface_data = _rz_to_vmec_representation(RZphi_sampled[..., 0], RZphi_sampled[..., 1]) return flux_surface_data, FluxSurfaceSettings(*mpol_ntor_from_ntheta_nphi(n_theta, n_phi), flux_surface.nfp)
@eqx.filter_jit def _create_fourier_representation_d_interp_single(flux_surfaces : FluxSurfaceBase, d : jnp.ndarray, n_theta : int, n_phi : int): ''' Create a Fourier representation of an extended flux surface with an interpolated extension distance. Parameters ----------- flux_surfaces : FluxSurfaceBase Flux_Surface to extend using the distance function. Flux surface must be of type FluxSurfaceNormalExtendedNoPhi or FluxSurfaceNormalExtendedConstantPhi to ensure valid results (phi_in must be phi_out for FFT) d : jnp.ndarray [n_theta_sampled, n_phi_sampled] or float Distance function to extend the flux surface with. Assumed to be full module: i.e. phi in [0, 2pi/nfp], theta in [0, 2pi] (included endpoints) n_theta : int Number of poloidal points in the output Fourier representation. n_phi : int Number of toroidal points in the output Fourier representation. Returns -------- ''' theta, phi = jnp.linspace(0, 2*jnp.pi, n_theta, endpoint=False), jnp.linspace(0, 2*jnp.pi/flux_surfaces.nfp, n_phi, endpoint=False) theta_mg, phi_mg = jnp.meshgrid(theta, phi, indexing='ij') s_interp = _interpolate_s_grid_full_mod(theta_mg, phi_mg, flux_surfaces.nfp, jnp.atleast_2d(d) + 1.0) flux_surface_data, settings = create_fourier_representation(flux_surfaces, s_interp, theta_mg) return flux_surface_data, settings # =================================================================================================================================================================================== # Convenience functions # =================================================================================================================================================================================== _create_fourier_representation_d_interp_vmap = jax.vmap(_create_fourier_representation_d_interp_single, in_axes=(None, 0, None, None))
[docs] @eqx.filter_jit def create_fourier_representation_d_interp(flux_surface : FluxSurfaceBase, d : jnp.ndarray, n_theta : int, n_phi : int): ''' Create a Fourier representation of an extended flux surface with an interpolated extension distance. Can be batched over d: if d is a scalar or 2D array, a single flux surface is created. If d is a 1D or 3D array, multiple flux surfaces are created (batched). Parameters ----------- flux_surface : FluxSurfaceBase Flux_Surface to extend using the distance function. Flux surface must be of type FluxSurfaceNormalExtendedNoPhi or FluxSurfaceNormalExtendedConstantPhi to ensure valid results (phi_in must be phi out for FFT) d : jnp.ndarray Distance function to extend the flux surface with. Assumed to be full module: i.e. phi in [0, 2pi/nfp], theta in [0, 2pi] (included endpoints) If d is a scalar or 2D array, a single flux surface is created. If d is a 1D or 3D array, multiple flux surfaces are created (batched). n_theta : int Number of poloidal points in the output Fourier representation. n_phi : int Number of toroidal points in the output Fourier representation. Returns -------- flux_surface_data : FluxSurfaceData Fourier representation of the sampled flux surface. settings : FluxSurfaceSettings Settings of the Fourier representation (mpol, ntor, nfp). ''' assert jnp.array(d).ndim < 4, f"d must be a scalar or 1D, 2D or 3D array but got shape {d.shape}" d = jnp.array(d) new_settings = FluxSurfaceSettings(*mpol_ntor_from_ntheta_nphi(n_theta, n_phi), flux_surface.nfp) if d.ndim == 0 or d.ndim == 2: flux_surface_data, _ = _create_fourier_representation_d_interp_single(flux_surface, d, n_theta, n_phi) elif d.ndim == 1 or d.ndim == 3: flux_surface_data, _ = _create_fourier_representation_d_interp_vmap(flux_surface, d, n_theta, n_phi) else: raise ValueError("d must be a scalar or 2D array but got shape {}".format(d.shape)) return flux_surface_data, new_settings
[docs] def create_flux_surface_d_interp(flux_surface : FluxSurfaceBase, d : jnp.ndarray, n_theta : int, n_phi : int, type_c : Type = FluxSurface) -> FluxSurface: ''' Convenience function of create_fourier_representation_d_interp + type_c.from_data_settings_full, returning a FluxSurface of given type. Parameters ----------- flux_surface : FluxSurfaceBase Flux_Surface to extend using the distance function. Flux surface must be of type FluxSurfaceNormal d : jnp.ndarray Distance function to extend the flux surface with. Assumed to be full module: i.e. phi in [0, 2pi/nfp], theta in [0, 2pi] (included endpoints) n_theta : int Number of poloidal points in the output Fourier representation. n_phi : int Number of toroidal points in the output Fourier representation. Returns -------- flux_surface : FluxSurface Flux surface with Fourier representation. ''' return type_c.from_data_settings_full(*create_fourier_representation_d_interp(flux_surface, d, n_theta, n_phi))
[docs] def create_extended_flux_surface_d_interp(flux_surface : FluxSurfaceBase, d : jnp.ndarray, n_theta : int, n_phi : int): ''' Creates a FluxSurfaceFourierExtended by extending a given flux surface using a distance function d and interpolating the distance function. Convenience function of create_fourier_representation_d_interp + FluxSurface.from_data_settings_full + FluxSurfaceFourierExtended.from_flux_surface_and_extension, returning a FluxSurfaceFourierExtended. Compared to create_flux_surface_d_interp, this function directly returns a FluxSurfaceFourierExtended. Parameters ----------- flux_surface : FluxSurfaceBase Flux_Surface to extend using the distance function. Flux surface must be of type FluxSurfaceNormal d : jnp.ndarray Distance function to extend the flux surface with. Assumed to be full module: i.e. phi in [0, 2pi/nfp], theta in [0, 2pi] (included endpoints) n_theta : int Number of poloidal points in the output Fourier representation. n_phi : int Number of toroidal points in the output Fourier representation. Returns -------- flux_surface_extended : FluxSurfaceFourierExtended Extended flux surface with Fourier representation. ''' return FluxSurfaceFourierExtended.from_flux_surface_and_extension(FluxSurface(data = flux_surface.data, modes = flux_surface.modes, settings = flux_surface.settings), create_flux_surface_d_interp(flux_surface, d, n_theta, n_phi, type_c=FluxSurface))
[docs] @eqx.filter_jit def create_fourier_representation_d_interp_equal_arclength(flux_surface : FluxSurfaceBase, d : jnp.ndarray, n_theta : int, n_phi : int, n_theta_s_arclength : int): ''' Convenience function of create_fourier_representation_d_interp + convert_to_equal_arclength Parameters ----------- flux_surface : FluxSurfaceBase Flux_Surface to extend using the distance function. Flux surface must be of type FluxSurfaceNormalExtendedNoPhi or FluxSurfaceNormalExtendedConstantPhi to ensure valid results (phi_in must be phi_out for FFT) d : jnp.ndarray Distance function to extend the flux surface with. Assumed to be full module: i.e. phi in [0, 2pi/nfp], theta in [0, 2pi] (included endpoints) n_theta : int Number of poloidal points in the output Fourier representation. n_phi : int Number of toroidal points in the output Fourier representation. n_theta_s_arclength : int Number of poloidal points to use for the arclength parametrization. Returns -------- flux_surface_data : FluxSurfaceData Fourier representation of the sampled flux surface. settings : FluxSurfaceSettings Settings of the Fourier representation (mpol, ntor, nfp). ''' return convert_to_equal_arclength(FluxSurface.from_data_settings(*create_fourier_representation_d_interp(flux_surface, d, n_theta, n_phi)), n_theta, n_phi, n_theta_s_arclength)
[docs] def create_flux_surface_d_interp_equal_arclength(flux_surface : FluxSurfaceBase, d : jnp.ndarray, n_theta : int, n_phi : int, n_theta_s_arclength : int, type_c : Type = FluxSurface): ''' Convenience function of create_fourier_representation_d_interp + convert_to_equal_arclength + type_c.from_data_settings_full, returning a FluxSurface of given type. Parameters ----------- flux_surface : FluxSurfaceBase Flux_Surface to extend using the distance function. Flux surface must be of type FluxSurfaceNormalExtendedNoPhi or FluxSurfaceNormalExtendedConstantPhi to ensure valid results (phi_in must be phi_out for FFT) d : jnp.ndarray Distance function to extend the flux surface with. Assumed to be full module: i.e. phi in [0, 2pi/nfp], theta in [0, 2pi] (included endpoints) n_theta : int Number of poloidal points in the output Fourier representation. n_phi : int Number of toroidal points in the output Fourier representation. n_theta_s_arclength : int Number of poloidal points to use for the arclength parametrization. Returns -------- flux_surface : FluxSurface Flux surface with Fourier representation sampled on an equal arclength poloidal grid. ''' return type_c.from_data_settings_full(*create_fourier_representation_d_interp_equal_arclength(flux_surface, d, n_theta, n_phi, n_theta_s_arclength))
[docs] def create_extended_flux_surface_d_interp_equal_arclength(flux_surface : FluxSurfaceBase, d : jnp.ndarray, n_theta : int, n_phi : int, n_theta_s_arclength : int): ''' Creates a FluxSurfaceFourierExtended by extending a given flux surface using a distance function d, interpolating the distance function, and sampling on an equal arclength poloidal grid. Convenience function of create_fourier_representation_d_interp + convert_to_equal_arclength + FluxSurface.from_data_settings_full + FluxSurfaceFourierExtended.from_flux_surface_and_extension, returning a FluxSurfaceFourierExtended. Compared to create_flux_surface_d_interp_equal_arclength, this function directly returns a FluxSurfaceFourierExtended. Parameters ----------- flux_surface : FluxSurfaceBase Flux_Surface to extend using the distance function. Flux surface must be of type FluxSurfaceNormalExtendedNoPhi or FluxSurfaceNormalExtendedConstantPhi to ensure valid results (phi_in must be phi_out for FFT) d : jnp.ndarray Distance function to extend the flux surface with. Assumed to be full module: i.e. phi in [0, 2pi/nfp], theta in [0, 2pi] (included endpoints) If d is a scalar or 2D array, a single flux surface is created. If d is a 1D or 3D array, multiple flux surfaces are created (batched). n_theta : int Number of poloidal points in the output Fourier representation. n_phi : int Number of toroidal points in the output Fourier representation. n_theta_s_arclength : int Number of poloidal points to use for the arclength parametrization. Returns -------- flux_surface_extended : FluxSurfaceFourierExtended Extended flux surface with Fourier representation sampled on an equal arclength poloidal grid. ''' return FluxSurfaceFourierExtended.from_flux_surface_and_extension(FluxSurface(data = flux_surface.data, modes = flux_surface.modes, settings = flux_surface.settings), create_flux_surface_d_interp_equal_arclength(flux_surface, d, n_theta, n_phi, n_theta_s_arclength, type_c=FluxSurface))