1. Introduction

The Nerva-Rowwise C++ Library is a library for neural networks. It is part of the Nerva library collection https://github.com/wiegerw/nerva, that includes native C++ and Python implementations of neural networks. Originally the library was intended for experimenting with truly sparse neural networks. Nowadays, the library also aims to provide a transparent and accessible implementation of neural networks.

This document describes the implementation of the Nerva-Rowwise C++ Library. For initial versions of the library I took inspiration from lecture notes of machine learning courses by Roger Grosse, which I highly recommend. This influence may still be evident in the naming of symbols.

2. Installation

The following build systems are supported for building the Nerva-Rowwise C++ Library:

2.1. Dependencies

The Nerva-Rowwise C++ Library has the following dependencies:

Library Description

doctest

Unit testing framework

FMT

Formatting library

Lyra

Command line argument parser

Eigen

Linear algebra library

pybind11

Python bindings

Intel MKL

Math Kernel Library

Intel oneAPI (*)

Intel oneAPI toolkit

(*) The Intel oneAPI dependency is optional, but highly recommended. This library is needed for the SYCL computation mode. Note that oneAPI includes the icpx compiler and the MKL library. Please make sure to use the latest version of oneAPI.

2.2. Environment variables

The following environment variables may have to be set:

Environment variable Description

ONEAPI_ROOT

Path to the OneAPI installation directory.

MKL_DIR

Path to the MKL installation directory.

MKL_NUM_THREADS

Controls the number of threads used at runtime by the MKL library.

OMP_NUM_THREADS

Controls the number of threads used at runtime by the OMP library.

NERVA_ENABLE_SYCL

Enable the SYCL computation mode.

NERVA_USE_DOUBLE

Use double instead of single precision in the computations.

Note that CMake seems to require setting the MKL_DIR variable, even if the MKL library is found in the ONEAPI_ROOT directory.

See techniques-to-set-the-number-of-threads.html for more information about setting the number of threads.

You can skip setting the ONEAPI_ROOT variable by running the OneAPI setvars.sh script beforehand. This script does not set the MKL_DIR variable, so you may have to set it manually.

2.3. Installation using CMake

Using CMake, the Nerva-Rowwise C++ Library can be built in a standard way. Note that dependencies for doctest, Eigen, FMT, Lyra, and pybind11 are automatically managed using the FetchContent commands in the build process.

2.3.1. Linux CMake install

On Linux, you can build and install the library with a command like this:

mkdir build
cd build
cmake .. \
    -DCMAKE_INSTALL_PREFIX=../install \
    -DCMAKE_BUILD_TYPE=RELEASE \
    -DMKL_DIR:PATH=$ONEAPI_ROOT/mkl/latest/lib/cmake/mkl
make -j8
make install

2.3.2. Windows CMake install

On Windows, a Visual Studio command line build can be done like this:

cmake .. ^
    -G "NMake Makefiles" ^
    -DCMAKE_INSTALL_PREFIX=..\install ^
    -DCMAKE_BUILD_TYPE=Release ^
    -DMKL_DIR="%ONEAPI_ROOT%\latest\lib\cmake\mkl"
nmake
nmake install

2.3.3. SYCL matrix operations

There is an experimental computation mode for SYCL matrix operations, that can be enabled in the mlp tool using the flag --computation=sycl. This is only supported on Linux in combination with the Intel icpx compiler. It can be enabled using the following build command:

source $ONEAPI_ROOT/setvars.sh
cmake .. \
-DCMAKE_C_COMPILER=$ONEAPI_ROOT/compiler/latest/bin/icx \
-DCMAKE_CXX_COMPILER=$ONEAPI_ROOT/compiler/latest/bin/icpx \
-DCMAKE_INSTALL_PREFIX=../install \
-DCMAKE_BUILD_TYPE=RELEASE \
-DENABLE_SYCL=ON
The initial tests with SYCL matrix operations on CPU have a rather disappointing performance.

2.3.4. Unit tests

The unit tests can be run using the command

ctest -R nerva

2.3.5. icpx + clang compilers

If you’re using the Clang or ICX (Intel’s icpx) compilers, be aware of a longstanding issue between MKL and Eigen. More information is available in the following discussions:

To work around this issue in Nerva, the symbol EIGEN_COLPIVOTINGHOUSEHOLDERQR_LAPACKE_H has been defined. This prevents the inclusion of the problematic header file and resolves the compatibility issue.

2.4. Installation using B2

The Nerva-Rowwise C++ Library can also be built using the B2 build system.

The following environment variables must be set:

Environment variable Description

ONEAPI_ROOT

Path to the OneAPI installation directory.

DOCTEST_INCLUDE_DIR

Path to the doctest include directory.

EIGEN_INCLUDE_DIR

Path to the Eigen include directory.

FMT_INCLUDE_DIR

Path to the FMT include directory.

LYRA_INCLUDE_DIR

Path to the Lyra include directory.

PYBIND11_INCLUDE_DIR

Path to the pybind11 include directory.

The Python include directory is currently hard coded to /usr/include/python3.12 in the file Jamroot This may have to be changed to the correct location on your system.

2.4.1. Building the tools

The tools can be installed with a command like this, assuming that the g++-14 compiler has been configured.

cd tools
b2 gcc-14 link=static release -j8

This will install the tools in the directory install/bin.

2.4.2. Running the tests

The unit tests can be run with a command like this:

cd tests
b2 gcc-14 link=static release -j8

3. Command line tools

The following command line tools are available. They can be found in the tools directory.

Tool Description

mlp

A tool for training multilayer perceptrons.

mkl

A tool for benchmarking sparse and dense matrix products using the Intel MKL library.

inspect_npz

A tool for inspecting the contents of a file in NumPy NPZ format.

3.1. The tool mlp

The tool mlp can be used for training multilayer perceptrons. An example invocation of the mlp tool is

../install/bin/mlp \
    --layers="ReLU;ReLU;Linear" \
    --layer-sizes="3072;1024;1024;10" \
    --layer-weights=XavierNormal \
    --optimizers="Nesterov(0.9)" \
    --loss=SoftmaxCrossEntropy \
    --learning-rate=0.01 \
    --epochs=100 \
    --batch-size=100 \
    --threads=12 \
    --overall-density=0.05 \
    --dataset=$dataset \
    --seed=123

This will train a CIFAR-10 model using an MLP consisting of three sparse layers with activation functions ReLU, ReLU and no activation. A script prepare_data.py is available in the data directory that can be used to download the dataset, flatten it and store it in .npz format. See the section Preparing data for details.

The output may look like this:

=== Nerva c++ model ===
Sparse(input_size=3072, output_size=1024, density=0.042382877, optimizer=Nesterov(0.90000), activation=ReLU())
Sparse(input_size=1024, output_size=1024, density=0.06357384, optimizer=Nesterov(0.90000), activation=ReLU())
Dense(input_size=1024, output_size=10, optimizer=Nesterov(0.90000), activation=NoActivation())
loss = SoftmaxCrossEntropyLoss()
scheduler = ConstantScheduler(lr=0.01)
layer densities: 133325/3145728 (4.238%), 66662/1048576 (6.357%), 10240/10240 (100%)

epoch   0 lr: 0.01000000  loss: 2.30284437  train accuracy: 0.07904000  test accuracy: 0.08060000 time: 0.00000000s
epoch   1 lr: 0.01000000  loss: 2.14723837  train accuracy: 0.21136000  test accuracy: 0.21320000 time: 2.74594253s
epoch   2 lr: 0.01000000  loss: 1.91454245  train accuracy: 0.29976000  test accuracy: 0.29940000 time: 2.76982510s
epoch   3 lr: 0.01000000  loss: 1.78019225  train accuracy: 0.35416000  test accuracy: 0.35820000 time: 2.69554319s
epoch   4 lr: 0.01000000  loss: 1.68071066  train accuracy: 0.39838000  test accuracy: 0.40000000 time: 2.68532307s
epoch   5 lr: 0.01000000  loss: 1.59761505  train accuracy: 0.42820000  test accuracy: 0.43060000 time: 3.02131606s

