Introduction

The Nerva-JAX Python Library is a library for neural networks. It is part of the Nerva library collection https://github.com/wiegerw/nerva, that includes native C++ and Python implementations of neural networks. Originally the library was intended for experimenting with truly sparse neural networks. Nowadays, the library also aims to provide a transparent and accessible implementation of neural networks.

This document describes the implementation of the Nerva-JAX Python Library. For initial versions of the library I took inspiration from lecture notes of machine learning courses by Roger Grosse, which I highly recommend. This influence may still be evident in the naming of symbols.

Installation

The nerva_jax library can be installed in two ways: from the source repository or from the Python Package Index (PyPI).

# Install from the local repository
pip install .
# Install directly from PyPI
pip install nerva-torch

Overview of the code

This section provides an overview of the code in the Nerva-JAX Python Library. All core functionality is contained in the nerva_jax module.

Module contents

The most important files in the nerva_jax module are listed below. Each file implements a distinct part of the neural network library.

File Description

multilayer_perceptron.py

Defines the MultilayerPerceptron class, representing a feedforward neural network with multiple layers.

layers.py

Implements various neural network layers, such as fully connected or custom layers.

activation_functions.py

Provides commonly used activation functions (e.g., ReLU, sigmoid, tanh) to introduce non-linearity.

loss_functions.py

Implements loss functions used to quantify the difference between predictions and targets (e.g., cross-entropy, MSE).

weight_initializers.py

Provides functions for initializing neural network weights, supporting different strategies for stability and performance.

optimizers.py

Defines optimizer functions that update neural network parameters based on computed gradients (e.g., SGD, momentum).

learning_rate.py

Implements learning rate schedulers to adjust the learning rate dynamically during training.

training.py

Contains the stochastic gradient descent (SGD) algorithms for training multilayer perceptrons.

Number type

The Nerva-JAX Python Library uses 32-bit floating point numbers (float32) as its default number type. This choice balances memory usage and computational efficiency on both CPUs and GPUs. All computations, including feedforward, backpropagation, and gradient updates, are performed in this precision.

API / User guide

Classes

Class Layer

The class Layer is the base class of all neural network layers. There are several different types of layers:

Layer Description

LinearLayer

A linear layer.

ActivationLayer

A linear layer followed by a pointwise activation function.

SReLULayer

A linear layer followed by a SReLU activation function.

SoftmaxLayer

A linear layer followed by a softmax activation function.

LogSoftmaxLayer

A linear layer followed by a logsoftmax activation function.

BatchNormalizationLayer

A batch normalization layer.

Class MultilayerPerceptron

A multilayer perceptron (MLP) is modeled using the class MultilayerPerceptron. It contains a list of layers, and has member functions feedforward, backpropagate and optimize that can be used for training the neural network. Constructing an MLP can be done as follows:

    M = MultilayerPerceptron()

    # configure layer 1
    layer1 = ActivationLayer(784, 1024, ReLUActivation())
    layer1.W = weights_xavier_normal(layer1.W)
    layer1.b = bias_zero(layer1.b)
    optimizer_W = MomentumOptimizer(layer1, "W", "DW", 0.9)
    optimizer_b = NesterovOptimizer(layer1, "b", "Db", 0.75)
    layer1.optimizer = CompositeOptimizer([optimizer_W, optimizer_b])

    # configure layer 2
    layer2 = ActivationLayer(1024, 512, LeakyReLUActivation(0.5))
    layer2.set_weights("XavierNormal")
    layer2.set_optimizer("Momentum(0.8)")

    # configure layer 3
    layer3 = LinearLayer(512, 10)
    layer3.set_weights("HeNormal")
    layer3.set_optimizer("GradientDescent")

    M.layers = [layer1, layer2, layer3]

This creates an MLP with three linear layers, and various activation functions, weight initializers and optimizers.

