Source code for jax_sbgeom.coils.biot_savart

### Experimental!!
import jax.numpy as jnp
import jax
from functools import partial
[docs] class Constant: MU_0 = 4e-7 * jnp.pi # Permeability of free space
@jax.jit def _biot_savart_internal(current : float, position : jnp.ndarray, delta_line_segment : jnp.ndarray, query_point : jnp.ndarray): ''' Biot-Savart law for a single line element and single query point Parameters ---------- current : float Current in the coil position : jnp.ndarray Position of the line element [3,] tangent : jnp.ndarray Tangent vector of the line element [3,] query_point : jnp.ndarray Point at which to evaluate the field [3,] Returns ------- jnp.ndarray Magnetic field at the query point [3,] ''' r_vec = query_point - position r_mag = jnp.linalg.norm(r_vec, axis=-1) dl_x_r = jnp.cross(delta_line_segment, r_vec, axis=-1) B = (Constant.MU_0 * current) / (4.0 * jnp.pi) * dl_x_r / (r_mag**3)[..., None] return B _vmapped_biot_savart = jax.jit(jax.vmap(_biot_savart_internal, in_axes=(0, 0, 0, None)))
[docs] def biot_savart_single(currents : jnp.ndarray, positions : jnp.ndarray, delta_line_segments : jnp.ndarray, query_point : jnp.ndarray): ''' Biot-Savart law for multiple line elements and a single query point Parameters ---------- currents : jnp.ndarray Currents in the coil segments [N,] positions : jnp.ndarray Positions of the line elements [N, 3] delta_line_segments : jnp.ndarray Line element vectors [N, 3] query_point : jnp.ndarray Point at which to evaluate the field [3,] Returns ------- jnp.ndarray Magnetic field at the query point [3,] ''' B_segments = _vmapped_biot_savart(currents, positions, delta_line_segments, query_point) # shape (N, 3) B_total = jnp.sum(B_segments, axis=0) # shape (3,) return B_total
_biot_savart_multiple = jax.jit(jax.vmap(biot_savart_single, in_axes=(None, None, None, 0)))
[docs] @jax.jit def biot_savart(currents : jnp.ndarray, positions : jnp.ndarray, delta_line_segments : jnp.ndarray, query_points : jnp.ndarray): ''' Biot-Savart law for multiple line elements and multiple query points Parameters ---------- currents : jnp.ndarray Currents in the coil segments [N,] positions : jnp.ndarray Positions of the line elements [N, 3] delta_line_segments : jnp.ndarray Line element vectors [N, 3] query_points : jnp.ndarray Points at which to evaluate the field [M, 3] Returns ------- jnp.ndarray Magnetic field at the query points [M, 3] ''' return _biot_savart_multiple(currents, positions, delta_line_segments, query_points)
[docs] @partial(jax.jit, static_argnums = (4,)) def biot_savart_batch(currents : jnp.ndarray, positions : jnp.ndarray, delta_line_segments : jnp.ndarray, query_points : jnp.ndarray, batch_size : int = None): ''' Biot-Savart law for multiple line elements and multiple query points Parameters ---------- currents : jnp.ndarray Currents in the coil segments [N,] positions : jnp.ndarray Positions of the line elements [N, 3] delta_line_segments : jnp.ndarray Line element vectors [N, 3] query_points : jnp.ndarray Points at which to evaluate the field [M, 3] Returns ------- jnp.ndarray Magnetic field at the query points [M, 3] ''' if batch_size is None: batch_size = query_points.shape[0] f_batch = lambda x : biot_savart_single(currents, positions, delta_line_segments, x) return jax.lax.map(f_batch, query_points, batch_size=batch_size)
[docs] def create_coilset_total_arrays(jax_coilset, currents, number_of_samples_per_coil): coil_samples = jax_coilset.position(jnp.linspace(0,1, number_of_samples_per_coil, endpoint=False)) currents_stacked = jnp.stack([currents] * number_of_samples_per_coil, axis=-1) coil_diff = jnp.roll(coil_samples, -1, axis=1) - coil_samples return currents_stacked.reshape(-1) , coil_samples.reshape(-1,3), coil_diff.reshape(-1,3)