3.1.1. Command line options of mlp

This section gives an overview of the command line interface of the mlp tool.

Parameters lists

Some command line options take a list of items as input, for example a list of layers. These items must be separated by semicolons, e.g. --layers="ReLU;ReLU;Linear".

Named parameters

Some of the items take parameters. For this we use a function call syntax with named parameters, e.g. AllReLU(alpha=0.3). In case that there is only one parameter, the name may be omitted: AllReLU(0.3). If the parameters have default values, they may be omitted. For example, SReLU or SReLU() is equivalent to SReLU(al=0,tl=0,ar=0,tr=1).

General options
  • -?, -h, --help Display help information.

  • --verbose, -v Show verbose output.

  • --debug, -d Show debug output. This prints batches, weight matrices, bias vectors, gradients etc.

Random generator options
  • --seed <value> A seed value for the random generator.

Layer configuration options
  • --layers <value> A semicolon separated list of layers. For example, --layers=ReLU;AllReLU(0.3);Linear is used to specify a neural network with three layers with an ReLU, AllReLU and no activation function. The following layers are supported:

Specification Description

Linear

Linear layer without activation

ReLU

Linear layer with ReLU activation

Sigmoid

Linear layer with sigmoid activation

Softmax

Linear layer with softmax activation

LogSoftmax

Linear layer with log-softmax activation

HyperbolicTangent

Linear layer with hyperbolic tangent activation

AllReLU(<alpha>)

Linear layer with All ReLU activation

SReLU(<al>,<tl>,<ar>,<tr>)

Linear layer with SReLU activation. The default value for the parameters are al=0, tl=0, ar=0, tr=1. For these values SReLU coincides with ReLU.

TReLU(<epsilon>)

Linear layer with trimmed ReLU activation

BatchNormalization

Batch normalization layer

  • --layer-sizes <value> A semicolon-separated list of the sizes of linear layers of the multilayer perceptron. For example, --layer-sizes=3072;1024;512;10 specifies the sizes of three linear layers. The first one has 3072 inputs and 1024 outputs, the second one 1024 inputs and 512 outputs, and the third one has 512 inputs and 10 outputs.

  • --densities <value> A comma-separated list of linear layer densities. By default, all linear layers are dense (i.e. have density 1.0). If only one value is specified, it will be used for all linear layers.

  • --dropouts <value> A comma-separated list of dropout rates of linear layers. By default, all linear layers have no dropout (i.e. dropout rate 0.0).

  • --overall-density <value> The overall density of the linear layers. This value should be in the interval \([0,1\)], and it specifies the fraction of the total number of weights that is non-zero. The overall density is not distributed evenly over the layers. Instead, small layers will be assigned a higher density than large layers.

  • --layer-weights <value> The generator that is used for initializing the weights of the linear layers. The following weight generators are supported:

Specification Description

XavierNormal

Xavier Glorot weights (normal distribution)

XavierUniform

Xavier Glorot weights (uniform distribution)

HeNormal

Kaiming He weights (normal distribution)

HeUniform

Kaiming He weights (uniform distribution)

Normal

Normal distribution

Uniform

Uniform distribution

Zero

All weights are zero (N.B. This is not recommended for training)

Training configuration options
  • --epochs <value> The number of epochs of the training (default: 100).

  • --batch-size <value> The batch size of the training.

  • --no-shuffle Do not shuffle the dataset during training.

  • --no-statistics Do not display intermediate statistics during training.

  • --optimizers <value> A semicolon-separated list of optimizers used for linear and batch normalization layers. The following optimizers are supported:

Specification Description

GradientDescent

Gradient descent optimization

Momentum(mu)

Momentum optimization with momentum parameter mu

Nesterov(mu)

Nesterov optimization with momentum parameter mu

  • --learning-rate <value> A semicolon-separated list of learning rate schedulers of linear and batch normalization layers. If only one learning rate scheduler is specified, it is applied to all layers. The following learning rate schedulers are supported:

Specification Description

Constant(lr)

Constant learning rate lr

TimeBased(lr, decay)

Adaptive learning rate with decay

StepBased(lr, drop_rate, change_rate)

Step based learning rate where the learning rate is regularly dropped to a lower value

MultistepLR(lr, milestones, gamma)

Step based learning rate, where milestones contains the epoch numbers in which the learning rate is dropped.

Exponential(lr, change_rate)

Exponentially decreasing learning rate

  • --loss <value> The loss function used for training the multilayer perceptron. The following loss functions are supported:

Specification Description

SquaredError

Squared error loss.

CrossEntropy

Cross entropy loss (N.B. prone to numerical problems!)

LogisticCrossEntropy

Logistic cross entropy loss.

SoftmaxCrossEntropy

Softmax cross entropy loss. Matches CrossEntropy of PyTorch. Suitable for classification experiments.

NegativeLogLikelihood

Negative log likelihood loss.

  • --load-weights <value> Load weights and biases from a dictionary in NumPy .npz format. The weight matrices should be stored with keys W1,W2,…​ and the bias vectors with keys b1,b2,…​. See also numpy.lib.format.

  • --save-weights <value> Save weights and biases to a dictionary in NumPy .npz format. The weight matrices are stored with keys W1,W2,…​ and the bias vectors with keys b1,b2,…​. See also numpy.lib.format.

Dataset options
  • --load-data <value> Load the dataset from a file in NumPy .npz format. See

  • --save-data <value> Save the dataset to a file in NumPy .npz format. See

  • --preprocessed <directory> A directory containing datasets named epoch0.npz, epoch1.npz, …​ See I/O for information about the .npz format. This can for example be used to precompute augmented datasets. A script generate_cifar10_augmented_datasets.py is available for creating augmented CIFAR-10 datasets.

  • --cifar10 <directory> Specify the directory where the binary version of the CIFAR-10 dataset is stored. This is a directory with subdirectory cifar-10-batches-bin for the C++ version or cifar-10-batches-py for the Python version of the dataset.

  • --mnist <directory> Specify the directory where the MNIST dataset is stored. It should be stored in a file named mnist.npz, that can be downloaded here.

  • --normalize Normalize the dataset.

  • --generate-data <name> Specify a synthetic dataset that is generated on the fly. The following datasets are supported:

Specification Description Features Classes

checkerboard

A checkerboard pattern, see also checkerboard.

2

2

mini

A dataset with random values.

3

2

  • --dataset-size <value> The size of the generated dataset (default: 1000). --save-weights for information about the format. --load-weights for information about the format.

Pruning and growing options
  • --prune <strategy> The strategy used for pruning sparse weight matrices. The following strategies are supported:

Specification Description

Magnitude(<drop_fraction>)

Magnitude based pruning. A fraction of the weights with the smallest absolute value is pruned.

SET(<drop_fraction>)

SET pruning. Positive and negative weights are treated separately. Both a fraction of the positive and a fraction of the negative weights is pruned.

Threshold(<threshold>)

Weights with absolute value below the given threshold are pruned.

  • --grow <strategy> The strategy used for growing in sparse weight matrices. The following strategies are supported:

Specification Description

Random

Weights are added at random positions (outside the support of the sparse matrix).

  • --grow-weights <value> The weight generation function used for growing weights. See --layer-weights for supported values. The default value is Xavier.

Computation options
  • --computation <value> The computation mode that is used for backpropagation. This is used for performance measurements. The following computation modes are available:

Specification Description

eigen

