import jax
import jax.numpy as jnp
from dataclasses import dataclass
from jax_sbgeom.flux_surfaces import ToroidalExtent, mesh_tetrahedra, FluxSurface, FluxSurfaceFourierExtended, ParametrisedSurface
from jax_sbgeom.flux_surfaces.flux_surface_meshing import _mesh_tetrahedra, mesh_watertight_layers
from functools import cached_property
import equinox as eqx
from abc import abstractmethod
# This provides a interface for creating blanket geometries around flux surfaces
[docs]
class BlanketMeshStructure(eqx.Module):
'''
Class representing the structure of a volume blanket mesh.
Has several convenience functions to slice the blanket and functions
defined on the blanket.
'''
n_theta : int
'''
Number of poloidal points in the blanket mesh.
'''
n_phi : int
'''
Number of toroidal points in the blanket mesh.
'''
n_s : int
'''
Number of radial points in the blanket mesh.
'''
include_axis : bool
'''
Whether the axis is included in the blanket mesh.
'''
full_angle : bool
'''
Whether the blanket mesh covers a full torus.
'''
@property
def n_phi_blocks(self):
return jnp.where(self.full_angle, self.n_phi, self.n_phi - 1)
@property
def n_theta_blocks(self):
return self.n_theta
@property
def n_layered_blocks(self):
return self.n_s - 1
def _norm_negative(self, layer_i : int):
return jnp.where(layer_i < 0, self.n_layered_blocks + layer_i , layer_i)
def _n_tet_per_block(self, layer_i : int):
layer_i_mod = self._norm_negative(layer_i)
return jnp.where(jnp.logical_and(layer_i_mod == 0, self.include_axis), 3, 6 )
[docs]
def n_blocks_in_layer(self, layer_i : int):
layer_i_mod = self._norm_negative(layer_i)
return jnp.where(jnp.logical_and(layer_i_mod == 0, self.include_axis), 3 * self.n_theta_blocks * self.n_phi_blocks, 6 * self.n_theta_blocks * self.n_phi_blocks)
[docs]
def layer_start(self, layer_i : int):
layer_i_mod = self._norm_negative(layer_i)
return jnp.where(self.include_axis, 3 * self.n_theta_blocks * self.n_phi_blocks * (layer_i_mod > 0) + 6 * self.n_theta_blocks * self.n_phi_blocks * (layer_i_mod - 1) * (layer_i_mod > 0), 0)
[docs]
def layer_slice(self, layer_i : int):
return slice(self.layer_start(layer_i), self.layer_start(layer_i) + self.n_blocks_in_layer(layer_i))
[docs]
def reshape_to_layer(self, layer_i : int, arr : jnp.ndarray):
'''
Reshapes a flat array of shape (n_elements,) to the shape of the blocks in the given layer, which is (n_phi_blocks, n_theta_blocks, n_tet_per_block).
This n_tet_per_block is 3 for the first layer if the axis is included, and 6 otherwise. For all other layers, it is 6.
'''
return arr[self.layer_slice(layer_i)].reshape(self.n_phi_blocks, self.n_theta_blocks, self._n_tet_per_block(layer_i))
[docs]
def reshape_without_axis(self, arr : jnp.ndarray):
'''
Reshapes a flat array of shape (n_elements,) to the shape of the blocks in the blanket, which is (n_layered_blocks, n_phi_blocks, n_theta_blocks, n_tet_per_block).
Discards the first layer if an axis is present and then reshapes the last.
'''
if self.include_axis:
return arr[self.layer_start(1):].reshape(self.n_layered_blocks - 1, self.n_phi_blocks, self.n_theta_blocks, 6)
else:
return arr.reshape(self.n_layered_blocks, self.n_phi_blocks, self.n_theta_blocks, 6)
[docs]
def map_radial_array_to_layers(self, arr : jnp.ndarray):
'''
Maps a flat array of shape (..., n_s) to the shape of the layers in the blanket, which is (n_layered_blocks,).
This is useful for mapping a radial function defined on the layers to the blocks in the blanket (e.g. materials)
'''
assert arr.shape[-1] == self.n_layered_blocks, f"Input array has last axis shape {arr.shape[-1]}, but expected shape is {self.n_layered_blocks}"
return jnp.repeat(arr, self.n_blocks_in_layer(jnp.arange(self.n_layered_blocks)), axis=-1)
@property
def n_elements(self):
return jnp.where(self.include_axis, 6 * self.n_theta_blocks * self.n_phi_blocks * (self.n_s - 2) + 3 * self.n_theta_blocks * self.n_phi_blocks, 6 * self.n_theta_blocks * self.n_phi_blocks * (self.n_s - 1))
@property
def n_points(self):
return jnp.where(self.include_axis, self.n_theta * self.n_phi * (self.n_s - 1) + self.n_phi, self.n_theta * self.n_phi * self.n_s)
[docs]
class LayeredDiscreteBlanket(eqx.Module):
'''
Class representing a layered, structured, discrete blanket structure around a flux surface.
'''
n_theta : int
'''
Number of poloidal points in the blanket mesh.
'''
n_phi : int
'''
Number of toroidal points in the blanket mesh.
'''
resolution_layers : tuple
'''
Tuple of number of discrete layers in each layer of the blanket. The total number of discrete layers is given by the sum of the entries in this tuple. The length of this tuple should be equal to n_layers.
'''
toroidal_extent : ToroidalExtent
'''
Toroidal extent of the blanket. This can be a full torus, a half module, etc. depending on the application.
'''
def __check_init__(self) :
pass
@property
def n_discrete_layers(self):
return jnp.sum(jnp.array(self.resolution_layers))
@property
def n_physical_layers(self):
return len(self.resolution_layers)
[docs]
@abstractmethod
def volume_mesh(parametrised_surface : ParametrisedSurface):
...
[docs]
@abstractmethod
def surface_mesh(parametrised_surface : ParametrisedSurface):
...
@property
@abstractmethod
def volume_mesh_structure(self) -> BlanketMeshStructure:
...
@property
def layer_array(self):
'''
Maps a flat array of shape (n_elements,) to the shape of the layers in the blanket, which is (n_layered_blocks,).
This is useful for mapping a radial function defined on the layers to the blocks in the blanket (e.g. materials)
'''
return self.volume_mesh_structure.map_radial_array_to_layers(jnp.repeat(jnp.arange(self.n_physical_layers), jnp.array(self.resolution_layers)))
@property
@abstractmethod
def s_spacing(self):
'''
The s spacing for the layered discrete blanket used for meshing the discrete layers.
'''
...
[docs]
def map_to_physical_spacing(self, d_layers : jnp.ndarray):
'''
Maps the s_spacing property to physical spacing.
Takes a 1D array of the same size of the number of physical layers.
The result has the meaning of a normal radial coordinate until s = 1.0, beyond is the distance from the lcfs. In other words,
1.2 means a distance of 0.2 from the LCFS.
Parameters
----------
d_layers : jnp.ndarray
Array of cumulative physical layer boundary positions. Must have length equal to the number of physical layers.
Returns
-------
jnp.ndarray
The physical spacing for the layered discrete blanket.
'''
assert d_layers.shape == (self.n_physical_layers,), f"Input array has shape {d_layers.shape}, but expected shape is {(self.n_physical_layers,)}"
return self._map_to_physical_spacing(d_layers)
@abstractmethod
def _map_to_physical_spacing(self, d_layers : jnp.ndarray):
...
[docs]
@jax.tree_util.register_dataclass
@dataclass(frozen=True)
class DiscreteFiniteSizeCoilSet:
'''
Class representing a set of discrete finite size coils forming a coilset for blanket creation.
Attributes
-------
n_points_per_coil : int
Number of discrete points per coil
toroidal_extent : ToroidalExtent
Toroidal extent of the coilset
width_R : float
Radial width of the finite sized coils
width_phi : float
Toroidal width of the finite sized coils
'''
n_points_per_coil : int
toroidal_extent : ToroidalExtent
width_R : float
width_phi : float
def _compute_layer_slice(n_discrete_layers : int, n_theta : int, n_phi : int, layer_i : int):
if layer_i < 0:
layer_i = n_discrete_layers + layer_i
if layer_i >= n_discrete_layers:
raise ValueError(f"Layer {layer_i} is out of bounds for {n_discrete_layers} layers.")
layer_wedge = 3 * n_theta * ( n_phi - 1)
layer_else = 6 * n_theta * (n_phi - 1)
dedge = 0
if layer_i == 0:
layer_blocks = slice(0, layer_wedge)
else:
layer_blocks = slice(layer_wedge + (layer_i - 1) * layer_else + dedge, layer_wedge + layer_i * layer_else - dedge)
return layer_blocks
def _compute_s_spacing_transformed_axis(blanket : LayeredDiscreteBlanket, s_power_sampling : int):
'''
Computes the s spacing for the layered discrete blanket in s for in [0, 1 + n_layers].
Parameters
----------
blanket : LayeredDiscreteBlanket
The layered discrete blanket to compute the spacing for.
s_power_sampling : int
The power to which the radial coordinate s is raised when sampling.
Higher values lead to more points near the inner layers.
Returns
-------
jnp.ndarray
The s spacing for the layered discrete blanket.
'''
inner_blanket_spacing = jnp.linspace(0.0, 1.0, blanket.resolution_layers[0]) ** s_power_sampling
s_layers = jnp.concatenate([inner_blanket_spacing, jnp.concatenate([jnp.linspace(2.0 + i, 3.0 + i, blanket.resolution_layers[i + 1], endpoint=False) for i in range(blanket.n_physical_layers - 1)], axis=0), jnp.array([1.0 + blanket.n_physical_layers])])
return s_layers