Another way to construct MLPs is provided by the function parse_multilayer_perceptron, that parses an MLP from textual specifications:

    layer_specifications = ["ReLU", "LeakyReLU(0.5)", "Linear"]
    linear_layer_sizes = [784, 1024, 512, 10]
    linear_layer_optimizers = ["Nesterov(0.9)", "Momentum(0.8)", "GradientDescent"]
    linear_layer_weight_initializers = ["XavierNormal", "XavierUniform", "HeNormal"]
    M = parse_multilayer_perceptron(layer_specifications,
                                    linear_layer_sizes,
                                    linear_layer_optimizers,
                                    linear_layer_weight_initializers)

Note that optimizers should not only be specified for linear layers, but also for batch normalization layers.

Class LossFunction

The class LossFunction is the base class of all loss functions. There are five loss functions available:

  • SquaredErrorLoss

  • CrossEntropyLoss

  • LogisticCrossEntropyLoss

  • NegativeLogLikelihoodLoss

  • SoftmaxCrossEntropyLoss

See the Nerva library specifications document for precise definitions of these loss functions.

Class ActivationFunction

The class ActivationFunction is the base class of all activation functions. The following activation functions are available:

  • ReLU

  • Sigmoid

  • Softmax

  • LogSoftmax

  • LeakyReLU

  • AllReLU

  • SReLU

  • HyperbolicTangent

See the Nerva library specifications document for precise definitions of these activation functions.

Training a neural network

The library provides two variants of stochastic gradient descent (SGD) training for multilayer perceptrons. The preferred interface is stochastic_gradient_descent, which accepts PyTorch-style DataLoader instances for training and test sets. This approach is the easiest in practice, since the DataLoader abstraction automatically handles batching, shuffling, and iteration over the dataset.

def stochastic_gradient_descent(M: MultilayerPerceptron,
                                epochs: int,
                                loss: LossFunction,
                                learning_rate: LearningRateScheduler,
                                train_loader: DataLoader,
                                test_loader: DataLoader
                                ):
    print_epoch_header()
    lr = learning_rate(0)
    compute_statistics(M, lr, loss, train_loader, test_loader, epoch=0)
    training_time = 0.0

    for epoch in range(epochs):
        timer = StopWatch()
        lr = learning_rate(epoch)  # update the learning at the start of each epoch

        for k, (X, T) in enumerate(train_loader):
            Y = M.feedforward(X)
            DY = loss.gradient(Y, T) / X.shape[0]

            if TrainOptions.debug:
                print_batch_debug_info(epoch, k, M, X, Y, DY)

            M.backpropagate(Y, DY)
            M.optimize(lr)

        seconds = timer.seconds()
        training_time += seconds
        compute_statistics(M, lr, loss, train_loader, test_loader, epoch=epoch + 1, elapsed_seconds=seconds)

    print_epoch_footer(training_time)

For educational purposes, a lower-level variant stochastic_gradient_descent_plain is also available. It operates directly on raw tensors in row layout (samples as rows), giving full control over batching and shuffling, but at the cost of additional boilerplate.

def stochastic_gradient_descent_plain(M: MultilayerPerceptron,
                                      Xtrain: Matrix,
                                      Ttrain: Matrix,
                                      loss: LossFunction,
                                      learning_rate: LearningRateScheduler,
                                      epochs: int,
                                      batch_size: int,
                                      shuffle: bool
                                     ):
    N = Xtrain.shape[0]  # number of examples (row layout)
    I = list(range(N))
    K = N // batch_size  # number of full batches
    num_classes = M.layers[-1].output_size()

    for epoch in range(epochs):
        if shuffle:
            random.shuffle(I)
        lr = learning_rate(epoch)  # update learning rate each epoch

        for k in range(K):
            batch = I[k * batch_size: (k + 1) * batch_size]
            X = Xtrain[batch, :]   # shape (batch_size, input_dim)

            # Convert labels to one-hot if needed
            if len(Ttrain.shape) == 2 and Ttrain.shape[1] > 1:
                # already one-hot encoded
                T = Ttrain[batch, :]
            else:
                T = to_one_hot(Ttrain[batch], num_classes)

            Y = M.feedforward(X)
            DY = loss.gradient(Y, T) / X.shape[0]

            if TrainOptions.debug:
                print_batch_debug_info(epoch, k, M, X, Y, DY)

            M.backpropagate(Y, DY)
            M.optimize(lr)