All computations are done using the Eigen library. Note that by setting the flag EIGEN_USE_MKL_ALL Eigen will attempt to use MKL library calls.

mkl

Some computations are implemented using MKL functions.

blas

Some computations are implemented using BLAS functions.

sycl

Some computations are implemented using SYCL functions.

  • --clip <value> A threshold value used to set small elements of weight matrices to zero.

  • --threads <value> The number of threads used by the MKL and OMP libraries.

  • --gradient-step <value> If this value is set, gradient checks are performed with the given step size. This is very slow, and should only be used for debugging.

Miscellaneous options
  • --info Print detailed information about the multilayer perceptron.

  • --timer Print timer messages. The following values are supported:

Value Description

disabled

No timing information is displayed

brief

At the end, a report with accumulated timing measurements will be displayed

full

In addition, individual timing measurements will be displayed

  • --precision <value> The precision used for printing matrix elements.

  • --edgeitems <value> The edgeitems used for printing matrices. This sets the number of border rows and columns that are printed.

3.2. The tool mkl

The tool mkl is used for benchmarking sparse and dense matrix products. An example of running the mkl tool is

../install/bin/mkl --arows=1000 --acols=1000 --brows=1000 --threads=12 --algorithm=sdd --repetitions=3 --densities="0.5,0.2,0.1,0.05"

This will use various algorithms to calculate the product A = B * C with A a sparse matrix and B and C dense matrices.

The output may look like this

--- testing A = B * C (sdd_product) ---
A = 1000x1000 sparse
B = 1000x1000 dense  layout=column-major
C = 1000x1000 dense  layout=column-major

density(A) = 0.5
 0.01147s ddd_product A=column-major, B=column-major, C=column-major
 0.00793s ddd_product A=column-major, B=column-major, C=column-major
 0.00854s ddd_product A=column-major, B=column-major, C=column-major
 0.04049s sdd_product(batchsize=5, density(A)=0.499599, B=column-major, C=column-major)
 0.01998s sdd_product(batchsize=5, density(A)=0.499599, B=column-major, C=column-major)
 0.01178s sdd_product(batchsize=5, density(A)=0.499599, B=column-major, C=column-major)
 0.01114s sdd_product(batchsize=10, density(A)=0.499599, B=column-major, C=column-major)
 0.01099s sdd_product(batchsize=10, density(A)=0.499599, B=column-major, C=column-major)
 0.00666s sdd_product(batchsize=10, density(A)=0.499599, B=column-major, C=column-major)
 0.00375s sdd_product(batchsize=100, density(A)=0.499599, B=column-major, C=column-major)
 0.00734s sdd_product(batchsize=100, density(A)=0.499599, B=column-major, C=column-major)
 0.00332s sdd_product(batchsize=100, density(A)=0.499599, B=column-major, C=column-major)
 0.20097s sdd_product_forloop_eigen(density(A)=0.499599, B=column-major, C=column-major)
 0.19891s sdd_product_forloop_eigen(density(A)=0.499599, B=column-major, C=column-major)
 0.19893s sdd_product_forloop_eigen(density(A)=0.499599, B=column-major, C=column-major)
 0.23286s sdd_product_forloop_mkl(density(A)=0.499599, B=column-major, C=column-major)
 0.23298s sdd_product_forloop_mkl(density(A)=0.499599, B=column-major, C=column-major)
 0.23281s sdd_product_forloop_mkl(density(A)=0.499599, B=column-major, C=column-major)

Note that the very first invocation of an MKL function can be slow.

3.3. The tool inspect_npz

The tool inspect_npz is a simple tool to show the contents of a file in NumPy NPZ format. The tool mlp uses this format to load and save datasets, and to load and save weight matrices + bias vectors of linear layers. The output may look like this:

W1 (1024x3072) norm = 0.03827324
   [-0.00850412,  0.00766624, -0.00379110,  ..., -0.02755435,  0.00842837,  0.00725122]
   [ 0.03012662, -0.01122476,  0.03765349,  ...,  0.02167689, -0.03734717, -0.01376905]
   [-0.03415587, -0.00498827,  0.00635345,  ..., -0.03036389, -0.01967963,  0.03339641]
   ...,
   [ 0.02993325, -0.00795984,  0.00388659,  ...,  0.01343446, -0.01625269,  0.00398590]
   [ 0.03800971, -0.01185982, -0.00944855,  ...,  0.02083720, -0.00217844,  0.02398606]
   [-0.00879488, -0.01937520, -0.02830209,  ...,  0.03606736, -0.01065827,  0.03293588]
b1= (1024)
   [-0.01735129, -0.01381215,  0.01708755,  ..., -0.01117092, -0.00264273, -0.00976263]
W2 (512x1024) norm = 0.06249978
   [-0.02440289,  0.01362467,  0.03782336,  ...,  0.01342138, -0.01060697, -0.05055390]
   [ 0.06187645, -0.00854158,  0.02849235,  ...,  0.05861567,  0.00708143, -0.06170959]
   [-0.00756755,  0.04718670, -0.02303848,  ...,  0.01513476,  0.00205931,  0.05441900]
   ...,
   [-0.04223771,  0.00852190, -0.00465803,  ...,  0.03600422,  0.00484904, -0.02281546]
   [ 0.03211500, -0.02740303, -0.04652309,  ...,  0.00307061,  0.02427530, -0.02245107]
   [ 0.05210501, -0.00423148, -0.00633851,  ...,  0.02453317,  0.02723335,  0.03589169]
b2= (512)
   [-0.01871627,  0.01150464, -0.01767523,  ..., -0.00220927, -0.01791467, -0.02616516]
W3 (10x512) norm = 0.10718583
   [-0.03256247, -0.09669271, -0.06564181,  ...,  0.00394586, -0.02191557,  0.08828022]
   [-0.09986399, -0.03712691,  0.04332626,  ..., -0.02475236, -0.07359495, -0.09421349]
   [-0.03308030,  0.01280271,  0.09341474,  ..., -0.03470980, -0.03936023,  0.02204999]
   ...,
   [-0.10063093, -0.04294113, -0.04938528,  ...,  0.08151620, -0.00991420,  0.09686699]
   [ 0.04347997, -0.08046009,  0.02828473,  ...,  0.06899156, -0.08314995,  0.07181197]
   [ 0.00575207, -0.06347645, -0.07257712,  ..., -0.00293436, -0.00266003, -0.08468610]
b3= (10)
   [-0.02117447, -0.00115431, -0.03672279,  ..., -0.02902718, -0.02759255,  0.03007624]

4. Overview of the code

This section gives an overview of the C++ code in the Nerva-Rowwise C++ Library, and some information that is needed for understanding the code.

4.1. Number type

The Nerva-Rowwise C++ Library uses a type called scalar as its number type. By default, it is defined as a 32-bit float. It is possible to change this by defining the symbol NERVA_USE_DOUBLE, in which case 64 bit doubles are used. The corresponding code is

python inspect_npz.py ../../data/cifar10-flattened.npz

A more generic approach would be to add a template argument for the number type to most classes and functions. This has been tried in the past, but since it had a negative impact on the readability of the code, it was later removed.

4.2. Header files

The most important header files in are given in the table below.

Header file Description

multilayer_perceptron.h

A multilayer perceptron class.

layers.h

Neural network layers.

activation_functions.h

Activation functions.

loss_functions.h

Loss functions.

weights.h

Weight initialization functions.

optimizers.h

Optimizer functions, for updating neural network parameters using their gradients.

learning_rate_schedulers.h

Learning rate schedulers, for updating the learning rate during training.

training.h

A stochastic gradient descent algorithm.

prune.h

Algorithms for pruning sparse weight matrices. This is used for dynamic sparse training.

grow.h

