Source code for jax_sbgeom.jax_utils.mesh


import jax
import jax.numpy as jnp
import numpy as onp
# ===================================================================================================================================================================================
#                                                                           Pyvista mesh conversion
# ===================================================================================================================================================================================
[docs] def mesh_to_pyvista_mesh(pts, conn = None): ''' Convert a mesh defined by pts and conn to a pyvista mesh Either pass a tuple (pts, conn) or pts and conn separately. Parameters ---------- pts : jnp.ndarray [N_points, 3] Points of the mesh conn : jnp.ndarray [N_elements, M] optional Connectivity of the mesh (triangles or tetrahedra) Returns ------- pyvista mesh ''' if type(pts) is tuple: conn = pts[1] pts = pts[0] import pyvista as pv if conn.shape[-1] ==2: points_onp = onp.array(pts) conn_onp = onp.array(conn) lines = onp.hstack([onp.full((conn_onp.shape[0], 1), 2), conn_onp]).astype(onp.int64) lines = lines.flatten() mesh = pv.PolyData(points_onp, lines=lines) return mesh elif conn.shape[-1] == 3: points_onp = onp.array(pts) conn_onp = onp.array(conn) faces_pv = onp.hstack([onp.full((conn_onp.shape[0], 1), 3), conn_onp]).astype(onp.int64) faces_pv = faces_pv.flatten() mesh = pv.PolyData(points_onp, faces_pv) return mesh elif conn.shape[-1] == 4: points_onp = onp.array(pts) conn_onp = onp.array(conn) cells = onp.hstack([onp.full((conn_onp.shape[0], 1), 4), conn_onp]).astype(onp.int64) cells = cells.flatten() mesh = pv.UnstructuredGrid(cells, onp.full(conn_onp.shape[0], 10), points_onp) return mesh elif conn.shape[-1] == 8: points_onp = onp.array(pts) conn_onp = onp.array(conn) cells = onp.hstack([onp.full((conn_onp.shape[0], 1), 8), conn_onp]).astype(onp.int64) cells = cells.flatten() mesh = pv.UnstructuredGrid(cells, onp.full(conn_onp.shape[0], 12), points_onp) return mesh else: raise ValueError("Connectivity must be triangles, tetrahedra, or hexahedra")
[docs] def vertices_to_pyvista_polyline(pts : jnp.ndarray): ''' Convert a set of points to a pyvista polyline Parameters ---------- pts : jnp.ndarray [N_points, 3] Points of the polyline Returns ------- pyvista PolyData line ''' import pyvista as pv points_onp = onp.array(pts) # Create a PolyData line from the points lines = onp.hstack([[points_onp.shape[0], *range(points_onp.shape[0])]]).astype(onp.int64) polyline = pv.PolyData(points_onp, lines=lines) return polyline
# =================================================================================================================================================================================== # Mesh utilities # ===================================================================================================================================================================================
[docs] def surface_normals_from_mesh(mesh): ''' Compute surface normals from triangular mesh Parameters ---------- mesh : tuple (pts, conn) pts : jnp.ndarray [N_points, 3] Points of the mesh conn : jnp.ndarray [N_triangles, 3] Connectivity of the mesh (triangles) Returns ------- jnp.ndarray [N_faces, 3] Normals at each face of the mesh ''' assert mesh[1].shape[-1] == 3, "Mesh must be triangular" positions_surface = mesh[0][mesh[1]] e_1 = positions_surface[..., 1, :] - positions_surface[...,0, :] e_2 = positions_surface[..., 2, :] - positions_surface[...,0, :] normal = jnp.cross(e_1, e_2, axis=-1) normals = normal / jnp.linalg.norm(normal, axis=-1, keepdims=True) return normals
[docs] def boundary_normal_vectors_from_tetrahedron(tetrahedron : jnp.ndarray): ''' Create a boundary vector from tetrahedron Parameters ----------- tetrahedron : jnp.ndarray [..., 4,3] Tetrahedron vertex locations Returns ------- boundary : jnp.ndarray [..., 4,3] Normal ''' assert tetrahedron.shape[-2] == 4 and tetrahedron.shape[-1] == 3, "Tetrahedron must have shape [...,4,3]" v0 = tetrahedron[...,1,:] - tetrahedron[...,0,:] v1 = tetrahedron[...,2,:] - tetrahedron[...,0,:] v2 = tetrahedron[...,3,:] - tetrahedron[...,0,:] v3 = tetrahedron[...,2,:] - tetrahedron[...,1,:] v4 = tetrahedron[...,3,:] - tetrahedron[...,1,:] n0 = jnp.cross(v0, v1) n1 = jnp.cross(v2, v0) n2 = jnp.cross(v1, v2) n3 = jnp.cross(v4, v3) normals = jnp.stack([n0, n1, n2, n3], axis=-2) positive_volume = jnp.sign(jnp.sum( (v0 * jnp.cross(v2, v1)), axis= -1)) # ... normals = normals / jnp.linalg.norm(normals, axis=-1, keepdims=True) #[...,4,3] return normals * positive_volume[..., None, None]
[docs] def boundary_centroids_from_tetrahedron(tetrahedron : jnp.ndarray): ''' Create the centroids of all boundaries from tetrahedron Parameters ----------- tetrahedra : jnp.ndarray [..., 4,3] Tetrahedron vertex locations Returns ------- centroids : jnp.ndarray [..., 4,3] Centroids ''' assert tetrahedron.shape[-2] == 4 and tetrahedron.shape[-1] == 3, "Tetrahedron must have shape [...,4,3]" v0 = tetrahedron[...,0,:] v1 = tetrahedron[...,1,:] v2 = tetrahedron[...,2,:] v3 = tetrahedron[...,3,:] c0 = (v0 + v1 + v2) / 3.0 c1 = (v0 + v1 + v3) / 3.0 c2 = (v0 + v2 + v3) / 3.0 c3 = (v1 + v2 + v3) / 3.0 centroids = jnp.stack([c0, c1, c2, c3,], axis=-2) return centroids
# =================================================================================================================================================================================== # Functions on meshes # ===================================================================================================================================================================================
[docs] @jax.jit def volumes_tetrahedra(positions : jnp.ndarray, connectivity : jnp.ndarray): assert connectivity.shape[-1] == 4, "Connectivity must have shape (n_tetrahedra, 4) for tetrahedral meshes." a = positions[connectivity[:,0], :] b = positions[connectivity[:,1], :] c = positions[connectivity[:,2], :] d = positions[connectivity[:,3], :] ab = b - a ac = c - a ad = d - a volumes = jnp.abs(jnp.einsum('ij,ij->i', ab, jnp.cross(ac, ad))) / 6.0 return volumes
[docs] @jax.jit def volume_of_mesh(positions : jnp.ndarray, connectivity : jnp.ndarray): if connectivity.shape[-1] == 3: # Triangle mesh calculation a = positions[connectivity[:,0], :] b = positions[connectivity[:,1], :] c = positions[connectivity[:,2], :] cross_prod = jnp.cross(b - a, c - a) volume = jnp.sum(jnp.einsum('ij,ij->i', a, cross_prod)) / 6.0 return volume elif connectivity.shape[-1] == 4: # Tetrahedral mesh calculation return jnp.abs(volumes_tetrahedra(positions, connectivity)).sum() else: raise ValueError("Connectivity must have shape (n_elements, 3) for triangles or (n_elements, 4) for tetrahedra.")