jax_sbgeom.jax_utils.optimize module

class OptimizationSettings(max_iterations: int, tolerance: float)[source]

Bases: object

Settings for optimization routines.

max_iterations

Maximum number of iterations for the optimizer

Type:

int

tolerance

Tolerance for convergence

Type:

float

max_iterations: int
tolerance: float
run_lbfgs_step(params: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | Iterable[ArrayTree] | Mapping[Any, ArrayTree], opt_state: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | Iterable[ArrayTree] | Mapping[Any, ArrayTree], loss_fn: Callable[[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | Iterable[ArrayTree] | Mapping[Any, ArrayTree]], float], optimizer: GradientTransformationExtraArgs, value_and_grad_function: Callable[[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | Iterable[ArrayTree] | Mapping[Any, ArrayTree]], Tuple[float, Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | Iterable[ArrayTree] | Mapping[Any, ArrayTree]]])[source]

Run a single step of L-BFGS optimization.

Parameters:
  • params (Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – Current parameters for optimization

  • opt_state (Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – Current optimizer state

  • loss_fn (Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], float]) – Loss function to minimize

  • optimizer (GradientTransformationExtraArgs) – L-BFGS optimizer

  • value_and_grad_function (Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], Tuple[float, Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]]) – Function to compute value and gradient of the loss function

Returns:

  • optax.Params – Updated parameters after the optimization step

  • optax.OptState – Updated optimizer state after the optimization step

run_optimization_lbfgs(initial_values: Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | Iterable[ArrayTree] | Mapping[Any, ArrayTree], loss_fn: Callable[[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | Iterable[ArrayTree] | Mapping[Any, ArrayTree]], float], settings: OptimizationSettings)[source]

Run L-BFGS optimization on a given loss function.

Parameters:
  • initial_values (Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – Initial parameters for optimization

  • loss_fn (Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex, TypedNdArray, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], float]) – Loss function to minimize

  • settings (OptimizationSettings) – Settings for the optimization

Returns:

Optimized parameters

Return type:

optax.Params