Both functions support targets provided either as a one-dimensional tensor of class indices (the default convention used in PyTorch’s classification losses) or as a one-hot encoded matrix with as many columns as the output Y. If class indices are provided, they are internally converted to one-hot encoding using to_one_hot.

Batching of the training data depends on the chosen interface. With stochastic_gradient_descent, batching and shuffling are handled automatically by the DataLoader. With stochastic_gradient_descent_plain, batching is implemented manually inside the training loop.

In each epoch, every batch (X, T) goes through the three standard steps of stochastic gradient descent:

  1. Feedforward: Given an input batch X and the current neural network parameters Θ, compute the outputs Y. In the code, this corresponds to Y = M.feedforward(X).

  2. Backpropagation: Given the outputs Y and the targets T, compute the gradient DY of Y with respect to the loss function. Then, using Y and DY, compute the gradients of the model parameters . These parameter gradients are stored internally in the model rather than returned. In the code, this step is performed by M.backpropagate(Y, DY).

  3. Optimization: Use the internally stored parameter gradients to update the parameters Θ. In the code, this corresponds to M.optimize(lr).

Command line tools

The following command line tools are available. They can be found in the tools directory.

Tool Description

mlp.py

A tool for training multilayer perceptrons.

inspect_npz.py

A tool for inspecting the contents of a file in NumPy NPZ format.

The tool mlp.py

The tool mlp.py can be used for training multilayer perceptrons. An example invocation of the mlp.py tool is

python3 -u ../tools/mlp.py \
        --layers="ReLU;ReLU;Linear" \
        --layer-sizes="3072;1024;512;10" \
        --layer-weights="XavierNormal;XavierNormal;XavierNormal" \
        --optimizers="Momentum(0.9);Momentum(0.9);Momentum(0.9)" \
        --batch-size=100 \
        --epochs=5 \
        --loss=SoftmaxCrossEntropy \
        --learning-rate="Constant(0.01)" \
        --load-dataset=$dataset

This will train a CIFAR-10 model using an MLP consisting of three linear layers with activation functions ReLU, ReLU and no activation. A script prepare_data.py is available in the data directory that can be used to download the dataset, flatten it and store it in .npz format. See the section Preparing data for details.

The output may look like this:

Loading dataset from file ../data/cifar10-flattened.npz
--------------------------------------------------------------------------------
epoch |           lr |         loss |    train_acc |     test_acc |     time (s)
--------------------------------------------------------------------------------
    0 |     0.010000 |     2.408590 |     0.097760 |     0.096000 |     0.000000
    1 |     0.010000 |     1.645980 |     0.412700 |     0.410500 |     3.572163
    2 |     0.010000 |     1.548570 |     0.448900 |     0.440100 |     3.586480
    3 |     0.010000 |     1.475506 |     0.477560 |     0.465100 |     4.450712
    4 |     0.010000 |     1.431293 |     0.491620 |     0.474800 |     4.429117
    5 |     0.010000 |     1.369700 |     0.513900 |     0.494300 |     4.489507
--------------------------------------------------------------------------------
Total training time: 20.527979 s

mlp.py Command Line Options

This section gives an overview of the command line interface of the mlp.py tool.

Parameter Lists

Some options accept a list of items. Lists must be semicolon-separated. For example: --layers="ReLU;AllReLU(0.3);Linear".

Named Parameters

Some items accept parameters using function-call syntax with commas to separate arguments. Use named parameters when needed, e.g. AllReLU(alpha=0.3). If a parameter has a default value, it may be omitted: SReLU() is equivalent to SReLU(al=0,tl=0,ar=0,tr=1).

General Options
  • --help Display help information.

  • --debug Enable debug output. Prints batches, weight matrices, bias vectors, and gradients.