Algorithms for (re-)growing sparse weights. This is used for dynamic sparse training.

4.3. Classes

4.3.1. Class multilayer_perceptron

A multilayer perceptron (MLP) is modeled using the class multilayer_perceptron. It contains a list of layers, and has member functions feedforward, backpropagate and optimize that can be used for training the neural network. Constructing an MLP can be done manually, as is illustrated in the tests:

void construct_mlp(multilayer_perceptron& M,
                   const eigen::matrix& W1,
                   const eigen::matrix& b1,
                   const eigen::matrix& W2,
                   const eigen::matrix& b2,
                   const eigen::matrix& W3,
                   const eigen::matrix& b3,
                   const std::vector<long>& sizes,
                   long N
                  )
{
  long batch_size = N;

  auto layer1 = std::make_shared<relu_layer<eigen::matrix>>(sizes[0], sizes[1], batch_size);
  M.layers.push_back(layer1);
  auto optimizer_W1 = std::make_shared<gradient_descent_optimizer<eigen::matrix>>(layer1->W, layer1->DW);
  auto optimizer_b1 = std::make_shared<gradient_descent_optimizer<eigen::matrix>>(layer1->b, layer1->Db);
  layer1->optimizer = make_composite_optimizer(optimizer_W1, optimizer_b1);
  layer1->W = W1;
  layer1->b = b1;

  auto layer2 = std::make_shared<relu_layer<eigen::matrix>>(sizes[1], sizes[2], batch_size);
  M.layers.push_back(layer2);
  auto optimizer_W2 = std::make_shared<gradient_descent_optimizer<eigen::matrix>>(layer2->W, layer2->DW);
  auto optimizer_b2 = std::make_shared<gradient_descent_optimizer<eigen::matrix>>(layer2->b, layer2->Db);
  layer2->optimizer = make_composite_optimizer(optimizer_W2, optimizer_b2);
  layer2->W = W2;
  layer2->b = b2;

  auto layer3 = std::make_shared<linear_layer<eigen::matrix>>(sizes[2], sizes[3], batch_size);
  M.layers.push_back(layer3);
  auto optimizer_W3 = std::make_shared<gradient_descent_optimizer<eigen::matrix>>(layer3->W, layer3->DW);
  auto optimizer_b3 = std::make_shared<gradient_descent_optimizer<eigen::matrix>>(layer3->b, layer3->Db);
  layer3->optimizer = make_composite_optimizer(optimizer_W3, optimizer_b3);
  layer3->W = W3;
  layer3->b = b3;
}

This will create an MLP with three linear layers that have weight matrices W1, W2, W3 and bias vectors b1, b2, b3. The parameter sizes contains the input and output sizes of the three layers. Note that the layers and the optimizers are stored using smart pointers. This is done to facilitate the Nerva Python interface. Constructing an MLP like this is quite verbose. An easier way to construct MLPs is provided by the function make_layers, that offers a string based interface.

  multilayer_perceptron M;
  std::vector<std::string> layer_specifications = {"ReLU", "ReLU", "Linear"};
  std::vector<std::size_t> linear_layer_sizes = {2, 2, 2, 2};
  std::vector<double> linear_layer_densities = {0.6, 0.4, 1.0};
  std::vector<double> linear_layer_dropouts = {0.0, 0.0, 0.0};
  std::vector<std::string> linear_layer_weights = {"XavierNormal", "XavierUniform", "HeNormal"};
  std::vector<std::string> optimizers = {"Nesterov(0.9)", "Momentum(0.9)", "GradientDescent"};
  long batch_size = 5;
  std::mt19937 rng{std::random_device{}()};
  M.layers = make_layers(layer_specifications,
                         linear_layer_sizes,
                         linear_layer_densities,
                         linear_layer_dropouts,
                         linear_layer_weights,
                         optimizers,
                         batch_size,
                         rng);

Note that the random number generator argument is used for the generation of the weights. See mlp command line options for an overview of the supported string arguments.

4.3.2. Class neural_network_layer

The class neural_network_layer is the base class of all neural network layers. It has attributes for the input matrix X and the corresponding gradient DX. Usually a layer has some additional parameters that can be learned by training the neural network. The most important member functions of neural_network_layer are given below.

  /// Do a feedforward step given the input `X`, and store the output in `result`.
  virtual void feedforward(eigen::matrix& result) = 0;

  /// Do a backpropagate step given the output `Y`, and its gradient `DY`.
  /// This will calculate the gradient `DX` of the input `X`, and the gradients
  /// of the layer parameters.
  virtual void backpropagate(const eigen::matrix& Y, const eigen::matrix& DY) = 0;

  /// Update the layer parameters using their gradients.
  virtual void optimize(scalar eta) = 0;

4.3.3. Class loss_function

The class loss_function is the base class of all loss functions. Although a loss function is similar to a layer, the interface is different:

  /// Calculate the loss for output `Y` and target `T`.
  [[nodiscard]] virtual scalar value(const eigen::matrix& Y, const eigen::matrix& T) const = 0;

  /// Calculate the gradient of the loss for output `Y` and target `T`.
  [[nodiscard]] virtual eigen::matrix gradient(const eigen::matrix& Y, const eigen::matrix& T) const = 0;

So instead of the names feedforward and backpropagate, we use value and gradient.

There are five loss functions available:

  • squared_error_loss

  • cross_entropy_loss

  • logistic_cross_entropy_loss

  • softmax_cross_entropy_loss

  • negative_log_likelihood_loss

4.3.4. Activation functions

Currently, there is no common base class for activation functions. For example, the ReLU activation function is implemented like this:

struct relu_activation
{
  template <typename Matrix>
  auto operator()(const Matrix& X) const
  {
    return Relu(X);
  }

  template <typename Matrix>
  auto gradient(const Matrix& X) const
  {
    return Relu_gradient(X);
  }

  [[nodiscard]] std::string to_string() const
  {
    return "ReLU()";
  }
};
Currently, there are some inconsistencies between the interfaces of layers, loss functions and activation functions. This may be changed in the future.

4.4. Training a neural network

The class stochastic_gradient_descent_algorithm can be used to train a neural network. It takes as input a multilayer perceptron, a dataset, a loss function, a learning rate scheduler, and a struct containing options like the number of epochs. The main loop looks like this:

for (unsigned int epoch = 0; epoch < options.epochs; ++epoch)
{
  on_start_epoch(epoch);

  eigen::matrix DY(L, options.batch_size);

  for (long batch_index = 0; batch_index < K; batch_index++)
  {
    on_start_batch(batch_index);

    eigen::eigen_slice batch(I.begin() + batch_index * options.batch_size, options.batch_size);
    auto X = data.Xtrain(batch, Eigen::indexing::all);
    auto T = data.Ttrain(batch, Eigen::indexing::all);

    M.feedforward(X, Y);
    DY = loss->gradient(Y, T) / options.batch_size;
    M.backpropagate(Y, DY);
    M.optimize(learning_rate);

    on_end_batch(batch_index);
  }
  on_end_epoch(epoch);
}

In every epoch, the dataset is divided into K batches. A batch X consists of batch_size examples, with corresponding targets T (i.e. the expected outputs). Each batch goes through the three steps of stochastic gradient descent:

  1. feedforward: Given an input batch X and the neural network parameters Θ, compute the output Y.

  2. backpropagation: Given output Y corresponding to input X and targets T, compute the gradient DY of Y with respect to the loss function. Then from Y and DY, compute the gradient of the parameters Θ.

  3. optimization: Given the gradient , update the parameters Θ.

4.4.1. Event functions

The algorithm uses a number of event functions:

Event Description

on_start_training

Is called at the start of the training

on_end_training

Is called at the end of the training

on_start_epoch

Is called at the start of each epoch

