jax_sbgeom.jax_utils package

stack_jacfwd(fun, argnums)[source]
interpolate_fractions(s, nsurf)[source]

Interpolate fractions for uniform sampling

Parameters:
  • s (jnp.ndarray) – Normalized parameter(s) between 0 and 1

  • nsurf (int) – Number of samples

Returns:

  • i0 (jnp.ndarray) – Lower indices

  • i1 (jnp.ndarray) – Upper indices

  • ds (jnp.ndarray) – Fraction between i0 and i1

interpolate_array(x_interp, s)[source]

Interpolate array for uniform sampling

Parameters:
  • x_interp (jnp.ndarray [1D]) – Array to interpolate

  • s (jnp.ndarray [1D]) – Normalized parameter(s) between 0 and 1

Returns:

Interpolated array

Return type:

jnp.ndarray

interp1d_jax(x, y, x_new)[source]

Simple 1D linear interpolation function in JAX

Extrapolates flat with the boundary values.

Parameters:
  • x (jnp.ndarray [N]) – x-coordinates of the data points

  • y (jnp.ndarray [N]) – y-coordinates of the data points

  • x_new (jnp.ndarray [M]) – x-coordinates where to interpolate

Returns:

Interpolated y-coordinates at x_new

Return type:

jnp.ndarray [M]

interpolate_fractions_modulo(s, nsurf)[source]

Interpolate fractions for uniform sampling with modulo wrapping I.e., s=1 maps to index 0 again.

Parameters:
  • s (jnp.ndarray) – Normalized parameter(s) between 0 and 1

  • nsurf (int) – Number of samples

Returns:

  • i0 (jnp.ndarray) – Lower indices

  • i1 (jnp.ndarray) – Upper indices

  • ds (jnp.ndarray) – Fraction between i0 and i1

interpolate_array_modulo(x_interp, s)[source]

Interpolate array for uniform sampling with modulo wrapping I.e., s=1 maps to index 0 again. :type x_interp: jnp.ndarray [1D] :param x_interp: Array to interpolate :type x_interp: jnp.ndarray [1D] :type s: jnp.ndarray [1D] :param s: Normalized parameter(s) between 0 and 1 :type s: jnp.ndarray [1D]

Returns:

Interpolated array

Return type:

jnp.ndarray

interpolate_array_modulo_broadcasted(x_interp, s)[source]

Interpolate array for uniform sampling with modulo wrapping I.e., s=1 maps to index 0 again. This version supports broadcasting of s to higher dimensions. :type x_interp: jnp.ndarray [s.shape, interpolation_dimension, :] :param x_interp: Array to interpolate :type x_interp: jnp.ndarray [s.shape, interpolation_dimension, :] :type s: jnp.ndarray [s.shape] :param s: Normalized parameter(s) between 0 and 1 :type s: jnp.ndarray [s.shape]

Returns:

jnp.ndarray[s.shape, – Interpolated array

Return type:

]

bilinear_interp(norm_array_0, norm_array_1, interpolate_array)[source]

Bilinear interpolation for uniform sampling in 2D It is assumed that interpolate_array is defined on a uniform grid normalised to 1 in both dimensions.

Parameters:
  • norm_array_0 (jnp.ndarray [shape]) – Normalized parameter(s) between 0 and 1 in first dimension

  • norm_array_1 (jnp.ndarray [shape]) – Normalized parameter(s) between 0 and 1 in second dimension

  • interpolate_array (jnp.ndarray [N0, N1]) – Array to interpolate

Returns:

Interpolated array

Return type:

jnp.ndarray [shape]

cumulative_trapezoid_uniform_periodic(y: Array, dx: float, initial: float = 0.0)[source]

Cumulative trapezoidal integration of y with respect to x, assuming uniform spacing and periodicity.

The y is to be sampled at uniform intervals in x, with spacing dx and not including the endpoint. i.e., jnp.linspace(0, period, n_samples, endpoint=False)

Parameters:
  • y (Array) – Values to integrate

  • dx (float) – Spacing between samples in x

  • initial (float) – Initial value of the integral

Returns:

Cumulative integral of y with respect to x

Return type:

jnp.ndarray [N+1]

pchip_interpolation(x, y, x_new)[source]

Piecewise cubic Hermite interpolating polynomial (PCHIP) interpolation in JAX

Convenience function: simply calls _pchip_derivatives and _pchip_evaluation If you need derivatives or evaluation, you can jax.grad and vectorize the _pchip_evalution function. If you need multiple calls with the same x,y data, compute the derivatives once with _pchip_derivatives and pass them to _pchip_evaluation. If you need to do this for many different x,y datasets, consider vmap’ing the _pchip_derivatives function.

Parameters:
  • x (jnp.ndarray [N]) – x-coordinates of the data points

  • y (jnp.ndarray [N]) – y-coordinates of the data points

  • x_new (jnp.ndarray [M]) – x-coordinates where to interpolate

Returns:

Interpolated y-coordinates at x_new

Return type:

jnp.ndarray [M]

resample_uniform_periodic_linear(non_uniform_values: Array, n_points_desired: int)[source]

Resample non-uniform values to uniform values using linear interpolation. The input values are assumed periodic and sampled on a uniform grid.

Parameters:
  • non_uniform_values (Array) – Non-uniform values to resample

  • n_points_desired (int) – Number of desired uniformly sampled points

Returns:

Resampled uniform values

Return type:

jnp.ndarray [n_points_desired]

resample_uniform_periodic_pchip(non_uniform_values: Array, n_points_desired: int)[source]

Resample non-uniform values to uniform values using PCHIP interpolation. The input values are assumed periodic and sampled on a uniform grid.

Parameters:
  • non_uniform_values (Array) – Non-uniform values to resample

  • n_points_desired (int) – Number of desired uniformly sampled points

Returns:

Resampled uniform values

Return type:

jnp.ndarray [n_points_desired]

reverse_except_begin(array: Array)[source]
mesh_to_pyvista_mesh(pts, conn=None)[source]

Convert a mesh defined by pts and conn to a pyvista mesh

Either pass a tuple (pts, conn) or pts and conn separately.

Parameters:
  • pts (jnp.ndarray [N_points, 3]) – Points of the mesh

  • conn (jnp.ndarray [N_elements, M] optional) – Connectivity of the mesh (triangles or tetrahedra)

Return type:

pyvista mesh

surface_normals_from_mesh(mesh)[source]

Compute surface normals from triangular mesh

Parameters:

mesh (tuple (pts, conn)) –

ptsjnp.ndarray [N_points, 3]

Points of the mesh

connjnp.ndarray [N_triangles, 3]

Connectivity of the mesh (triangles)

Returns:

Normals at each face of the mesh

Return type:

jnp.ndarray [N_faces, 3]

vertices_to_pyvista_polyline(pts: Array)[source]

Convert a set of points to a pyvista polyline

Parameters:

pts (Array) – Points of the polyline

Return type:

pyvista PolyData line

boundary_centroids_from_tetrahedron(tetrahedron: Array)[source]

Create the centroids of all boundaries from tetrahedron

Parameters:

tetrahedra (jnp.ndarray [..., 4,3]) – Tetrahedron vertex locations

Returns:

centroids – Centroids

Return type:

jnp.ndarray […, 4,3]

boundary_normal_vectors_from_tetrahedron(tetrahedron: Array)[source]

Create a boundary vector from tetrahedron

Parameters:

tetrahedron (Array) – Tetrahedron vertex locations

Returns:

boundary – Normal

Return type:

jnp.ndarray […, 4,3]

Submodules