Source code for jax_sbgeom.jax_utils.utils

import jax 
import jax.numpy as jnp

import numpy as onp
[docs] def stack_jacfwd(fun, argnums): jacfwd_internal = jax.jacfwd(fun, argnums = argnums) def jac_stack_wrap(*args): return jnp.stack(jacfwd_internal(*args), axis=-1) return jac_stack_wrap