Source code for jax_sbgeom.jax_utils.raytracing

import jax 
import jax.numpy as jnp
from typing import Type
from dataclasses import dataclass
from functools import partial

def _get_norm_centroids(positions : jnp.ndarray, connectivity : jnp.ndarray) -> jnp.ndarray:
    '''
    Get normalized centroids for the triangles defined by the connectivity on the positions.

    Parameters
    ------------
    positions : jnp.ndarray
        (N, 3) array of vertex positions
    connectivity : jnp.ndarray
        (M, 3) array of connectivity indices
    Returns
    ------------
    normalized_centroids : jnp.ndarray
        (M, 3) array of normalized centroids
    '''
    centroids = jnp.mean(positions[connectivity], axis=1)  # (M, 3)

    r_min     = jnp.min(centroids, axis=0)
    r_max     = jnp.max(centroids, axis=0)  

    safe_divisor = jnp.where(r_max - r_min == 0, 1.0, r_max - r_min)

    normalized_centroids = (centroids - r_min) / safe_divisor    
    return normalized_centroids


def _create_morton_codes(normalized_positions : jnp.ndarray) -> jnp.ndarray:
    '''
    Create 32 bit morton codes for the triangles defined by the connectivity on the positions.

    See https://developer.nvidia.com/blog/thinking-parallel-part-iii-tree-construction-gpu/

    Parameters
    ------------
    positions : jnp.ndarray
        (N, 3) array of vertex positions
    connectivity : jnp.ndarray
        (M, 3) array of connectivity indices
    Returns
    ------------
    morton_codes : jnp.ndarray
        (M,) array of morton codes
    '''    
    N_BIT = 10
    clip_max = jnp.astype(2 ** N_BIT - 1, jnp.uint32)
    value_scale = 2.0**N_BIT

    def expand_bits(v):
        # ensure unsigned 32-bit
        v = jnp.asarray(v, dtype=jnp.uint32)

        # Use JAX integer constants (avoid raw Python ints in bitwise expressions)
        c1 = jnp.uint32(0x00010001)  # 65537
        m1 = jnp.uint32(0xFF0000FF)
        c2 = jnp.uint32(0x00000101)  # 257
        m2 = jnp.uint32(0x0F00F00F)
        c3 = jnp.uint32(0x00000011)  # 17
        m3 = jnp.uint32(0xC30C30C3)
        c4 = jnp.uint32(0x00000005)  # 5
        m4 = jnp.uint32(0x49249249)

        v = (v * c1) & m1
        v = (v * c2) & m2
        v = (v * c3) & m3
        v = (v * c4) & m4

        return v


    
    def morton_3d(norm_centroids):
        int_coords = jnp.clip(jnp.floor(norm_centroids * value_scale).astype(jnp.uint32), 0, clip_max)
        return expand_bits(int_coords[:, 0]) * 4 + expand_bits(int_coords[:, 1]) * 2 + expand_bits(int_coords[:, 2])
    
    morton_codes  = morton_3d(normalized_positions)
    return morton_codes

def _common_prefix_length(a : jnp.ndarray, b : jnp.ndarray) -> jnp.ndarray:
    '''
    Compute the common prefix length between two unsigned integers a and b.

    Uses jax.lax.clz

    Parameters
    -----------
    a : jnp.ndarray
        Unsigned integer array
    b : jnp.ndarray
        Unsigned integer array
    Returns
    ----------- 
    jnp.ndarray
        Common prefix length array (scalar)

    '''
    return jax.lax.clz(jnp.bitwise_xor(a,b))


def _delta_ij(sorted_morton_codes : jnp.ndarray, i : int, j : int, int_dtype : jnp.dtype = jnp.int32):
    '''
    Compute the delta_ij function as defined in [1].


    [1] Karras, T. (2012, June). Maximizing parallelism in the construction of BVHs, octrees, and k-d trees. In Proceedings of the Fourth ACM SIGGRAPH/Eurographics Conference on High-Performance Graphics (pp. 33-37).

    Parameters 
    -----------
    sorted_morton_codes : jnp.ndarray
        Sorted morton codes array
    i : int
        Index i
    j : int
        Index j
    int_dtype : jnp.dtype
        Integer dtype for the output
    Returns
    -----------
    jnp.ndarray
        Delta_ij value (scalar)    
    '''

    # To ensure no OOB access, first clip the indices
    i_safe = jnp.clip(i, 0, sorted_morton_codes.shape[0]-1)
    j_safe = jnp.clip(j, 0, sorted_morton_codes.shape[0]-1)

    # However, we need to have a mask to indicate whether the original indices were OOB. If so, the result should be -1
    safe_mask = jnp.logical_and( jnp.logical_and( i >= 0, i < sorted_morton_codes.shape[0]), 
                                 jnp.logical_and( j >= 0, j < sorted_morton_codes.shape[0]))
    
    # Furthermore, if the morton codes are the same, we should fallback to the _common_prefix_length + the base common prefix length of the indices themselves.
    # (jnp.where computes eagerly anyway, so need to compute all of them)
    equal_mask                   = jnp.equal(sorted_morton_codes[i_safe], sorted_morton_codes[j_safe])
    common_prefix_length_morton  = jnp.astype(_common_prefix_length(sorted_morton_codes[i_safe],              sorted_morton_codes[j_safe]), int_dtype)
    common_prefix_length_indices = jnp.astype(_common_prefix_length(i_safe.astype(sorted_morton_codes.dtype), j_safe.astype(sorted_morton_codes.dtype)), int_dtype)
    
    return jnp.where(safe_mask,                 
                        jnp.where(equal_mask, 
                                  common_prefix_length_morton + common_prefix_length_indices,
                                  common_prefix_length_morton
                                 )
                     , -1
                     )

