nerva_jax documentation
A tiny, educational set of neural network components built on JAX.
Install and build
# from repository root
python -m pip install -U sphinx sphinx-rtd-theme
# build HTML docs into docs_sphinx/_build/html
sphinx-build -b html docs_sphinx docs_sphinx/_build/html
API reference
nerva_jax: Minimal neural network components built on top of JAX. |
|
Activation functions and utilities used by the MLP implementation. |
|
In-memory data loader helpers and one-hot conversions. |
|
Neural network layers used by the MultilayerPerceptron class. |
|
Learning-rate schedulers. |
|
Analytic loss functions and their gradients used during training. |
|
Matrix operations built on top of torch to support the math in the library. |
|
A simple multilayer perceptron (MLP) class. |
|
Optimizers used to adjusts the model's parameters based on the gradients. |
|
Softmax and log-softmax functions together with stable variants. |
|
Training helpers for the MLP, including a basic SGD loop and CLI glue. |
|
Miscellaneous utilities (formatting, timing, parsing, I/O). |
|
Weight and bias initialization helpers for linear layers. |