on_end_epoch

Is called at the end of each epoch

on_start_batch

Is called at the start of each batch

on_end_batch

Is called at the end of each batch

The user can respond to these events by deriving from the class stochastic_gradient_descent_algorithm. Typical use cases for these event functions are the following:

  • Update the learning rate.

  • Renew dropout masks.

  • Prune and grow sparse weights.

Such operations are typically done after each epoch or after a given number of batches.

The following actions take place at the start of every epoch:

  • A preprocessed dataset is loaded from disk, which is done to avoid the expensive computation of augmented data at every epoch.

  • The learning rate is updated if a learning rate scheduler is set.

  • Dropout masks are renewed.

  • Sparse weight matrices are pruned and regrown if a regrow function is specified.

  • Small weights in the subnormal range are clipped to zero if the clip option is set.

An example can be found in the tool mlp:

    void on_start_epoch(unsigned int epoch) override
    {
      if (epoch > 0 && !reload_data_directory.empty())
      {
        reload_data(epoch);
      }

      if (lr_scheduler)
      {
        learning_rate = lr_scheduler->operator()(epoch);
      }

      if (epoch > 0)
      {
        renew_dropout_masks(M, rng);
      }

      if (epoch > 0 && regrow_function)
      {
        (*regrow_function)(M);
      }

      if (epoch > 0 && options.clip > 0)
      {
        M.clip(options.clip);
      }
    }

4.5. Timers

The Nerva-Rowwise C++ Library has two timer classes, defined in the header file timer.h:

class description

map_timer

A timer that can be used for timing different operations. Each operation is identified using a name, and for each name all timing results are stored.

resumable_timer

This is a map_timer that can be suspended.

The Nerva-Rowwise C++ Library uses a predefined timer nerva_timer that is defined in the header file nerva_timer.h. The mlp tool uses this timer to keep track of the time spent on feedforward, backpropagate and optimize calls during training, and optionally of other computations. Each computation is identified with a unique name. If the option --timer=brief is set, the accumulated times of all computations will be displayed:

--- timing results ---
backpropagate        = 5.6162
batchnorm1           = 0.0895
batchnorm2           = 0.0418
batchnorm3           = 0.1030
batchnorm4           = 0.8833
feedforward          = 1.4613
optimize             = 0.1137
total time           = 8.3089

For fine-grained measurements two macros NERVA_TIMER_START and NERVA_TIMER_STOP are defined for starting and stopping the timer. An example can be found in the backpropagate call of batch normalization layers:

  void backpropagate(const eigen::matrix& Y, const eigen::matrix& DY) override
  {
    using eigen::diag;
    using eigen::hadamard;
    using eigen::row_repeat;
    using eigen::columns_sum;
    using eigen::identity;
    using eigen::ones;
    using eigen::inv_sqrt;
    auto N = X.rows();

    NERVA_TIMER_START("batchnorm1")
    DZ = hadamard(row_repeat(gamma, N), DY);
    NERVA_TIMER_STOP("batchnorm1")

    NERVA_TIMER_START("batchnorm2")
    Dbeta = columns_sum(DY);
    NERVA_TIMER_STOP("batchnorm2")

    NERVA_TIMER_START("batchnorm3")
    Dgamma = columns_sum(hadamard(Z, DY));
    NERVA_TIMER_STOP("batchnorm3")

    NERVA_TIMER_START("batchnorm4")
    DX = hadamard(row_repeat(inv_sqrt_Sigma / N, N), (N * identity<eigen::matrix>(N) - ones<eigen::matrix>(N, N)) * DZ - hadamard(Z, row_repeat(diag(Z.transpose() * DZ).transpose(), N)));
    NERVA_TIMER_STOP("batchnorm4")
  }

To avoid any overhead, these macros can be disabled by defining the symbol NERVA_DISABLE_TIMER. If the option --timer=full is set, all individual timings will be displayed:

    feedforward-1    0.001753s
     batchnorm1-1    0.000096s
     batchnorm2-1    0.000043s
     batchnorm3-1    0.000099s
     batchnorm4-1    0.001125s
  backpropagate-1    0.006895s
       optimize-1    0.000184s
    feedforward-2    0.001773s
     batchnorm1-2    0.000117s
     batchnorm2-2    0.000051s
     batchnorm3-2    0.000124s
     batchnorm4-2    0.001280s
  backpropagate-2    0.006300s
       optimize-2    0.000115s
    feedforward-3    0.001471s
     batchnorm1-3    0.000071s
     batchnorm2-3    0.000023s
     batchnorm3-3    0.000088s
     batchnorm4-3    0.000667s

The calls are numbered, to make it easy to compare different runs. Unsurprisingly, the timing output shows that the computation labeled batchnorm4 takes the majority of time.

5. Matrix operations

The most important part of the implementation of neural networks consists of matrix operations. In the implementation of activation functions, loss functions and neural network layers, many different matrix operations are needed. In Nerva a structured approach is followed to implement these components. All equations are expressed in terms of the matrix operations in the table below.

Table 1. matrix operations
Operation Code Definition

\(0_{m}\)

zeros(m)

\(m \times 1\) column vector with elements equal to 0

\(0_{mn}\)

zeros(m, n)

\(m \times n\) matrix with elements equal to 0

\(1_{m}\)

ones(m)

\(m \times 1\) column vector with elements equal to 1

\(1_{mn}\)

ones(m, n)

\(m \times n\) matrix with elements equal to 1

\(\mathbb{I}_n\)

identity(n)

\(n \times n\) identity matrix

\(X^\top\)

X.transpose()

transposition

\(cX\)

c * X

scalar multiplication, \(c \in \mathbb{R}\)

\(X + Y\)

X + Y

addition

\(X - Y\)

X - Y

subtraction

\(X \cdot Z\)

X * Z

matrix multiplication, also denoted as \(XZ\)

\(x^\top y~\) or \(~x y^\top\)

dot(x,y)

dot product, \(x,y \in \mathbb{R}^{m \times 1}\) or \(x,y \in \mathbb{R}^{1 \times n}\)

\(X \odot Y\)

hadamard(X,Y)

element-wise product of \(X\) and \(Y\)

\(\mathsf{diag}(X)\)

diag(X)

column vector that contains the diagonal of \(X\)

\(\mathsf{Diag}(x)\)

Diag(x)

diagonal matrix with \(x\) as diagonal, \(x \in \mathbb{R}^{1 \times n}\) or \(x \in \mathbb{R}^{m \times 1}\)

\(1_m^\top \cdot X \cdot 1_n\)

elements_sum(X)

sum of the elements of \(X\)

\(x \cdot 1_n^\top\)

column_repeat(x, n)

\(n\) copies of column vector \(x \in \mathbb{R}^{m \times 1}\)

\(1_m \cdot x\)

row_repeat(x, m)

\(m\) copies of row vector \(x \in \mathbb{R}^{1 \times n}\)

\(1_m^\top \cdot X\)

columns_sum(X)

\(1 \times n\) row vector with sums of the columns of \(X\)

\(X \cdot 1_n\)

rows_sum(X)

\(m \times 1\) column vector with sums of the rows of \(X\)

\(\max(X)_{col}\)

columns_max(X)

\(1 \times n\) row vector with maximum values of the columns of \(X\)

\(\max(X)_{row}\)

rows_max(X)

\(m \times 1\) column vector with maximum values of the rows of \(X\)

\((1_m^\top \cdot X) / n\)

columns_mean(X)

\(1 \times n\) row vector with mean values of the columns of \(X\)

\((X \cdot 1_n) / m\)

rows_mean(X)

\(m \times 1\) column vector with mean values of the rows of \(X\)

\(f(X)\)

apply(f, X)

