jax_sbgeom.jax_utils.optimize module
- class OptimizationSettings(max_iterations: int, tolerance: float)[source]
Bases:
objectSettings for optimization routines.
- max_iterations
Maximum number of iterations for the optimizer
- Type:
int
- tolerance
Tolerance for convergence
- Type:
float
- max_iterations: int
- tolerance: float
- run_lbfgs_step(params: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | Iterable[ArrayTree] | Mapping[Any, ArrayTree], opt_state: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | Iterable[ArrayTree] | Mapping[Any, ArrayTree], loss_fn: Callable[[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | Iterable[ArrayTree] | Mapping[Any, ArrayTree]], float], optimizer: GradientTransformationExtraArgs, value_and_grad_function: Callable[[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | Iterable[ArrayTree] | Mapping[Any, ArrayTree]], Tuple[float, Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | Iterable[ArrayTree] | Mapping[Any, ArrayTree]]])[source]
Run a single step of L-BFGS optimization.
- Parameters:
params (
Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray,Iterable[ArrayTree],Mapping[Any, ArrayTree]]) – Current parameters for optimizationopt_state (
Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray,Iterable[ArrayTree],Mapping[Any, ArrayTree]]) – Current optimizer stateloss_fn (
Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray,Iterable[ArrayTree],Mapping[Any, ArrayTree]]],float]) – Loss function to minimizeoptimizer (
GradientTransformationExtraArgs) – L-BFGS optimizervalue_and_grad_function (
Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray,Iterable[ArrayTree],Mapping[Any, ArrayTree]]],Tuple[float,Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray,Iterable[ArrayTree],Mapping[Any, ArrayTree]]]]) – Function to compute value and gradient of the loss function
- Returns:
optax.Params – Updated parameters after the optimization step
optax.OptState – Updated optimizer state after the optimization step
- run_optimization_lbfgs(initial_values: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | Iterable[ArrayTree] | Mapping[Any, ArrayTree], loss_fn: Callable[[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | Iterable[ArrayTree] | Mapping[Any, ArrayTree]], float], settings: OptimizationSettings)[source]
Run L-BFGS optimization on a given loss function.
- Parameters:
initial_values (
Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray,Iterable[ArrayTree],Mapping[Any, ArrayTree]]) – Initial parameters for optimizationloss_fn (
Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex,TypedNdArray,Iterable[ArrayTree],Mapping[Any, ArrayTree]]],float]) – Loss function to minimizesettings (
OptimizationSettings) – Settings for the optimization
- Returns:
Optimized parameters
- Return type:
optax.Params