nerva_jax

nerva_jax: Minimal neural network components built on top of JAX.

This package provides small, educational implementations of layers, activation functions, loss functions, optimizers, learning-rate schedulers, softmax utilities, and simple training helpers. It is designed for readability and experimentation rather than performance.

Modules

activation_functions

Activation functions and utilities used by the MLP implementation.

datasets

In-memory data loader helpers and one-hot conversions.

layers

Neural network layers used by the MultilayerPerceptron class.

learning_rate

Learning-rate schedulers.

loss_functions

Analytic loss functions and their gradients used during training.

matrix_operations

Matrix operations built on top of torch to support the math in the library.

multilayer_perceptron

A simple multilayer perceptron (MLP) class.

optimizers

Optimizers used to adjusts the model's parameters based on the gradients.

softmax_functions

Softmax and log-softmax functions together with stable variants.

training

Training helpers for the MLP, including a basic SGD loop and CLI glue.

utilities

Miscellaneous utilities (formatting, timing, parsing, I/O).

weight_initializers

Weight and bias initialization helpers for linear layers.