element-wise application of \(f: \mathbb{R} \rightarrow \mathbb{R}\) to \(X\)

\(e^X\)

exp(X)

element-wise application of \(f: x \rightarrow e^x\) to \(X\)

\(\log(X)\)

log(X)

element-wise application of the natural logarithm \(f: x \rightarrow \ln(x)\) to \(X\)

\(1 / X\)

reciprocal(X)

element-wise application of \(f: x \rightarrow 1/x\) to \(X\)

\(\sqrt{X}\)

sqrt(X)

element-wise application of \(f: x \rightarrow \sqrt{x}\) to \(X\)

\(X^{-1/2}\)

inv_sqrt(X)

element-wise application of \(f: x \rightarrow x^{-1/2}\) to \(X\)

\(\log(\sigma(X))\)

log_sigmoid(X)

element-wise application of \(f: x \rightarrow \log(\sigma(x))\) to \(X\)

Using this table leads to concise and uniform code. For example, the backpropagation implementation of a softmax layer looks like this:

        DZ = hadamard(Y, DY - column_repeat(diag(Y * DY.transpose()), K));
        DW = DZ.transpose() * X;
        Db = columns_sum(DZ);
        DX = DZ * W;

See the paper Batch Matrix-form Equations and Implementation of Multilayer Perceptrons for an overview of how these matrix operations are used.

5.1. Eigen library

The Nerva-Rowwise C++ Library uses the Eigen library for representing matrices. The matrix operations in table matrix operations have been implemented using Eigen, see the file matrix_operations.h.

5.2. MKL library

Using the Eigen library alone is not sufficient for obtaining high performance. Therefore, the Nerva-Rowwise C++ Library uses the Intel Math Kernel library (MKL) as a backend. The Eigen library supports MKL by means of the compiler flag EIGEN_USE_MKL_ALL, see also TopicUsingIntelMKL.html. Note that the MKL library is included in the Intel oneAPI base toolkit.

The MKL library supports a number of highly efficient, but extremely low-level interfaces for matrix operations. See blas-and-sparse-blas-routines.html for an overview. The Nerva-Rowwise C++ Library contains matrix classes that hide those low-level details from the user. The table below gives an overview of them.

Header file Description

mkl_dense_vector.h

A class dense_vector_view that wraps a raw pointer to a vector.

mkl_dense_matrix.h

A class dense_matrix_view that wraps a raw pointer to a matrix, and a class dense_matrix that stores a dense matrix.

mkl_sparse_matrix.h

A class sparse_matrix_csr [1] that stores a sparse matrix in compressed sparse row (CSR) format.

In C++23 the implementation of sparse matrices in CSR format can be greatly simplified, as shown by Ben Brock.
The sparse CSR matrix functions in the MKL library take an argument of the opaque type sparse_matrix_t. It stores unspecified properties of a sparse matrix. This parameter is poorly documented, and it is unknown when this parameter should be recalculated. For safety reasons, the Nerva-Rowwise C++ Library recalculates this parameter after every change to a sparse matrix, which may cause some inefficiencies. See also the function sparse_matrix_csr::construct_csr and mkl_sparse_?_create_csr.

6. I/O

The Nerva-Rowwise C++ Library has support for reading and writing datasets and weights + biases of a model in NumPy NPZ format. This format is used for portability between C++ and Python implementations. There is no support yet for storing a complete model, including its architecture.

6.1. NPZ format

The default storage format used in the Nerva libraries is the NumPy NPZ format, see numpy.lib.format. The reason for choosing this format is portability between C++ and Python implementations. A file in .npz format can be used to store a dictionary of arrays in a compressed format.

6.2. Preparing data

The mlp tool requires training and testing data to be stored in .npz format. To help with this, a script is provided to download and preprocess datasets commonly used in experiments, including MNIST and CIFAR-10.

The script is located at ../../../data/prepare_data.py and can be run from the command line.

6.2.1. MNIST

To download and prepare the MNIST dataset, run:

python prepare_data.py --dataset=mnist --download

This will:

  • Download mnist.npz from the official source if not already present.

  • Create a flattened and normalized version of the dataset as mnist-flattened.npz.

The output file contains:

  • Xtrain, Xtest: flattened and normalized image data

  • Ttrain, Ttest: corresponding label vectors

6.2.2. CIFAR-10

To download and prepare the CIFAR-10 dataset, run:

python prepare_data.py --dataset=cifar10 --download

This will:

As with MNIST, the .npz file will contain:

  • Xtrain, Xtest: flattened image arrays with pixel values normalized to [0, 1]

  • Ttrain, Ttest: integer class labels

6.2.3. Reusing Existing Files

If the required .npz files already exist, the script will detect this and skip reprocessing. You can safely rerun the script without overwriting files.

6.2.4. Help

For help with usage, run:

python data/prepare_data.py --help

This displays all options, including how to customize the output directory.

6.2.5. Inspecting .npz files

To inspect the contents of a .npz file (such as mnist-flattened.npz or cifar10-flattened.npz), you can use the inspect_npz.py utility included in the distribution:

python tools/inspect_npz.py data/mnist-flattened.npz

This prints the shape and values of each array stored in the file. To print only the names, shapes, and norms without dumping the full contents, use:

python tools/inspect_npz.py data/mnist-flattened.npz --shapes-only

6.3. Storing datasets and weights

The mlp tool has options --load-weights and --save-weights for loading and saving the weights and bias vectors of an MLP, and options --load-data and --save-data for loading and saving a dataset in NPZ format. The keys in the dictionary for the weight matrices and bias vectors of linear layers are W1, W2, …​ and b1, b2, …​. The keys for the training data plus targets are Xtrain and Ttrain, while for the test data plus targets we use Xtest and Ttest.

6.4. Storing datasets and weights

The mlp tool supports saving and loading both datasets and model parameters using the NumPy .npz format. This ensures compatibility between Python and C++ implementations by storing everything in a standard dictionary of arrays.

  • Use --save-data and --load-data to write or read datasets.

  • Use --save-weights and --load-weights to store or restore the weights and biases of a trained model.

The .npz file for datasets contains the following keys:

  • Xtrain, Ttrain: input features and target labels for the training set

  • Xtest, Ttest: input features and target labels for the test set

The .npz file for model parameters stores each layer’s weights and biases under the keys:

  • W1, W2, …​: weight matrices for the first, second, etc. linear layer

  • b1, b2, …​: corresponding bias vectors

These arrays use standard NumPy formats and can be inspected or manipulated easily in Python using numpy.load() and numpy.savez().

Note: The architecture of the model (e.g., number of layers or activation functions) is not stored in the .npz file. This must be specified separately when reloading weights.

7. Performance

This section discusses various aspects that play a role for the performance of a neural network library.

7.1. Mini-batches

In textbooks and tutorials, the training of a neural network is usually explained in terms of individual examples. But in order to achieve high performance, it is absolutely necessary to use mini-batches. On Wikipedia this is explained as follows:

A compromise between computing the true gradient and the gradient at a single sample is to compute the gradient against more than one training sample (called a "mini-batch") at each step. This can perform significantly better than "true" stochastic gradient descent described, because the code can make use of vectorization libraries rather than computing each step separately.

To support mini-batches, the Nerva-Rowwise C++ Library defines all equations that play a role in the execution of a neural network in matrix form, including the backpropagation equations, see the paper Batch Matrix-form Equations and Implementation of Multilayer Perceptrons. For the latter, many neural network frameworks rely on Automatic differentiation, see also [1]. We use explicit backpropagation equations to implement truly sparse layers and to provide an instructive resource for those studying neural network execution.

7.2. Matrix products

The performance of training a neural network largely depends on the calculation of matrix products during the backpropagation step of linear layers. In order to do this efficiently, the Intel Math Kernel library (MKL) is used. Currently, this dependency is hard coded, but there are plans to make this optional. To experiment with other implementations, like SYCL or BLAS, a global setting is used that is discussed in the next section.

