nerva_jax.multilayer_perceptron

A simple multilayer perceptron (MLP) class.

Functions

parse_multilayer_perceptron(...)

Construct an MLP from textual layer specs and size/optimizer configs.

Classes

MultilayerPerceptron([layers])

Multilayer perceptron

class nerva_jax.multilayer_perceptron.MultilayerPerceptron(layers=None)[source]

Bases: object

Multilayer perceptron

feedforward(X: jax.numpy.ndarray) jax.numpy.ndarray[source]
backpropagate(Y: jax.numpy.ndarray, DY: jax.numpy.ndarray) None[source]
optimize(eta: float)[source]
info()[source]
load_weights_and_bias(filename: str)[source]

Loads the weights and biases from a file in .npz format

The weight matrices are stored using the keys W1, W2, … and the bias vectors using the keys b1, b2, … :param filename: the name of the file

save_weights_and_bias(filename: str)[source]

Saves the weights and biases to a file in compressed .npz format.

The weight matrices are stored using the keys W1, W2, … and the bias vectors using the keys b1, b2, …

nerva_jax.multilayer_perceptron.parse_multilayer_perceptron(layer_specifications: List[str], linear_layer_sizes: List[int], optimizers: List[str], linear_layer_weight_initializers: List[str]) MultilayerPerceptron[source]

Construct an MLP from textual layer specs and size/optimizer configs.

layer_specifications: e.g. [“ReLU”, “BatchNormalization”, “LogSoftmax”] linear_layer_sizes: e.g. [784, 128, 10] for two linear layers optimizers: one per layer (including BatchNormalization) linear_layer_weight_initializers: one per linear layer