Random Generator Options
  • --seed <value> Set the seed value for the random number generator.

Layer Configuration Options
  • --layers <value> Specify a semicolon-separated list of layers. Example: --layers=ReLU;AllReLU(0.3);Linear.

Specification Description

Linear

Linear layer without activation

ReLU

Linear layer with ReLU activation

Sigmoid

Linear layer with sigmoid activation

Softmax

Linear layer with softmax activation

LogSoftmax

Linear layer with log-softmax activation

HyperbolicTangent

Linear layer with hyperbolic tangent activation

AllReLU(<alpha>)

Linear layer with AllReLU activation

SReLU(<al>,<tl>,<ar>,<tr>)

Linear layer with SReLU activation. Defaults: al=0,tl=0,ar=0,tr=1. Equivalent to ReLU when defaults are used.

BatchNormalization

Batch normalization layer

  • --layer-sizes <value> Specify the sizes of linear layers (semicolon-separated). Example: --layer-sizes=3072;1024;512;10.

  • --layer-weights <value> Specify the weight initialization method for linear layers. Supported values:

Specification Description

XavierNormal

Xavier Glorot weights (normal distribution)

XavierUniform

Xavier Glorot weights (uniform distribution)

HeNormal

Kaiming He weights (normal distribution)

HeUniform

Kaiming He weights (uniform distribution)

Normal

Normal distribution

Uniform

Uniform distribution

Zero

All weights are zero (N.B. This is not recommended for training)

Training Configuration Options
  • --epochs <value> Set the number of training epochs. Default: 100.

  • --batch-size <value> Set the training batch size.

  • --optimizers <value> Specify a semicolon-separated list of optimizers for linear and batch normalization layers.

Specification Description

GradientDescent

Standard gradient descent

Momentum(mu)

Momentum optimization with parameter mu

Nesterov(mu)

Nesterov momentum optimization

  • --learning-rate <value> Specify a semicolon-separated list of learning rate schedulers. If only one is given, it applies to all layers.

Specification Description

Constant(lr)

Constant learning rate lr

TimeBased(lr, decay)

Adaptive learning rate with decay

StepBased(lr, drop_rate, change_rate)

Step-based learning rate with scheduled drops

MultistepLR(lr, milestones, gamma)

Drops learning rate at specified epoch milestones

Exponential(lr, change_rate)

Exponentially decreasing learning rate

  • --loss <value> Specify the loss function. Supported values:

Specification Description

SquaredError

Squared error loss

CrossEntropy

Cross entropy loss

LogisticCrossEntropy

Logistic cross entropy loss

SoftmaxCrossEntropy

Softmax cross entropy (matches PyTorch)

NegativeLogLikelihood

Negative log likelihood loss

  • --load-weights <value> Load weights and biases from a NumPy .npz file. Weight matrices keys: W1,W2,…​; bias vectors keys: b1,b2,…​. See numpy.lib.format[numpy.lib.format].

Dataset Options
  • --load-dataset <file> Load a dataset from a NumPy .npz file. The file must contain the following arrays:

    • Xtrain: training inputs

    • Ttrain: training labels

    • Xtest: test inputs

    • Ttest: test labels

Other arrays will be ignored. The shapes should match the expected input and output dimensions of the network.

The tool inspect_npz.py

The tool inspect_npz.py can be used to inspect the contents of a dataset stored in .npz format.

An example invocation of the inspect_npz.py tool is

python inspect_npz.py data/cifar10-flattened.npz

The output may look like this:

Xtrain   (50000x3072  )  inf-norm = 1.00000000
[[0.23137255 0.16862745 0.19607843 ... 0.54901961 0.32941176 0.28235294]
 [0.60392157 0.49411765 0.41176471 ... 0.54509804 0.55686275 0.56470588]
 [1.         0.99215686 0.99215686 ... 0.3254902  0.3254902  0.32941176]
 ...
 [0.1372549  0.15686275 0.16470588 ... 0.30196078 0.25882353 0.19607843]
 [0.74117647 0.72941176 0.7254902  ... 0.6627451  0.67058824 0.67058824]
 [0.89803922 0.9254902  0.91764706 ... 0.67843137 0.63529412 0.63137255]]