7.3. Subnormal numbers

Experiments with sparse neural networks have shown that the performance can be negatively influenced by subnormal numbers. The example program subnormal_numbers.cpp illustrates the problem. The table below is the result of the following experiment. The dot product of two large vectors of floating-point numbers is computed. One vector is filled with random values between 0 and 1, and the other with powers of 10, ranging from 1 to 1e−45. For values larger than 1e−35, the time needed for this calculation is about 0.044 seconds. For smaller values we end up in the range of subnormal numbers. This causes the runtime to increase more than eight-fold to 0.37 seconds. In our experiments we observed that when layers with high sparsity are used, it may happen that subnormal values appear in weight matrices, and their amount increases every epoch.

--- multiplication1 ---
time =   0.044372 | value = 1.0e+00    | sum = -5.49552e+03
time =   0.044567 | value = 1.0e-01    | sum = -5.49572e+02
time =   0.044243 | value = 1.0e-02    | sum = -5.49304e+01
time =   0.044434 | value = 1.0e-03    | sum = -5.49612e+00
time =   0.044253 | value = 1.0e-04    | sum = -5.49862e-01
time =   0.044765 | value = 1.0e-05    | sum = -5.49653e-02
time =   0.044698 | value = 1.0e-06    | sum = -5.49624e-03
time =   0.044683 | value = 1.0e-07    | sum = -5.49642e-04
time =   0.044703 | value = 1.0e-08    | sum = -5.49491e-05
time =   0.044821 | value = 1.0e-09    | sum = -5.49454e-06
time =   0.044705 | value = 1.0e-10    | sum = -5.49557e-07
time =   0.044657 | value = 1.0e-11    | sum = -5.49730e-08
time =   0.045235 | value = 1.0e-12    | sum = -5.49563e-09
time =   0.045120 | value = 1.0e-13    | sum = -5.49706e-10
time =   0.045010 | value = 1.0e-14    | sum = -5.49719e-11
time =   0.044988 | value = 1.0e-15    | sum = -5.49464e-12
time =   0.044943 | value = 1.0e-16    | sum = -5.49629e-13
time =   0.044795 | value = 1.0e-17    | sum = -5.49573e-14
time =   0.044147 | value = 1.0e-18    | sum = -5.49449e-15
time =   0.044166 | value = 1.0e-19    | sum = -5.49589e-16
time =   0.044380 | value = 1.0e-20    | sum = -5.49722e-17
time =   0.044036 | value = 1.0e-21    | sum = -5.49430e-18
time =   0.043405 | value = 1.0e-22    | sum = -5.49577e-19
time =   0.043615 | value = 1.0e-23    | sum = -5.49548e-20
time =   0.043544 | value = 1.0e-24    | sum = -5.49570e-21
time =   0.043547 | value = 1.0e-25    | sum = -5.49694e-22
time =   0.043536 | value = 1.0e-26    | sum = -5.49365e-23
time =   0.043560 | value = 1.0e-27    | sum = -5.49488e-24
time =   0.043500 | value = 1.0e-28    | sum = -5.49657e-25
time =   0.043524 | value = 1.0e-29    | sum = -5.49783e-26
time =   0.044128 | value = 1.0e-30    | sum = -5.49559e-27
time =   0.043585 | value = 1.0e-31    | sum = -5.49745e-28
time =   0.043530 | value = 1.0e-32    | sum = -5.49488e-29
time =   0.043609 | value = 1.0e-33    | sum = -5.49569e-30
time =   0.043805 | value = 1.0e-34    | sum = -5.49446e-31
time =   0.046169 | value = 1.0e-35    | sum = -5.49661e-32
time =   0.070594 | value = 1.0e-36    | sum = -5.49664e-33
time =   0.247938 | value = 1.0e-37    | sum = -5.49684e-34
time =   0.368848 | value = 1.0e-38    | sum = -5.49553e-35
time =   0.369819 | value = 1.0e-39    | sum = -5.49426e-36
time =   0.368434 | value = 1.0e-40    | sum = -5.49607e-37
time =   0.368747 | value = 1.0e-41    | sum = -5.49801e-38
time =   0.369033 | value = 1.0e-42    | sum = -5.50173e-39
time =   0.370241 | value = 9.9e-44    | sum = -5.47762e-40
time =   0.370065 | value = 9.8e-45    | sum = -4.97559e-41
time =   0.370310 | value = 1.4e-45    | sum = -1.44152e-41

On Google Groups this problem is discussed. A possible solution is to instruct the compiler to flush subnormal values to zero. But there doesn’t seem to be a portable way to achieve this. In the Nerva-Rowwise C++ Library different solutions have been tried. One of them is to periodically flush weights in the subnormal range to zero using the --clip command line option of the mlp tool. In [2] the problem of subnormal numbers is discussed.

7.4. Nerva computation mode

In general, the performance of Eigen is very good. But occasionally, the generated code for a matrix expression can be quite poor. Especially in case of backpropagation calculations this can have a huge impact on the performance. The Nerva-Rowwise C++ Library uses a global setting NervaComputation to experiment with other implementations. For example, the function softmax_layer::feedforward contains this:

      if (NervaComputation == computation::eigen)
      {
        Z = X * W.transpose() + row_repeat(b, N);
        result = stable_softmax()(Z);
      }
      else
      {
        mkl::ddd_product(Z, X, W.transpose());
        Z += row_repeat(b, N);
        result = stable_softmax()(Z);
      }

In this case, the default version computation::eigen turned out to have very poor performance. A direct call to an MKL routine is used to solve this problem. The NervaComputation setting is also used to experiment with BLAS implementations and SYCL implementations. See the file optimizers.h for some examples of that.

The command line tool mlp has an option --computation to set the computation mode.

8. Sparse neural networks

Sparse neural network layers are often simulated using binary masks, see [3]. This is caused by the lack of support for sparse tensors in popular neural network frameworks. Note that PyTorch is currently developing sparse tensors. The Nerva-Rowwise C++ Library supports truly sparse layers, meaning that the weight matrices of sparse layers are stored in a sparse matrix format. Another example of truly sparse layers is given by [4].

8.1. Sparse matrices

Since we are dealing with a programming context, we say that the support of a sparse matrix refers to the set of positions (or indices) in the matrix that are explicitly stored. Elements inside the support can have a non-zero value. Elements outside the support have the value zero by definition.

Sparse matrices in the Nerva-Rowwise C++ Library are stored in CSR format. This matrix representation stores arrays of column and row indices to define the support, plus an array of the corresponding values. CSR matrices are unstructured sparse matrices, meaning they have non-zero elements located at arbitrary positions. Alternatively, there are structured sparse matrices, take for example butterfly matrices [5].

8.2. Sparse evolutionary training

Sparse evolutionary training (SET) is a method for efficiently training sparse neural networks, see e.g. [6]. The idea behind this method is to start the training with a random sparse topology, and to periodically prune and regrow some of the weights.

8.3. Sparse initialization

In SET, the sparsity is not divided evenly over the sparse layers. Instead, small layers are assigned a higher density than larger ones. In [6] formula (3), Erdős–Rényi graph topology is suggested to calculate the densities of the sparse layers given a desired overall density of the sparse layers combined. In the Nerva-Rowwise C++ Library this is implemented in the function compute_sparse_layer_densities, see layer_algorithms.h. The original Python implementation can be found here, along with several other sparse initialization strategies. In the tool mlp the option --overall-density is used for assigning Erdős–Rényi densities to the sparse layers. See [mlp_output] for an example of this. The overall density of 0.05 is converted into densities [0.042382877, 0.06357384, 1.0] for the individual layers.

