Source code for nerva_jax.learning_rate

# Copyright 2022 Wieger Wesselink.
# Distributed under the Boost Software License, Version 1.0.
# (See accompanying file LICENSE or http://www.boost.org/LICENSE_1_0.txt)

"""Learning-rate schedulers.

These schedulers are intentionally minimal and stateless unless noted
(TimeBased updates its internal lr). The parser accepts textual forms such as
"Constant(0.1)", "StepBased(0.1,0.5,10)", or "MultiStepLR(0.1;1,3,5;0.1)".
"""

import math
import re
from typing import List


[docs] class LearningRateScheduler(object): """Interface for epoch-indexed learning-rate schedules.""" def __call__(self, epoch: int) -> float: raise NotImplementedError
[docs] class ConstantScheduler(LearningRateScheduler): """Constant learning rate: returns the same lr for any epoch.""" def __init__(self, lr: float): self.lr = lr def __str__(self): return f'ConstantScheduler(lr={self.lr})' def __call__(self, epoch: int) -> float: return self.lr
[docs] class TimeBasedScheduler(LearningRateScheduler): """Time-based decay: lr = lr / (1 + decay * epoch).""" def __init__(self, lr: float, decay: float): self.lr = lr self.decay = decay def __str__(self): return f'TimeBasedScheduler(lr={self.lr}, decay={self.decay})' def __call__(self, epoch: int) -> float: self.lr = self.lr / (1 + self.decay * float(epoch)) return self.lr
[docs] class StepBasedScheduler(LearningRateScheduler): """Step decay: lr * drop_rate ^ floor((1+epoch)/change_rate).""" def __init__(self, lr: float, drop_rate: float, change_rate: float): self.lr = lr self.drop_rate = drop_rate self.change_rate = change_rate def __str__(self): return f'StepBasedScheduler(lr={self.lr}, drop_rate={self.drop_rate}, change_rate={self.change_rate})' def __call__(self, epoch: int) -> float: return self.lr * math.pow(self.drop_rate, math.floor((1.0 + epoch) / self.change_rate))
[docs] class MultiStepLRScheduler(LearningRateScheduler): """Multi-step decay: multiply lr by gamma at specified milestone epochs.""" def __init__(self, lr: float, milestones: List[int], gamma: float): self.lr = lr self.milestones = milestones self.gamma = gamma def __str__(self): return f'MultiStepLRScheduler(lr={self.lr}, milestones={self.milestones}, gamma={self.gamma})' def __call__(self, epoch: int) -> float: eta = self.lr for milestone in self.milestones: if epoch >= milestone: eta *= self.gamma else: break return eta
[docs] class ExponentialScheduler(LearningRateScheduler): """Exponential decay: lr * exp(-change_rate * epoch).""" def __init__(self, lr: float, change_rate: float): self.lr = lr self.change_rate = change_rate def __str__(self): return f'ExponentialScheduler(lr={self.lr}, change_rate={self.change_rate})' def __call__(self, epoch: int) -> float: return self.lr * math.exp(-self.change_rate * float(epoch))
[docs] def parse_learning_rate(text: str) -> LearningRateScheduler: """Parse a textual learning-rate scheduler specification. Accepted forms include Constant(lr), TimeBased(lr,decay), StepBased(lr,drop_rate,change_rate), MultiStepLR(lr;milestones;gamma) and Exponential(lr,change_rate). """ try: if text.startswith('Constant'): m = re.match(r'Constant\((.*)\)', text) lr = float(m.group(1)) return ConstantScheduler(lr) elif text.startswith('TimeBased'): m = re.match(r'TimeBased\((.*),(.*)\)', text) lr = float(m.group(1)) decay = float(m.group(2)) return TimeBasedScheduler(lr, decay) elif text.startswith('StepBased'): m = re.match(r'StepBased\((.*),(.*),(.*)\)', text) lr = float(m.group(1)) drop_rate = float(m.group(2)) change_rate = float(m.group(3)) return StepBasedScheduler(lr, drop_rate, change_rate) elif text.startswith('MultiStepLR'): m = re.match(r'MultiStepLR\((.*);(.*);(.*)\)', text) lr = float(m.group(1)) milestones = [int(x) for x in m.group(2).split(',')] gamma = float(m.group(3)) return MultiStepLRScheduler(lr, milestones, gamma) elif text.startswith('Exponential'): m = re.match(r'Exponential\((.*),(.*)\)', text) lr = float(m.group(1)) change_rate = float(m.group(2)) return ExponentialScheduler(lr, change_rate) except: pass raise RuntimeError(f"could not parse learning rate scheduler '{text}'")