Ttrain   (50000       )  inf-norm = 9.00000000
[6 9 9 ... 9 1 1]

Xtest    (10000x3072  )  inf-norm = 1.00000000
[[0.61960784 0.62352941 0.64705882 ... 0.48627451 0.50588235 0.43137255]
 [0.92156863 0.90588235 0.90980392 ... 0.69803922 0.74901961 0.78039216]
 [0.61960784 0.61960784 0.54509804 ... 0.03137255 0.01176471 0.02745098]
 ...
 [0.07843137 0.0745098  0.05882353 ... 0.19607843 0.20784314 0.18431373]
 [0.09803922 0.05882353 0.09019608 ... 0.31372549 0.31764706 0.31372549]
 [0.28627451 0.38431373 0.38823529 ... 0.36862745 0.22745098 0.10196078]]

Ttest    (10000       )  inf-norm = 9.00000000
[3 8 8 ... 5 1 7]

With the command line option --shapes-only a summary can be obtained:

Xtrain   (50000x3072  )  inf-norm = 1.00000000
Ttrain   (50000       )  inf-norm = 9.00000000
Xtest    (10000x3072  )  inf-norm = 1.00000000
Ttest    (10000       )  inf-norm = 9.00000000

Data Handling

The Nerva-JAX Python Library provides utilities for reading and writing datasets and the weights and biases of MLP models in NumPy .npz format. This format ensures portability between Python and C++ implementations. Currently, storing the complete model, including its architecture, is not supported.

NPZ format

The default storage format used in the Nerva libraries is the NumPy NPZ format (see numpy.lib.format). A .npz file can store a dictionary of arrays in compressed form, which allows both datasets and model parameters to be saved efficiently.

Preparing data

The mlp.py utility expects training and testing data in .npz format. A helper script is provided to download and preprocess commonly used datasets, including MNIST and CIFAR-10.

The script is located at data/prepare_data.py and can be run from the command line.

MNIST

To download and prepare MNIST:

python prepare_data.py --dataset=mnist --download

This will:

  • Download mnist.npz if it does not exist.

  • Create a flattened and normalized version of the dataset as mnist-flattened.npz.

The output file contains:

  • Xtrain, Xtest: flattened and normalized image data

  • Ttrain, Ttest: corresponding label vectors

CIFAR-10

To download and prepare CIFAR-10:

python prepare_data.py --dataset=cifar10 --download

This will:

The output file contains:

  • Xtrain, Xtest: flattened image arrays with pixel values normalized to [0, 1]

  • Ttrain, Ttest: integer class labels

Reusing Existing Files

If the required .npz files already exist, the script will detect this and skip reprocessing. It is safe to rerun the script without overwriting existing files.

Help

To see all script options:

python prepare_data.py --help

Inspecting .npz files

When constructing DataLoader instances from integer class labels, explicitly pass num_classes. If omitted, the loader infers it as max(label) + 1, which may underestimate the number of classes for small subsets, leading to one-hot vectors with too few columns and mismatched dimensions with the model output.

To inspect .npz files interactively, you can use the inspect_npz.py tool (see Command Line Tools section).

Storing datasets and weights

The mlp.py utility supports saving and loading datasets and model parameters in .npz format.

  • Use --save-dataset and --load-dataset to write or read datasets.

  • Use --save-weights and --load-weights to store or restore the weights and biases of an MLP.

The .npz file for datasets contains:

  • Xtrain, Ttrain: training inputs and labels

  • Xtest, Ttest: test inputs and labels

The .npz file for model parameters contains:

  • W1, W2, …​ : weight matrices for each linear layer

  • b1, b2, …​ : corresponding bias vectors

All arrays use standard NumPy formats and can be inspected or manipulated in Python using numpy.load() and numpy.savez().

The architecture of the model (number of layers, activation functions, etc.) is not stored in the .npz file. This must be specified separately when reloading weights.