8.4. Pruning weights

Pruning weights is about removing parameters from a neural network, see also Pruning (artificial_neural_network). In our context removing parameters is about removing elements from the support of a sparse weight matrix. The effect of this is that the values corresponding to these elements are zeroed.

8.4.1. Threshold pruning

In threshold pruning, all weights \(w_{ij}\) with \(|w_{ij}| \leq t\) for a given threshold \(t\) are pruned from a weight matrix \(W\).

8.4.2. Magnitude based pruning

Magnitude based pruning is special case of threshold pruning. In magnitude based pruning, the threshold \(t\) is computed such that for a given fraction \(\zeta\) of the weights we have \(|w_{ij}| \leq t\). To ensure that the desired fraction of weights is removed, our implementation takes into account that there can be multiple weights with \(|w_{ij}| = t\).

8.4.3. SET based pruning

In SET based pruning, magnitude pruning is applied to positive weights and negative weights separately. So a fraction \(\zeta\) of the positive weights and a fraction \(\zeta\) of the negative weights are pruned.

8.5. Growing weights

Growing weights is about adding parameters to a neural network. In our context adding parameters is about adding elements to the support of a sparse weight matrix.

8.5.1. Random growing

In random growing, a given number of elements is chosen randomly from the positions outside the support of a weight matrix. These new elements are then added to the support. Since the new elements need to be initialized, a weight initializer needs to be chosen to generate values for them.

A specific implementation of random growing for matrices in CSR format has been developed, that uses reservoir sampling to determine the new elements that are added to the support.

8.6. Classes for pruning and growing

In the Nerva-Rowwise C++ Library, the classes prune_function and grow_function are used to represent generic pruning and growing strategies:

struct prune_function
{
  /// Removes elements from the support of a sparse matrix
  /// @param W A sparse matrix
  /// @return The number of elements removed from the support
  virtual std::size_t operator()(mkl::sparse_matrix_csr<scalar>& W) const = 0;

  virtual ~prune_function() = default;
};
struct grow_function
{
  /// Adds `count` elements to the support of matrix `W`
  virtual void operator()(mkl::sparse_matrix_csr<scalar>& W, std::size_t count) const = 0;

  virtual ~grow_function() = default;
};

In the command line tool mlp the user can select specific implementations of these prune and grow functions. They are called at the start of each epoch of training via an attribute regrow_function that applies pruning and growing to the sparse layers of an MLP. See also the [on_start_epoch] event.

8.7. Experiments with sparse training

In [7] we report on some of our experiments with sparse neural networks.

An example of a dynamic sparse training experiment is

../install/bin/mlp \
    --layers="ReLU;ReLU;Linear" \
    --layer-sizes="3072;1024;1024;10" \
    --layer-weights=XavierNormal \
    --optimizers="Nesterov(0.9)" \
    --loss=SoftmaxCrossEntropy \
    --learning-rate=0.01 \
    --epochs=100 \
    --batch-size=100 \
    --threads=12 \
    --overall-density=0.05 \
    --prune="Magnitude(0.1)" \
    --cifar10=../data \
    --seed=123

At the start of every epoch 20% of the weights is pruned, and the same number of weights is added back at different locations. The output may look like this:

=== Nerva c++ model ===
Sparse(input_size=3072, output_size=1024, density=0.042382877, optimizer=Nesterov(0.90000), activation=ReLU())
Sparse(input_size=1024, output_size=1024, density=0.06357384, optimizer=Nesterov(0.90000), activation=ReLU())
Dense(input_size=1024, output_size=10, optimizer=Nesterov(0.90000), activation=NoActivation())
loss = SoftmaxCrossEntropyLoss()
scheduler = ConstantScheduler(lr=0.01)
layer densities: 133325/3145728 (4.238%), 66662/1048576 (6.357%), 10240/10240 (100%)

epoch   0 lr: 0.01000000  loss: 2.30284437  train accuracy: 0.07904000  test accuracy: 0.08060000 time: 0.00000000s
epoch   1 lr: 0.01000000  loss: 2.14723837  train accuracy: 0.21136000  test accuracy: 0.21320000 time: 5.48583113s
pruning + growing 26665/133325 weights
pruning + growing 13332/66662 weights
epoch   2 lr: 0.01000000  loss: 1.91203228  train accuracy: 0.30918000  test accuracy: 0.30900000 time: 5.00460376s
pruning + growing 26665/133325 weights
pruning + growing 13332/66662 weights

9. Extending the library

The Nerva-Rowwise C++ Library can be extended in several obvious ways, such as adding new layers, activation functions, loss functions, learning rate schedulers and pruning or growing functions. This can be done by inheriting from the appropriate base class and implementing the required virtual functions. The table below provides an overview:

Functionality Base class

A layer

neural_network_layer

An activation function

activation_function

A loss function

loss_function

A learning rate scheduler

learning_rate_scheduler

A pruning function

prune_function

A growing function

grow_function

It is recommended to follow the approach advocated in the Nerva libraries. Each implementation should be based on a mathematical specification, as explained in the paper Batch Matrix-form Equations and Implementation of Multilayer Perceptrons. After defining the mathematical equations, you can use the table of matrix operations to convert the equations into code.

Another crucial step is validation and testing. The symbolic mathematics library SymPy can be used to validate the equations. The nerva-sympy repository contains test cases for activation functions, loss functions, layers, and even for the derivation of equations.

10. References

[1] A. G. Baydin, B. A. Pearlmutter, A. A. Radul, and J. M. Siskind, “Automatic Differentiation in Machine Learning: a Survey,” J. Mach. Learn. Res., vol. 18, pp. 153:1–153:43, 2017, [Online]. Available: https://jmlr.org/papers/v18/17-468.html.

[2] N. J. Higham and T. Mary, “Mixed precision algorithms in numerical linear algebra,” Acta Numer., vol. 31, pp. 347–414, 2022, [Online]. Available: https://doi.org/10.1017/S0962492922000022.

[3] S. Curci, D. C. Mocanu, and M. Pechenizkiy, “Truly Sparse Neural Networks at Scale,” CoRR, vol. abs/2102.01732, 2021, [Online]. Available: https://arxiv.org/abs/2102.01732.

[4] M. Nikdan, T. Pegolotti, E. Iofinova, E. Kurtic, and D. Alistarh, “SparseProp: Efficient Sparse Backpropagation for Faster Training of Neural Networks at the Edge,” in International Conference on Machine Learning, ICML 2023, 23-29 July 2023, Honolulu, Hawaii, USA, 2023, vol. 202, pp. 26215–26227, [Online]. Available: https://proceedings.mlr.press/v202/nikdan23a.html.

[5] A. Gonon, L. Zheng, P. Carrivain, and Q.-T. Le, “Make Inference Faster: Efficient GPU Memory Management for Butterfly Sparse Matrix Multiplication,” CoRR, vol. abs/2405.15013, 2024, [Online]. Available: https://doi.org/10.48550/arXiv.2405.15013.

[6] S. Liu, D. C. Mocanu, A. R. R. Matavalam, Y. Pei, and M. Pechenizkiy, “Sparse evolutionary deep learning with over one million artificial neurons on commodity hardware,” Neural Comput. Appl., vol. 33, no. 7, pp. 2589–2604, 2021, [Online]. Available: https://doi.org/10.1007/s00521-020-05136-7.

[7] W. Wesselink, B. Grooten, Q. Xiao, C. de Campos, and M. Pechenizkiy, “Nerva: a Truly Sparse Implementation of Neural Networks.” 2024, [Online]. Available: https://arxiv.org/abs/2407.17437.


1. Note that the name sparse_matrix could not be used due to a conflict with a #define of the same name buried deep inside the MKL library code