# Copyright 2023 Wieger Wesselink.
# Distributed under the Boost Software License, Version 1.0.
# (See accompanying file LICENSE or http://www.boost.org/LICENSE_1_0.txt)
"""Optimizers used to adjusts the model's parameters based on the gradients.
Only SGD, Momentum and Nesterov variants are provided. The parser creates
factory callables from textual specifications like "Momentum(mu=0.9)".
"""
from typing import Any, Callable, List
from nerva_torch.utilities import parse_function_call
import torch
[docs]
class Optimizer(object):
"""Minimal optimizer interface used by layers to update parameters."""
[docs]
def update(self, eta):
raise NotImplementedError
[docs]
class CompositeOptimizer(Optimizer):
"""Combines multiple optimizers to update different parameter groups."""
def __init__(self, optimizers: List[Optimizer]):
self.optimizers = optimizers
[docs]
def update(self, eta):
"""Update all contained optimizers with the given learning rate."""
for optimizer in self.optimizers:
optimizer.update(eta)
def __repr__(self) -> str:
optimizers_str = ", ".join(str(opt) for opt in self.optimizers)
return f"CompositeOptimizer([{optimizers_str}])"
__str__ = __repr__
[docs]
class GradientDescentOptimizer(Optimizer):
"""Standard gradient descent optimizer: x -= eta * grad."""
def __init__(self, x, Dx):
self.x = x
self.Dx = Dx
[docs]
def update(self, eta):
"""Apply gradient descent update step."""
self.x -= eta * self.Dx
def __repr__(self) -> str:
return "GradientDescent()"
__str__ = __repr__
[docs]
class MomentumOptimizer(GradientDescentOptimizer):
"""Gradient descent with momentum for accelerated convergence."""
def __init__(self, x, Dx, mu):
super().__init__(x, Dx)
self.mu = mu
self.delta_x = torch.zeros_like(x)
[docs]
def update(self, eta):
"""Apply momentum update step."""
self.delta_x = self.mu * self.delta_x - eta * self.Dx
self.x += self.delta_x
def __repr__(self) -> str:
return f"Momentum(mu={float(self.mu)})"
__str__ = __repr__
[docs]
class NesterovOptimizer(MomentumOptimizer):
"""Nesterov accelerated gradient descent optimizer."""
def __init__(self, x, Dx, mu):
super().__init__(x, Dx, mu)
[docs]
def update(self, eta):
"""Apply Nesterov accelerated gradient update step."""
self.delta_x = self.mu * self.delta_x - eta * self.Dx
self.x += self.mu * self.delta_x - eta * self.Dx
def __repr__(self) -> str:
return f"Nesterov(mu={float(self.mu)})"
__str__ = __repr__
[docs]
def parse_optimizer(text: str) -> Callable[[Any, Any], Optimizer]:
"""Parse a textual optimizer specification into a factory function.
Returns a callable that takes (x, Dx) and produces an Optimizer.
Supported names: GradientDescent, Momentum(mu=...), Nesterov(mu=...).
"""
try:
func = parse_function_call(text)
if func.name == 'GradientDescent':
return lambda x, Dx: GradientDescentOptimizer(x, Dx)
elif func.name == 'Momentum':
mu = func.as_scalar('mu')
return lambda x, Dx: MomentumOptimizer(x, Dx, mu)
elif func.name == 'Nesterov':
mu = func.as_scalar('mu')
return lambda x, Dx: NesterovOptimizer(x, Dx, mu)
except:
pass
raise RuntimeError(f'Could not parse optimizer "{text}"')