nerva_jax.training

Training helpers for the MLP, including a basic SGD loop and CLI glue.

Functions

compute_accuracy(M, data_loader)

Compute mean classification accuracy for a model over a data loader.

compute_loss(M, data_loader, loss)

Compute mean loss for a model over a data loader using the given loss.

compute_statistics(M, lr, loss, ...[, ...])

Compute and optionally print loss and accuracy statistics.

print_batch_debug_info(epoch, batch_idx, M, ...)

Print detailed debug information for a training batch.

print_epoch_footer(total_time)

print_epoch_header()

print_epoch_line(epoch, lr, loss, ...)

stochastic_gradient_descent(M, epochs, loss, ...)

Run a simple stochastic gradient descent (SGD) training loop using PyTorch data loaders.

stochastic_gradient_descent_plain(M, Xtrain, ...)

Perform plain stochastic gradient descent training for a multilayer perceptron using raw tensors in row layout (samples are rows).

train(layer_specifications, ...)

High-level training convenience that wires parsing, data and SGD.

Classes

TrainOptions()

class nerva_jax.training.TrainOptions[source]

Bases: object

debug = False
print_statistics = True
print_digits = 6
nerva_jax.training.print_epoch_header()[source]
nerva_jax.training.print_epoch_line(epoch, lr, loss, train_accuracy, test_accuracy, elapsed)[source]
nerva_jax.training.print_batch_debug_info(epoch: int, batch_idx: int, M: MultilayerPerceptron, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, DY: jax.numpy.ndarray)[source]

Print detailed debug information for a training batch.

nerva_jax.training.compute_accuracy(M: MultilayerPerceptron, data_loader: DataLoader)[source]

Compute mean classification accuracy for a model over a data loader.

nerva_jax.training.compute_loss(M: MultilayerPerceptron, data_loader: DataLoader, loss: LossFunction)[source]

Compute mean loss for a model over a data loader using the given loss.

nerva_jax.training.compute_statistics(M, lr, loss, train_loader, test_loader, epoch, elapsed_seconds=0.0, print_statistics=True)[source]

Compute and optionally print loss and accuracy statistics.

nerva_jax.training.stochastic_gradient_descent(M: MultilayerPerceptron, epochs: int, loss: LossFunction, learning_rate: LearningRateScheduler, train_loader: DataLoader, test_loader: DataLoader)[source]

Run a simple stochastic gradient descent (SGD) training loop using PyTorch data loaders.

Parameters:
  • M (MultilayerPerceptron) – The neural network model to train.

  • epochs (int) – Number of training epochs.

  • loss (LossFunction) – The loss function instance (must provide gradient method).

  • learning_rate (LearningRateScheduler) – Scheduler returning the learning rate per epoch.

  • train_loader (DataLoader) –

    DataLoader that yields mini-batches (X, T) for training.

    • X: input batch of shape (batch_size, input_dim).

    • T: batch of target labels, either class indices (batch_size,) or one-hot encoded (batch_size, num_classes).

  • test_loader (DataLoader) – DataLoader that yields test batches (X, T) for evaluation.

Notes

  • The learning rate is updated once per epoch using the scheduler.

  • Gradients are normalized by batch size before backpropagation.

  • Debugging output is controlled by SGDOptions.debug. When enabled, per-batch information is printed via print_batch_debug_info.

Side Effects:
  • Updates model parameters in-place via M.optimize(lr).

  • Prints statistics and training time to standard output.

nerva_jax.training.stochastic_gradient_descent_plain(M: MultilayerPerceptron, Xtrain: jax.numpy.ndarray, Ttrain: jax.numpy.ndarray, loss: LossFunction, learning_rate: LearningRateScheduler, epochs: int, batch_size: int, shuffle: bool)[source]

Perform plain stochastic gradient descent training for a multilayer perceptron using raw tensors in row layout (samples are rows).

Parameters:
  • M (MultilayerPerceptron) – The neural network model to train.

  • Xtrain – Training input data of shape (N, input_dim), where N is the number of training examples.

  • Ttrain – Training labels. Either: - class indices of shape (N,) or (N, 1), or - one-hot encoded labels of shape (N, num_classes).

  • loss (LossFunction) – The loss function instance (with gradient method).

  • learning_rate (LearningRateScheduler) – Scheduler returning the learning rate per epoch.

  • epochs (int) – Number of training epochs.

  • batch_size (int) – Number of examples per mini-batch.

  • shuffle (bool) – Whether to shuffle training examples each epoch.

Notes

  • The learning rate is updated once per epoch using the scheduler.

  • Gradients are normalized by batch size before backpropagation.

  • Debugging output is controlled by SGDOptions.debug. When enabled, per-batch information is printed via print_batch_debug_info.

  • If Ttrain contains class indices, they will be converted to one-hot encoding.

Side Effects:
  • Updates model parameters in-place via M.optimize(lr).

  • Prints statistics and training time to standard output.

nerva_jax.training.train(layer_specifications: List[str], linear_layer_sizes: List[int], linear_layer_optimizers: List[str], linear_layer_weight_initializers: List[str], batch_size: int, epochs: int, loss: str, learning_rate: str, weights_and_bias_file: str, dataset_file: str, debug: bool)[source]

High-level training convenience that wires parsing, data and SGD.