nerva_numpy.learning_rate
Learning-rate schedulers.
These schedulers are intentionally minimal and stateless unless noted (TimeBased updates its internal lr). The parser accepts textual forms such as “Constant(0.1)”, “StepBased(0.1,0.5,10)”, or “MultiStepLR(0.1;1,3,5;0.1)”.
Functions
|
Parse a textual learning-rate scheduler specification. |
Classes
Constant learning rate: returns the same lr for any epoch. |
|
|
Exponential decay: lr * exp(-change_rate * epoch). |
Interface for epoch-indexed learning-rate schedules. |
|
|
Multi-step decay: multiply lr by gamma at specified milestone epochs. |
|
Step decay: lr * drop_rate ^ floor((1+epoch)/change_rate). |
|
Time-based decay: lr = lr / (1 + decay * epoch). |
- class nerva_numpy.learning_rate.LearningRateScheduler[source]
Bases:
object
Interface for epoch-indexed learning-rate schedules.
- class nerva_numpy.learning_rate.ConstantScheduler(lr: float)[source]
Bases:
LearningRateScheduler
Constant learning rate: returns the same lr for any epoch.
- class nerva_numpy.learning_rate.TimeBasedScheduler(lr: float, decay: float)[source]
Bases:
LearningRateScheduler
Time-based decay: lr = lr / (1 + decay * epoch).
- class nerva_numpy.learning_rate.StepBasedScheduler(lr: float, drop_rate: float, change_rate: float)[source]
Bases:
LearningRateScheduler
Step decay: lr * drop_rate ^ floor((1+epoch)/change_rate).
- class nerva_numpy.learning_rate.MultiStepLRScheduler(lr: float, milestones: List[int], gamma: float)[source]
Bases:
LearningRateScheduler
Multi-step decay: multiply lr by gamma at specified milestone epochs.
- class nerva_numpy.learning_rate.ExponentialScheduler(lr: float, change_rate: float)[source]
Bases:
LearningRateScheduler
Exponential decay: lr * exp(-change_rate * epoch).
- nerva_numpy.learning_rate.parse_learning_rate(text: str) LearningRateScheduler [source]
Parse a textual learning-rate scheduler specification.
Accepted forms include Constant(lr), TimeBased(lr,decay), StepBased(lr,drop_rate,change_rate), MultiStepLR(lr;milestones;gamma) and Exponential(lr,change_rate).