@jax.jit
def _create_parallel_binary_radix_tree(morton_codes : jnp.ndarray):
    '''
    Create a parallel binary radix tree as defined in [1].
     
    [1] Karras, T. (2012, June). Maximizing parallelism in the construction of BVHs, octrees, and k-d trees. In Proceedings of the Fourth ACM SIGGRAPH/Eurographics Conference on High-Performance Graphics (pp. 33-37).

    Parameters
    '''
    # See https://developer.nvidia.com/blog/parallelforall/wp-content/uploads/2012/11/karras2012hpg_paper.pdf

    N = morton_codes.shape[0]                      
    morton_order        = jnp.argsort(morton_codes) # int_dtype
    sorted_morton_codes = morton_codes[morton_order]# uint_dtype

    morton_order_type   = morton_order.dtype        # type(morton_order): used for indexing
    
    internal_nodes = N - 1 # size_type
    internal_idx   = jnp.arange(0, internal_nodes, dtype = morton_order_type) # size_type

    max_doublings = jnp.ceil(jnp.log2(N)).astype(morton_order_type) + 2
    l0            = jnp.astype(1, morton_order_type)

    def internal_node_function(idx):
        d = jnp.sign(_delta_ij(sorted_morton_codes, idx + 1, idx) - _delta_ij(sorted_morton_codes, idx, idx - 1))        
        delta_min = _delta_ij(sorted_morton_codes, idx, idx - d)
        
        # Exponential search to find upper bound on lmax
        def lmax_condition(state):
            l, step = state
            return jnp.logical_and(_delta_ij(sorted_morton_codes, idx, idx + d * l) > delta_min, step < max_doublings)
        def lmax_body(state):
            l, step = state            
            return l * 2, step + 1
        
        l_final, _ = jax.lax.while_loop(lmax_condition, lmax_body, (l0, 0))        
        
        def binary_search_condition(state):
            t, _ = state
            return t > 0
        
        def binary_search_body(state):
            t, l = state
            l_carry = jax.lax.cond(_delta_ij(sorted_morton_codes, idx, idx + d * (l+t))> delta_min, lambda _ : l + t, lambda _ : l, operand=None)
            t_carry = t // 2
            return t_carry, l_carry
        
        t_final, l_final = jax.lax.while_loop(binary_search_condition, binary_search_body, (l_final //2 , jnp.astype(0, morton_order_type)))
        
        j_idx = idx + d * l_final

        # Split position

        delta_node = _delta_ij(sorted_morton_codes, idx, j_idx)
        
        s = jnp.astype(0, morton_order_type)

        def split_condition(state):
            _, _, t_one_done = state
            
            return ~t_one_done # this continues including the t=1 case just once
        
        def split_body(state):
            t, s, _    = state 
            
            s_carry = jax.lax.cond(_delta_ij(sorted_morton_codes, idx, idx + d * (s + t)) > delta_node, lambda _ : s + t, lambda _ : s, operand=None)
            t_carry = (t+1)  // 2
            return t_carry, s_carry, jnp.equal(t, 1)
        
        t_final_split, s_final, _ = jax.lax.while_loop(split_condition, split_body, ((l_final+1) //2 , s, False))        
        gamma                     = idx + s_final * d + jnp.minimum(d, 0 )        
        
        left_idx               = jax.lax.cond(jnp.minimum(idx, j_idx) == gamma, lambda _ : gamma, lambda _ : gamma + N, operand=None)
        right_idx              = jax.lax.cond(jnp.maximum(idx, j_idx) == gamma + 1, lambda _ : gamma + 1, lambda _ : gamma + 1 + N, operand=None)
        return left_idx, right_idx    
    
    left_idx, right_idx = jax.vmap(internal_node_function)(internal_idx)

    total_left_idx = jnp.concatenate([jnp.arange(N, dtype = morton_order_type), left_idx])
    total_right_idx = jnp.concatenate([jnp.arange(N, dtype = morton_order_type), right_idx])
    return total_left_idx, total_right_idx, morton_order

@jax.jit
def _check_binary_radix_tree(left_idx, right_idx):
    '''
    Convenience function to check whether a binary radix tree is actually valid (i.e. one can visit all nodes from the leafs to the root).

    Parameters
    left_idx : jnp.ndarray
        Left child indices for each node
    right_idx : jnp.ndarray
        Right child indices for each node
    '''
    N_leaves = (left_idx.shape[0] +1) //2
    
    initial_available = jnp.zeros((2 *  N_leaves - 1), dtype = jnp.bool) 
    initial_available = initial_available.at[:N_leaves].set(True)

    def condition(state):
        total_computed, i_loop  = state
        return jnp.any(~total_computed[N_leaves:]) & (i_loop < 2 * N_leaves)
    
    def single_iteration(state):
        computed, i_loop = state
        left_available  = computed[left_idx ]
        right_available = computed[right_idx]
        computed       = (left_available & right_available)
        return computed, i_loop + 1
    
    state = (initial_available, 0)
    final_state = jax.lax.while_loop(condition, single_iteration, state)
    total_computed, _ = final_state
    return jnp.sum(total_computed) == total_computed.shape[0]

@jax.jit
def _create_aabb(primitive_coordinates : jnp.ndarray) -> jnp.ndarray:
    '''
    Create axis-aligned bounding box for a given primitive.

    Parameters
    ----------
    primitive:  jnp.ndarray of shape (N, M, 3)
        N coordinates of a primitive with shape M

    Returns
    -------
    jnp.ndarray of shape (N, 2, 3)
      representing min and max corners of the AABB.
    '''
    min_corner = jnp.min(primitive_coordinates, axis=-2)
    max_corner = jnp.max(primitive_coordinates, axis=-2)
    return jnp.stack([min_corner, max_corner], axis = 1)

[docs] @jax.tree_util.register_dataclass @dataclass(frozen=True) class BVH: ''' Dataclass representing a Bounding Volume Hierarchy (BVH). Build with `build_lbvh`. ''' left_idx : jnp.ndarray # left child indices for each node right_idx : jnp.ndarray # right child indices for each node aabb : jnp.ndarray # AABB for each node leafs : jnp.ndarray # bool array order : jnp.ndarray # order from primitives -> BVH inverse_order : jnp.ndarray # order from BVH -> primitives
[docs] @jax.jit def build_lbvh(positions, connectivity): ''' Build a Linear Bounding Volume Hierarchy (LBVH) for a given set of 3D positions and connectivity. Parameters ---------- positions : jnp.ndarray Array of shape (N, 3) with point coordinates. connectivity : jnp.ndarray Array of shape (M, K) with primitive connectivity indices. Returns ------- BVH LBVH structure with child indices, AABBs, leaf mask, and ordering maps. ''' normalized_centroids = _get_norm_centroids(positions, connectivity) morton_codes = _create_morton_codes(normalized_centroids) N = morton_codes.shape[0] left_idx, right_idx, morton_order = _create_parallel_binary_radix_tree(morton_codes) aabb_leaves = _create_aabb(positions[connectivity]) aabb_leaves_sorted = aabb_leaves[morton_order] left_internal = left_idx [N:] right_internal = right_idx[N:] def condition(state): return ~jnp.all(state[1]) def compute_aabb_compound(aabb, left_i, right_i): return jnp.stack([ jnp.minimum(aabb[left_i,0], aabb[right_i,0]), jnp.maximum(aabb[left_i,1], aabb[right_i,1]) ], axis = 1) def compute(state): aabb, computed = state new_aabb = compute_aabb_compound(aabb, left_internal, right_internal) new_computed = computed[left_internal] & computed[right_internal] return jnp.concatenate([aabb_leaves_sorted, new_aabb], axis=0), jnp.concatenate([jnp.ones((N,), dtype=bool), new_computed], axis=0) initial_aabb = jnp.concatenate([aabb_leaves_sorted, jnp.zeros((N-1,2,3))], axis=0) initial_available = jnp.concatenate([jnp.ones((N,), dtype=bool), jnp.zeros((N-1,), dtype=bool)], axis=0) aabb_total, _ = jax.lax.while_loop(condition, compute, (initial_aabb, initial_available)) return BVH(left_idx, right_idx, aabb_total, initial_available, morton_order, jnp.argsort(morton_order))
def _point_in_aabb(point : jnp.ndarray, aabb : jnp.ndarray) -> jnp.ndarray: """Return whether a point lies inside an axis-aligned bounding box.""" min_corner = aabb[0] max_corner = aabb[1] inside = jnp.all((point >= min_corner) & (point <= max_corner)) return inside points_in_aabbs = jax.vmap(jax.vmap(_point_in_aabb, in_axes=(0, None)), in_axes = (None, 0)) points_in_aabb = jax.vmap(_point_in_aabb, in_axes=(0, None)) points_in_aabbvec = jax.vmap(_point_in_aabb, in_axes=(0, 0)) @partial(jax.jit, static_argnums = (2,3)) def _probe_bvh_imp(bvh : BVH, points : jnp.ndarray, stack_size : int = 64, max_hit_size : int = 64): ''' Probe a BVH with a set of points to find which AABBs contain the points. Parameters ----------- bvh : BVH The BVH to probe. points : jnp.ndarray The points to probe the BVH with. Shape (N_points, 3) stack_size : int The size of the stack to use for traversal. max_hit_size : int The maximum number of hits to record per point. Returns ----------- jnp.ndarray An array of shape (N_points, max_hit_size) containing the indices of the AABBs that contain each point. If a point hits fewer than max_hit_size AABBs, the remaining entries are filled with -1. int The number of traversal loops performed. ''' N_leafs = (bvh.leafs.shape[0] + 1) // 2 N_points = points.shape[0] n_points_arange = jnp.arange(N_points) @jax.jit def condition(state): stack_idx, _, _, _, loop_idx = state return jnp.any(stack_idx >= 0 ) & (loop_idx < N_leafs) # Leaf def loop(state): stack_idx, stack, hits, no_hits, loop_idx = state current_idx = stack[n_points_arange, stack_idx] left_idx = bvh.left_idx[ current_idx] right_idx = bvh.right_idx[current_idx] left_contains = points_in_aabbvec(points, bvh.aabb[left_idx]) # shape (N_points, ) right_contains = points_in_aabbvec(points, bvh.aabb[right_idx]) # shape (N_points, ) hits_with_left = jnp.where(jnp.logical_and(left_contains, bvh.leafs[left_idx])[:,None], hits.at[n_points_arange, no_hits].set(left_idx), hits) no_hits_with_left = jnp.where(jnp.logical_and(left_contains, bvh.leafs[left_idx]), no_hits + 1, no_hits) hits_with_both = jnp.where(jnp.logical_and(right_contains, bvh.leafs[right_idx])[:,None], hits_with_left.at[n_points_arange, no_hits_with_left].set(right_idx), hits_with_left) no_hits_with_both = jnp.where(jnp.logical_and(right_contains, bvh.leafs[right_idx]), no_hits_with_left + 1, no_hits_with_left) traverse_left = jnp.logical_and(left_contains, jnp.logical_not(bvh.leafs[left_idx])) traverse_right = jnp.logical_and(right_contains, jnp.logical_not(bvh.leafs[right_idx])) transverse_any = jnp.logical_or(traverse_left, traverse_right) transverse_both = jnp.logical_and(traverse_left, traverse_right) new_stack_idx = stack_idx + transverse_both.astype(jnp.int32) - (~transverse_any).astype(jnp.int32) new_stack = jnp.where( transverse_any[:, None], jnp.where( transverse_both[:, None], stack.at[n_points_arange, stack_idx + 1].set(right_idx).at[n_points_arange, stack_idx].set(left_idx), jnp.where( traverse_left[:, None], stack.at[n_points_arange, stack_idx].set(left_idx), stack.at[n_points_arange, stack_idx].set(right_idx) ) ), stack.at[n_points_arange, stack_idx].set(-1) ) return new_stack_idx, new_stack, hits_with_both, no_hits_with_both, loop_idx + 1 N_points = points.shape[0] stack = jnp.full((N_points, stack_size, ), -1) hits = jnp.full((N_points, max_hit_size,), -1) initial_stack = stack.at[..., 0].set(N_leafs) # start with root node initial_state = (jnp.zeros(N_points, dtype=int), initial_stack, hits,jnp.zeros(N_points, dtype= int),0) final_stack_idx, final_stack, final_hits, final_no_hits, n_loops = jax.lax.while_loop( condition, loop, initial_state) return final_hits, n_loops
[docs] def probe_bvh(bvh : BVH, points : jnp.ndarray, stack_size : int = 64, max_hit_size : int = 64): ''' Probe a BVH with a set of points to find which AABBs contain the points. Parameters ----------- bvh : BVH The BVH to probe. points : jnp.ndarray The points to probe the BVH with. Shape (N_points, 3) stack_size : int The size of the stack to use for traversal. max_hit_size : int The maximum number of hits to record per point. Returns ----------- jnp.ndarray An array of shape (N_points, max_hit_size) containing the indices of the AABBs of the original mesh. (not in the BVH order, but in the original order) In other words, bvh.order is indexed by the resulting output of _probe_bvh_imp. If no hit, -1 is returned. ''' final_hits, n_loops = _probe_bvh_imp(bvh, points, stack_size, max_hit_size) return jnp.where(final_hits >=0, bvh.order[final_hits], -1)##[:,0] #[hit_i[hit_i >=0] for hit_i in final_hits] if we want variable length outputs
[docs] def ray_intersects_aabb(origin : jnp.ndarray, direction : jnp.ndarray, aabb : jnp.ndarray): ''' Vectorized function to check ray-AABB intersections using the slab method. Parameters ----------- origin : jnp.ndarray Ray origin of shape (..., 3) direction : jnp.ndarray Ray direction of shape (..., 3) aabb : jnp.ndarray Axis-aligned bounding box of shape (..., 2, 3) Returns ----------- jnp.ndarray Boolean array indicating whether the ray intersects the AABB. ''' inv_dir = jnp.where(direction == 0.0, jnp.inf, 1.0 / direction) # precompute to avoid divides tmin = (aabb[..., 0, :] - origin) * inv_dir tmax = (aabb[..., 1, :] - origin) * inv_dir # if direction is negative, swap t1 = jnp.minimum(tmin, tmax) t2 = jnp.maximum(tmin, tmax) t_enter = jnp.max(t1, axis=-1) t_exit = jnp.min(t2, axis=-1) hit = (t_exit >= jnp.maximum(t_enter, 0.0)) hit = jnp.logical_or(hit, jnp.any(jnp.isnan(jnp.array([t_enter, t_exit])))) # ugly; but for slab parallel it ensures it traverses the BVH further and does not miss triangles. # this only happens in very very rare cases anyway; almost never is the direction slab parallel (i.e. exactly 0.0 in one of the components). return hit#, t_enter, t_exit
[docs] @jax.jit def ray_traversal_bvh(bvh : BVH, points : jnp.ndarray, directions : jnp.ndarray, stack_size : int = 64, max_hit_size : int = 64): ''' Traverse a BVH with a set of rays defined by points and directions. Parameters ----------- bvh : BVH The BVH to traverse. points : jnp.ndarray The ray origins. Shape (N_points, 3) directions : jnp.ndarray The ray directions. Shape (N_points, 3) stack_size : int The size of the stack to use for traversal. max_hit_size : int The maximum number of hits to record per ray. Returns ----------- jnp.ndarray An array of shape (N_points, max_hit_size) containing the indices of the AABBs that the rays intersect. If a ray hits fewer than max_hit_size AABBs, the remaining entries are filled with -1. ''' N_leafs = (bvh.leafs.shape[0] + 1) // 2 N_points = points.shape[0] n_points_arange = jnp.arange(N_points) @jax.jit def condition(state): stack_idx, _, _, _, loop_idx = state return jnp.any(stack_idx >= 0 ) & (loop_idx < N_leafs) def loop(state): stack_idx, stack, hits, no_hits, loop_idx = state current_idx = stack[n_points_arange, stack_idx] left_idx = bvh.left_idx[ current_idx] right_idx = bvh.right_idx[current_idx] left_contains = ray_intersects_aabb(points, directions, bvh.aabb[left_idx]) # shape (N_points, ) right_contains = ray_intersects_aabb(points, directions, bvh.aabb[right_idx]) # shape (N_points, ) hits_with_left = jnp.where(jnp.logical_and(left_contains, bvh.leafs[left_idx])[:,None], hits.at[n_points_arange, no_hits].set(left_idx), hits) no_hits_with_left = jnp.where(jnp.logical_and(left_contains, bvh.leafs[left_idx]), no_hits + 1, no_hits) hits_with_both = jnp.where(jnp.logical_and(right_contains, bvh.leafs[right_idx])[:,None], hits_with_left.at[n_points_arange, no_hits_with_left].set(right_idx), hits_with_left) no_hits_with_both = jnp.where(jnp.logical_and(right_contains, bvh.leafs[right_idx]), no_hits_with_left + 1, no_hits_with_left) traverse_left = jnp.logical_and(left_contains, jnp.logical_not(bvh.leafs[left_idx])) traverse_right = jnp.logical_and(right_contains, jnp.logical_not(bvh.leafs[right_idx])) transverse_any = jnp.logical_or(traverse_left, traverse_right) transverse_both = jnp.logical_and(traverse_left, traverse_right) new_stack_idx = stack_idx + transverse_both.astype(jnp.int32) - (~transverse_any).astype(jnp.int32) new_stack = jnp.where( transverse_any[:, None], jnp.where( transverse_both[:, None], stack.at[n_points_arange, stack_idx + 1].set(right_idx).at[n_points_arange, stack_idx].set(left_idx), jnp.where( traverse_left[:, None], stack.at[n_points_arange, stack_idx].set(left_idx), stack.at[n_points_arange, stack_idx].set(right_idx) ) ), stack.at[n_points_arange, stack_idx].set(-1) ) return new_stack_idx, new_stack, hits_with_both, no_hits_with_both, loop_idx + 1 N_points = points.shape[0] stack = jnp.full((N_points, stack_size, ), -1) hits = jnp.full((N_points, max_hit_size,), -1) initial_stack = stack.at[..., 0].set(N_leafs) # start with root node initial_state = (jnp.zeros(N_points, dtype=int), initial_stack, hits,jnp.zeros(N_points, dtype= int),0) final_stack_idx, final_stack, final_hits, final_no_hits, n_loops = jax.lax.while_loop( condition, loop, initial_state) return final_hits
[docs] @jax.jit def ray_traversal_bvh_single(bvh : BVH, point : jnp.ndarray, direction : jnp.ndarray, stack_size : int = 128, max_hit_size : int = 128): ''' Traverse a BVH with a single ray defined by point and direction. Use ray_traversal_bvh_vectorized to handle multiple rays. This is only ~10% slower than ray_traversal_bvh for large number of rays. Parameters ----------- bvh : BVH The BVH to traverse. point : jnp.ndarray The ray origin. Shape (3,) direction : jnp.ndarray The ray direction. Shape (3,) stack_size : int The size of the stack to use for traversal. max_hit_size : int The maximum number of hits to record. Returns ----------- jnp.ndarray An array of shape (max_hit_size,) containing the indices of the AABBs that the ray intersects. If the ray hits fewer than max_hit_size AABBs, the remaining entries are filled with -1. ''' N_leafs = (bvh.leafs.shape[0] + 1) // 2 @jax.jit def condition(state): stack_idx, _, _, _, loop_idx = state return jnp.any(stack_idx >= 0 ) & (loop_idx < N_leafs) def loop(state): stack_idx, stack, hits, no_hits, loop_idx = state current_idx = stack[stack_idx] left_idx = bvh.left_idx[ current_idx] right_idx = bvh.right_idx[current_idx] left_contains = ray_intersects_aabb(point, direction, bvh.aabb[left_idx]) # shape (N_points, ) right_contains = ray_intersects_aabb(point, direction, bvh.aabb[right_idx]) # shape (N_points, ) hits_with_left = jnp.where(jnp.logical_and(left_contains, bvh.leafs[left_idx]), hits.at[no_hits].set(left_idx), hits) no_hits_with_left = jnp.where(jnp.logical_and(left_contains, bvh.leafs[left_idx]), no_hits + 1, no_hits) hits_with_both = jnp.where(jnp.logical_and(right_contains, bvh.leafs[right_idx]), hits_with_left.at[no_hits_with_left].set(right_idx), hits_with_left) no_hits_with_both = jnp.where(jnp.logical_and(right_contains, bvh.leafs[right_idx]), no_hits_with_left + 1, no_hits_with_left) traverse_left = jnp.logical_and(left_contains, jnp.logical_not(bvh.leafs[left_idx])) traverse_right = jnp.logical_and(right_contains, jnp.logical_not(bvh.leafs[right_idx])) transverse_any = jnp.logical_or(traverse_left, traverse_right) transverse_both = jnp.logical_and(traverse_left, traverse_right) new_stack_idx = stack_idx + transverse_both.astype(jnp.int32) - (~transverse_any).astype(jnp.int32) new_stack = jnp.where( transverse_any, jnp.where( transverse_both, stack.at[stack_idx + 1].set(right_idx).at[stack_idx].set(left_idx), jnp.where( traverse_left, stack.at[ stack_idx].set(left_idx), stack.at[ stack_idx].set(right_idx) ) ), stack.at[ stack_idx].set(-1) ) return new_stack_idx, new_stack, hits_with_both, no_hits_with_both, loop_idx + 1 N_points = point.shape[0] stack = jnp.full(( stack_size, ), -1) hits = jnp.full(( max_hit_size,), -1) initial_stack = stack.at[..., 0].set(N_leafs) # start with root node initial_state = (0, initial_stack, hits,0,0) final_stack_idx, final_stack, final_hits, final_no_hits, n_loops = jax.lax.while_loop( condition, loop, initial_state) return final_hits
ray_traversal_bvh_vectorized = jax.jit(jnp.vectorize(ray_traversal_bvh_single, excluded=(0,), signature='(3),(3)->(max_hit_size)')) # =========================================================================================================
[docs] @jax.jit def ray_triangle_intersection_single(point : jnp.ndarray, direction : jnp.ndarray, triangle : jnp.ndarray, eps=1e-8, eps_bary_center = 1e-8): """ Compute intersections of one ray with one trianlge Parameters ---------- point: jnp.ndarray Ray origin. Shape (3,) direction: jnp.ndarray Ray direction. Shape (3,) triangle: jnp.ndarray Triangle vertices. Shape (3, 3) Returns ------- t: distance along ray (jnp.inf if no hit) """ v0 = triangle[0, :] v1 = triangle[1, :] v2 = triangle[2, :] e1 = v1 - v0 # (3,) e2 = v2 - v0 # (3,) pvec = jnp.cross(direction, e2) #(3,) det = jnp.dot(e1, pvec) #(,) valid_det = jnp.abs(det) > eps #(,) inv_det = jnp.where(valid_det, 1.0 / det, 0.0) #(,) tvec = point - v0 # (3,) u = jnp.dot(tvec, pvec) * inv_det # (,) qvec = jnp.cross(tvec, e1) # (3,) v = jnp.dot(direction, qvec) * inv_det # (,) t = jnp.dot(e2, qvec) * inv_det #(,) mask = (valid_det & (u >= -eps_bary_center) & (v >= -eps_bary_center) & ((u + v) <= 1.0 + eps_bary_center) & (t > eps)) t = jnp.where(mask, t, jnp.inf) return t
ray_triangle_intersection_vectorized = jax.jit(jnp.vectorize(ray_triangle_intersection_single, signature='(3),(3),(3,3)->()'))
[docs] @jax.jit def find_minimum_distance_to_mesh(points : jnp.ndarray, directions : jnp.ndarray, mesh): ''' Convenience function to find the minimum distance from rays to a triangle mesh. Parameters ----------- points : jnp.ndarray Ray origins. Shape (N_points, 3) directions : jnp.ndarray Ray directions. Shape (N_points, 3) mesh : tuple of (positions, connectivity) positions: jnp.ndarray of shape (M, 3) representing the 3D coordinates of mesh vertices. connectivity: jnp.ndarray of shape (K, 3) representing the triangle connectivity of the mesh. Returns ----------- jnp.ndarray An array of shape (N_points,) containing the minimum distance from each ray to the mesh. If no intersection occurs, the distance is jnp.inf. ''' bvh = build_lbvh(mesh[0], mesh[1]) # BVH hits_possible = ray_traversal_bvh_vectorized(bvh, points, directions) mesh_total = jnp.moveaxis(mesh[0][mesh[1][bvh.order[hits_possible]]], -3, 0) # we move the possible hits to the front. return jnp.nanmin(ray_triangle_intersection_vectorized(points, directions, mesh_total), axis=0)
# ========================================================================================================== # Closest point on triangle #===========================================================================================================
[docs] def closest_point_on_triangle(p : jnp.ndarray, a : jnp.ndarray, b : jnp.ndarray, c : jnp.ndarray): ''' Compute the closest point on triangle abc to point p. Not vectorized itself. Reimplements the Embree reference algorithm without explicit if conditionals. Parameters ---------- p : jnp.ndarray (3,) array of point coordinates a : jnp.ndarray (3,) array of triangle vertex a coordinates b : jnp.ndarray (3,) array of triangle vertex b coordinates c : jnp.ndarray (3,) array of triangle vertex c coordinates Returns ------- closest_point : jnp.ndarray (3,) array of closest point coordinates ''' # Edge vectors ab = b - a ac = c - a ap = p - a d1 = jnp.dot(ab, ap) d2 = jnp.dot(ac, ap) d1_d2_cond = (d1 <= 0.0) & (d2 <= 0.0) d1_d2_cond_vec = a bp = p - b d3 = jnp.dot(ab, bp) d4 = jnp.dot(ac, bp) d3_d4_cond = (d3 >= 0.0) & (d4 <= d3) d3_d4_cond_vec = b cp = p - c d5 = jnp.dot(ab, cp) d6 = jnp.dot(ac, cp) d5_d6_cond = (d6 >= 0.0) & (d5 <= d6) d5_d6_cond_vec = c # Edge conditions vc = d1 * d4 - d3 * d2 ab_edge_cond = (vc <= 0.0) & (d1 >= 0.0) & (d3 <= 0.0) ab_edge_vec = a + d1 / (d1 - d3) * ab vb = d5 * d2 - d1 * d6 ac_edge_cond = (vb <= 0.0) & (d2 >= 0.0) & (d6 <= 0.0) ac_edge_vec = a + d2 /(d2-d6) * ac va = d3 * d6 - d5 * d4 bc_edge_cond = (va <=0.0) & ( (d4-d3) >=0.0) & ( (d5 -d6) >=0.0) bc_edge_vec = b + (d4 - d3) /((d4 - d3) + (d5 - d6)) * (c-b) # Inside Triangle denom = 1.0 / (va + vb + vc) v = vb * denom w = vc * denom triangle_vec = a + v * ab + w * ac return jnp.where( d1_d2_cond, d1_d2_cond_vec, jnp.where( d3_d4_cond, d3_d4_cond_vec, jnp.where( d5_d6_cond, d5_d6_cond_vec, jnp.where( ab_edge_cond, ab_edge_vec, jnp.where( ac_edge_cond, ac_edge_vec, jnp.where( bc_edge_cond, bc_edge_vec, triangle_vec ) ) ) ) ) )
def _point_aabb_distance(p : jnp.ndarray, aabb : jnp.ndarray): ''' Compute squared distance from point to AABB. Returns 0 if point is inside AABB. Parameters ----------- p : jnp.ndarray (3,) array of point coordinates aabb : jnp.ndarray (2, 3) array of AABB min and max corner coordinates Returns ----------- distance : jnp.ndarray Squared distance from point to AABB ''' # Distance from point to box (0 if inside) diff = jnp.maximum(0.0, jnp.maximum(aabb[..., 0, :] - p, p - aabb[..., 1, :])) return jnp.sum(diff**2) @jax.jit def _bvh_closest_point(bvh : BVH, point: jnp.ndarray, mesh, stack_size : int = 64, max_hit_size : int = 64): ''' Get closest point on triangle mesh for a single point using BVH traversal. For multiple points, use bvh_closest_point_vectorized (handles arbitrary point shapes) Parameters ---------- bvh : BVH BVH structure. point : jnp.ndarray Query point with shape (3,). mesh : tuple Mesh tuple `(positions, connectivity)` with triangle connectivity. Returns -------- closest_point : jnp.ndarray Closest point on the mesh. distance : jnp.ndarray Distance from query point to closest point. triangle_index : jnp.ndarray Triangle index of the closest point. ''' N_leafs = (bvh.leafs.shape[0] + 1) // 2 surface_points = mesh[0][mesh[1]] # (N_triangles, 3, 3) def condition(state): stack_idx, _, _, _, loop_idx = state return jnp.any(stack_idx >= 0 ) & (loop_idx < N_leafs) def loop(state): stack_idx, stack, d_min, d_idx, loop_idx = state current_idx = stack[stack_idx] left_idx = bvh.left_idx[ current_idx] right_idx = bvh.right_idx[current_idx] left_aabb_dist = _point_aabb_distance(point, bvh.aabb[left_idx] ) right_aabb_dist= _point_aabb_distance(point, bvh.aabb[right_idx]) left_better = left_aabb_dist< d_min # shape (N_points, ) right_better = right_aabb_dist< d_min # shape (N_points, ) left_is_leaf = bvh.leafs[left_idx] right_is_leaf= bvh.leafs[right_idx] def set_if_leaf(d_min, d_idx, idx, is_leaf): # if it is clipped it just returns the distance to the last triangle. # this gets masked out anyway by is_leaf. index_in_mesh = bvh.order[jnp.clip(idx, 0, N_leafs -1)] leaf = surface_points[index_in_mesh] closest_leaf = closest_point_on_triangle(point, leaf[0], leaf[1], leaf[2]) d_leaf = jnp.sum((closest_leaf - point)**2) d_idx_new = jnp.where(is_leaf & (d_leaf < d_min), index_in_mesh, d_idx) d_min_new = jnp.where(is_leaf & (d_leaf < d_min), d_leaf, d_min) return d_min_new, d_idx_new d_min_left, d_idx_left = set_if_leaf(d_min, d_idx, left_idx, left_is_leaf) d_min_both, d_idx_both = set_if_leaf(d_min_left, d_idx_left, right_idx, right_is_leaf) traverse_left = jnp.logical_and(left_better, jnp.logical_not(bvh.leafs[left_idx])) traverse_right = jnp.logical_and(right_better, jnp.logical_not(bvh.leafs[right_idx])) transverse_any = jnp.logical_or(traverse_left, traverse_right) transverse_both = jnp.logical_and(traverse_left, traverse_right) # if we hit one, stack stays the same, if we hit both, we add one, if we hit none, we go one lower new_stack_idx = stack_idx + transverse_both.astype(jnp.int32) - (~transverse_any).astype(jnp.int32) new_stack = jnp.where( transverse_any, jnp.where( transverse_both, stack.at[stack_idx + 1].set(right_idx).at[stack_idx].set(left_idx), jnp.where( traverse_left, stack.at[stack_idx].set(left_idx), stack.at[stack_idx].set(right_idx) ) ), stack.at[stack_idx].set(-1) ) return new_stack_idx, new_stack, d_min_both, d_idx_both, loop_idx + 1 stack = jnp.full((stack_size, ), -1) initial_stack = stack.at[..., 0].set(N_leafs) # start with root node initial_state = (0, initial_stack,jnp.inf, -1, 0) # stack_idx, stack, hits, d_min, d_idx, loop_idx final_stack_idx, final_stack, d_min, d_idx, n_loops = jax.lax.while_loop( condition, loop, initial_state) closest_point = closest_point_on_triangle(point, surface_points[d_idx,0], surface_points[d_idx,1], surface_points[d_idx,2]) return closest_point, jnp.sqrt(d_min), d_idx bvh_closest_point_vectorized = jax.jit(jnp.vectorize(_bvh_closest_point, excluded=(0,2), signature='(3)->(3),(),()'))
[docs] @jax.jit def get_closest_points_on_mesh(points : jnp.ndarray, mesh) -> jnp.ndarray: ''' Get closest points on triangle mesh for a set of points. Parameters ----------- points: jnp.ndarray [N,3] Array of points mesh: tuple of (positions, connectivity) positions: jnp.ndarray of shape (M, 3) representing the 3D coordinates of mesh vertices. connectivity: jnp.ndarray of shape (K, 3) representing the triangle connectivity of the mesh. Returns ----------- closest_points: jnp.ndarray [N,3] Array of closest points on the mesh for each input point. distances: jnp.ndarray [N,] Squared distances from each input point to its closest point on the mesh. triangle_indices: jnp.ndarray [N,] Indices of the triangles on the mesh corresponding to the closest points. ''' bvh = build_lbvh(mesh[0], mesh[1]) # BVH return bvh_closest_point_vectorized(bvh, points, mesh)