nerva_jax.layers

Neural network layers used by the MultilayerPerceptron class.

Layers in this module operate on matrices with row layout (each row is a sample). They expose a minimal interface with feedforward, backpropagate and optimize.

Functions

parse_linear_layer(text, D, K, optimizer, ...)

Parse a textual layer spec and create a configured Layer instance.

Classes

ActivationLayer(D, K, act)

Linear layer followed by a pointwise activation function.

BatchNormalizationLayer(D)

Batch normalization layer with per-feature gamma and beta.

Layer()

Base class for layers of a neural network with data in row layout (each row is a sample: shape (N, D)).

LinearLayer(D, K)

Linear layer: Y = X W^T + b.

LogSoftmaxLayer(D, K)

Linear layer followed by log_softmax over the last dimension.

SReLULayer(D, K, act)

Activation layer with SReLU and trainable activation parameters.

SoftmaxLayer(D, K)

Linear layer followed by softmax over the last dimension.

class nerva_jax.layers.Layer[source]

Bases: object

Base class for layers of a neural network with data in row layout (each row is a sample: shape (N, D)).

feedforward(X: jax.numpy.ndarray) jax.numpy.ndarray[source]
backpropagate(Y: jax.numpy.ndarray, DY: jax.numpy.ndarray) None[source]
optimize(eta)[source]
class nerva_jax.layers.LinearLayer(D: int, K: int)[source]

Bases: Layer

Linear layer: Y = X W^T + b.

Shapes: X (N, D) -> Y (N, K), W (K, D), b (K,).

feedforward(X: jax.numpy.ndarray) jax.numpy.ndarray[source]
backpropagate(Y: jax.numpy.ndarray, DY: jax.numpy.ndarray) None[source]
input_size() int[source]
output_size() int[source]
set_optimizer(optimizer: str)[source]
set_weights(weight_initializer)[source]
class nerva_jax.layers.ActivationLayer(D: int, K: int, act: ActivationFunction)[source]

Bases: LinearLayer

Linear layer followed by a pointwise activation function.

feedforward(X: jax.numpy.ndarray) jax.numpy.ndarray[source]
backpropagate(Y: jax.numpy.ndarray, DY: jax.numpy.ndarray) None[source]
class nerva_jax.layers.SReLULayer(D: int, K: int, act: SReLUActivation)[source]

Bases: ActivationLayer

Activation layer with SReLU and trainable activation parameters.

In addition to W and b, this layer optimizes SReLU’s (al, tl, ar, tr).

backpropagate(Y: jax.numpy.ndarray, DY: jax.numpy.ndarray) None[source]
set_optimizer(optimizer: str)[source]
class nerva_jax.layers.SoftmaxLayer(D: int, K: int)[source]

Bases: LinearLayer

Linear layer followed by softmax over the last dimension.

feedforward(X: jax.numpy.ndarray) jax.numpy.ndarray[source]
backpropagate(Y: jax.numpy.ndarray, DY: jax.numpy.ndarray) None[source]
class nerva_jax.layers.LogSoftmaxLayer(D: int, K: int)[source]

Bases: LinearLayer

Linear layer followed by log_softmax over the last dimension.

feedforward(X: jax.numpy.ndarray) jax.numpy.ndarray[source]
backpropagate(Y: jax.numpy.ndarray, DY: jax.numpy.ndarray) None[source]
class nerva_jax.layers.BatchNormalizationLayer(D: int)[source]

Bases: Layer

Batch normalization layer with per-feature gamma and beta.

Normalizes inputs across the batch using per-feature statistics. Shapes: X (N, D) -> Y (N, D), gamma/beta (D,).

feedforward(X: jax.numpy.ndarray) jax.numpy.ndarray[source]
backpropagate(Y: jax.numpy.ndarray, DY: jax.numpy.ndarray) None[source]

Compute gradients for gamma/beta and propagate DX through BN.

input_size() int[source]
output_size() int[source]
set_optimizer(optimizer: str)[source]
nerva_jax.layers.parse_linear_layer(text: str, D: int, K: int, optimizer: str, weight_initializer: str) Layer[source]

Parse a textual layer spec and create a configured Layer instance.

Supports Linear, Softmax, LogSoftmax, activation names (e.g. ReLU), and SReLU(…). The optimizer and weight initializer are applied.