Advanced Topics

Matrix operations

The most important part of the implementation of neural networks consists of matrix operations. In the implementation of activation functions, loss functions and neural network layers, many different matrix operations are needed. In Nerva a structured approach is followed to implement these components. All equations are expressed in terms of the matrix operations in the table below.

Table 1. matrix operations
Operation Code Definition

\(0_{m}\)

zeros(m)

\(m \times 1\) column vector with elements equal to 0

\(0_{mn}\)

zeros(m, n)

\(m \times n\) matrix with elements equal to 0

\(1_{m}\)

ones(m)

\(m \times 1\) column vector with elements equal to 1

\(1_{mn}\)

ones(m, n)

\(m \times n\) matrix with elements equal to 1

\(\mathbb{I}_n\)

identity(n)

\(n \times n\) identity matrix

\(X^\top\)

X.transpose()

transposition

\(cX\)

c * X

scalar multiplication, \(c \in \mathbb{R}\)

\(X + Y\)

X + Y

addition

\(X - Y\)

X - Y

subtraction

\(X \cdot Z\)

X * Z

matrix multiplication, also denoted as \(XZ\)

\(x^\top y~\) or \(~x y^\top\)

dot(x,y)

dot product, \(x,y \in \mathbb{R}^{m \times 1}\) or \(x,y \in \mathbb{R}^{1 \times n}\)

\(X \odot Y\)

hadamard(X,Y)

element-wise product of \(X\) and \(Y\)

\(\mathsf{diag}(X)\)

diag(X)

column vector that contains the diagonal of \(X\)

\(\mathsf{Diag}(x)\)

Diag(x)

diagonal matrix with \(x\) as diagonal, \(x \in \mathbb{R}^{1 \times n}\) or \(x \in \mathbb{R}^{m \times 1}\)

\(1_m^\top \cdot X \cdot 1_n\)

elements_sum(X)

sum of the elements of \(X\)

\(x \cdot 1_n^\top\)

column_repeat(x, n)

\(n\) copies of column vector \(x \in \mathbb{R}^{m \times 1}\)

\(1_m \cdot x\)

row_repeat(x, m)

\(m\) copies of row vector \(x \in \mathbb{R}^{1 \times n}\)

\(1_m^\top \cdot X\)

columns_sum(X)

\(1 \times n\) row vector with sums of the columns of \(X\)

\(X \cdot 1_n\)

rows_sum(X)

\(m \times 1\) column vector with sums of the rows of \(X\)

\(\max(X)_{col}\)

columns_max(X)

\(1 \times n\) row vector with maximum values of the columns of \(X\)

\(\max(X)_{row}\)

rows_max(X)

\(m \times 1\) column vector with maximum values of the rows of \(X\)

\((1_m^\top \cdot X) / n\)

columns_mean(X)

\(1 \times n\) row vector with mean values of the columns of \(X\)

\((X \cdot 1_n) / m\)

rows_mean(X)

\(m \times 1\) column vector with mean values of the rows of \(X\)

\(f(X)\)

apply(f, X)

element-wise application of \(f: \mathbb{R} \rightarrow \mathbb{R}\) to \(X\)

\(e^X\)

exp(X)

element-wise application of \(f: x \rightarrow e^x\) to \(X\)

\(\log(X)\)

log(X)

element-wise application of the natural logarithm \(f: x \rightarrow \ln(x)\) to \(X\)

\(1 / X\)

reciprocal(X)

element-wise application of \(f: x \rightarrow 1/x\) to \(X\)

\(\sqrt{X}\)

sqrt(X)

element-wise application of \(f: x \rightarrow \sqrt{x}\) to \(X\)

\(X^{-1/2}\)

inv_sqrt(X)

element-wise application of \(f: x \rightarrow x^{-1/2}\) to \(X\)

\(\log(\sigma(X))\)

log_sigmoid(X)

element-wise application of \(f: x \rightarrow \log(\sigma(x))\) to \(X\)

References