Source code for jax_sbgeom.jax_utils.utils
import jax
import jax.numpy as jnp
import numpy as onp
[docs]
def stack_jacfwd(fun, argnums):
jacfwd_internal = jax.jacfwd(fun, argnums = argnums)
def jac_stack_wrap(*args):
return jnp.stack(jacfwd_internal(*args), axis=-1)
return jac_stack_wrap