Source code for nerva_jax.optimizers

# Copyright 2023 Wieger Wesselink.
# Distributed under the Boost Software License, Version 1.0.
# (See accompanying file LICENSE or http://www.boost.org/LICENSE_1_0.txt)

"""Optimizers used to adjusts the model's parameters based on the gradients.

   Only SGD, Momentum and Nesterov variants are provided. The parser creates
   factory callables from textual specifications like "Momentum(mu=0.9)".
"""

from typing import Any, Callable, List

import jax.numpy as jnp
from nerva_jax.utilities import parse_function_call


[docs] class Optimizer(object): """Minimal optimizer interface used by layers to update parameters."""
[docs] def update(self, eta): raise NotImplementedError
[docs] class CompositeOptimizer(Optimizer): """Combines multiple optimizers to update different parameter groups.""" def __init__(self, optimizers: List[Optimizer]): self.optimizers = optimizers
[docs] def update(self, eta): """Update all contained optimizers with the given learning rate.""" for optimizer in self.optimizers: optimizer.update(eta)
def __repr__(self) -> str: optimizers_str = ", ".join(str(opt) for opt in self.optimizers) return f"CompositeOptimizer([{optimizers_str}])" __str__ = __repr__
[docs] class GradientDescentOptimizer(Optimizer): """Standard gradient descent optimizer: x -= eta * grad.""" def __init__(self, obj, attr_x: str, attr_Dx: str): """ Store the names of the x and Dx attributes """ self.obj = obj self.attr_x = attr_x self.attr_Dx = attr_Dx
[docs] def update(self, eta): """Apply gradient descent update step.""" x = getattr(self.obj, self.attr_x) Dx = getattr(self.obj, self.attr_Dx) x1 = x - eta * Dx setattr(self.obj, self.attr_x, x1)
def __repr__(self) -> str: return "GradientDescent()" __str__ = __repr__
[docs] class MomentumOptimizer(GradientDescentOptimizer): """Gradient descent with momentum for accelerated convergence.""" def __init__(self, obj, attr_x: str, attr_Dx: str, mu: float): super().__init__(obj, attr_x, attr_Dx) self.mu = mu x = getattr(self.obj, self.attr_x) self.delta_x = jnp.zeros_like(x)
[docs] def update(self, eta): """Apply momentum update step.""" x = getattr(self.obj, self.attr_x) Dx = getattr(self.obj, self.attr_Dx) self.delta_x = self.mu * self.delta_x - eta * Dx x1 = x + self.delta_x setattr(self.obj, self.attr_x, x1)
def __repr__(self) -> str: return f"Momentum(mu={float(self.mu)})" __str__ = __repr__
[docs] class NesterovOptimizer(MomentumOptimizer): """Nesterov accelerated gradient descent optimizer.""" def __init__(self, obj, attr_x: str, attr_Dx: str, mu: float): super().__init__(obj, attr_x, attr_Dx, mu)
[docs] def update(self, eta): """Apply Nesterov accelerated gradient update step.""" x = getattr(self.obj, self.attr_x) Dx = getattr(self.obj, self.attr_Dx) self.delta_x_prev = self.delta_x self.delta_x = self.mu * self.delta_x - eta * Dx x1 = x + self.mu * self.delta_x - eta * Dx setattr(self.obj, self.attr_x, x1)
def __repr__(self) -> str: return f"Nesterov(mu={float(self.mu)})" __str__ = __repr__
[docs] def parse_optimizer(text: str) -> Callable[[Any, str, str], Optimizer]: """Parse a textual optimizer specification into a factory function. Returns a callable that takes (x, Dx) and produces an Optimizer. Supported names: GradientDescent, Momentum(mu=...), Nesterov(mu=...). """ try: func = parse_function_call(text) if func.name == 'GradientDescent': return lambda obj, attr_x, attr_Dx: GradientDescentOptimizer(obj, attr_x, attr_Dx) elif func.name == 'Momentum': mu = func.as_scalar('mu') return lambda obj, attr_x, attr_Dx: MomentumOptimizer(obj, attr_x, attr_Dx, mu) elif func.name == 'Nesterov': mu = func.as_scalar('mu') return lambda obj, attr_x, attr_Dx: NesterovOptimizer(obj, attr_x, attr_Dx, mu) except: pass raise RuntimeError(f'Could not parse optimizer "{text}"')