nerva_jax.training
Training helpers for the MLP, including a basic SGD loop and CLI glue.
Functions
|
Compute mean classification accuracy for a model over a data loader. |
|
Compute mean loss for a model over a data loader using the given loss. |
|
Compute and optionally print loss and accuracy statistics. |
|
Print detailed debug information for a training batch. |
|
|
|
|
|
Run a simple stochastic gradient descent (SGD) training loop using PyTorch data loaders. |
|
Perform plain stochastic gradient descent training for a multilayer perceptron using raw tensors in row layout (samples are rows). |
|
High-level training convenience that wires parsing, data and SGD. |
Classes
- class nerva_jax.training.TrainOptions[source]
Bases:
object
- debug = False
- print_statistics = True
- print_digits = 6
- nerva_jax.training.print_epoch_line(epoch, lr, loss, train_accuracy, test_accuracy, elapsed)[source]
- nerva_jax.training.print_batch_debug_info(epoch: int, batch_idx: int, M: MultilayerPerceptron, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, DY: jax.numpy.ndarray)[source]
Print detailed debug information for a training batch.
- nerva_jax.training.compute_accuracy(M: MultilayerPerceptron, data_loader: DataLoader)[source]
Compute mean classification accuracy for a model over a data loader.
- nerva_jax.training.compute_loss(M: MultilayerPerceptron, data_loader: DataLoader, loss: LossFunction)[source]
Compute mean loss for a model over a data loader using the given loss.
- nerva_jax.training.compute_statistics(M, lr, loss, train_loader, test_loader, epoch, elapsed_seconds=0.0, print_statistics=True)[source]
Compute and optionally print loss and accuracy statistics.
- nerva_jax.training.stochastic_gradient_descent(M: MultilayerPerceptron, epochs: int, loss: LossFunction, learning_rate: LearningRateScheduler, train_loader: DataLoader, test_loader: DataLoader)[source]
Run a simple stochastic gradient descent (SGD) training loop using PyTorch data loaders.
- Parameters:
M (MultilayerPerceptron) – The neural network model to train.
epochs (int) – Number of training epochs.
loss (LossFunction) – The loss function instance (must provide gradient method).
learning_rate (LearningRateScheduler) – Scheduler returning the learning rate per epoch.
train_loader (DataLoader) –
DataLoader that yields mini-batches (X, T) for training.
X: input batch of shape (batch_size, input_dim).
T: batch of target labels, either class indices (batch_size,) or one-hot encoded (batch_size, num_classes).
test_loader (DataLoader) – DataLoader that yields test batches (X, T) for evaluation.
Notes
The learning rate is updated once per epoch using the scheduler.
Gradients are normalized by batch size before backpropagation.
Debugging output is controlled by SGDOptions.debug. When enabled, per-batch information is printed via print_batch_debug_info.
- Side Effects:
Updates model parameters in-place via M.optimize(lr).
Prints statistics and training time to standard output.
- nerva_jax.training.stochastic_gradient_descent_plain(M: MultilayerPerceptron, Xtrain: jax.numpy.ndarray, Ttrain: jax.numpy.ndarray, loss: LossFunction, learning_rate: LearningRateScheduler, epochs: int, batch_size: int, shuffle: bool)[source]
Perform plain stochastic gradient descent training for a multilayer perceptron using raw tensors in row layout (samples are rows).
- Parameters:
M (MultilayerPerceptron) – The neural network model to train.
Xtrain – Training input data of shape (N, input_dim), where N is the number of training examples.
Ttrain – Training labels. Either: - class indices of shape (N,) or (N, 1), or - one-hot encoded labels of shape (N, num_classes).
loss (LossFunction) – The loss function instance (with gradient method).
learning_rate (LearningRateScheduler) – Scheduler returning the learning rate per epoch.
epochs (int) – Number of training epochs.
batch_size (int) – Number of examples per mini-batch.
shuffle (bool) – Whether to shuffle training examples each epoch.
Notes
The learning rate is updated once per epoch using the scheduler.
Gradients are normalized by batch size before backpropagation.
Debugging output is controlled by SGDOptions.debug. When enabled, per-batch information is printed via print_batch_debug_info.
If Ttrain contains class indices, they will be converted to one-hot encoding.
- Side Effects:
Updates model parameters in-place via M.optimize(lr).
Prints statistics and training time to standard output.
- nerva_jax.training.train(layer_specifications: List[str], linear_layer_sizes: List[int], linear_layer_optimizers: List[str], linear_layer_weight_initializers: List[str], batch_size: int, epochs: int, loss: str, learning_rate: str, weights_and_bias_file: str, dataset_file: str, debug: bool)[source]
High-level training convenience that wires parsing, data and SGD.