nerva_jax.optimizers
Optimizers used to adjusts the model’s parameters based on the gradients.
Only SGD, Momentum and Nesterov variants are provided. The parser creates factory callables from textual specifications like “Momentum(mu=0.9)”.
Functions
|
Parse a textual optimizer specification into a factory function. |
Classes
|
Combines multiple optimizers to update different parameter groups. |
|
Standard gradient descent optimizer: x -= eta * grad. |
|
Gradient descent with momentum for accelerated convergence. |
|
Nesterov accelerated gradient descent optimizer. |
Minimal optimizer interface used by layers to update parameters. |
- class nerva_jax.optimizers.Optimizer[source]
Bases:
object
Minimal optimizer interface used by layers to update parameters.
- class nerva_jax.optimizers.CompositeOptimizer(optimizers: List[Optimizer])[source]
Bases:
Optimizer
Combines multiple optimizers to update different parameter groups.
- class nerva_jax.optimizers.GradientDescentOptimizer(obj, attr_x: str, attr_Dx: str)[source]
Bases:
Optimizer
Standard gradient descent optimizer: x -= eta * grad.
- class nerva_jax.optimizers.MomentumOptimizer(obj, attr_x: str, attr_Dx: str, mu: float)[source]
Bases:
GradientDescentOptimizer
Gradient descent with momentum for accelerated convergence.
- class nerva_jax.optimizers.NesterovOptimizer(obj, attr_x: str, attr_Dx: str, mu: float)[source]
Bases:
MomentumOptimizer
Nesterov accelerated gradient descent optimizer.