nerva_jax.weight_initializers

Weight and bias initialization helpers for linear layers.

Functions

bias_normal(b[, mean, std])

Normal (Gaussian) initialization with given mean and std.

bias_uniform(b_[, a, b])

Uniform initialization within [a, b).

bias_zero(b)

Initialize biases to zero.

set_layer_weights(layer, text)

Initialize a layer's parameters according to a named scheme.

weights_he_normal(W)

He / Kaiming normal initialization (for ReLU).

weights_he_uniform(W)

He / Kaiming uniform initialization (for ReLU, less common).

weights_normal(W[, mean, std])

Normal (Gaussian) initialization with given mean and std.

weights_uniform(W[, a, b])

Uniform initialization within [a, b).

weights_xavier_normal(W)

Xavier / Glorot normal initialization (for tanh/sigmoid).

weights_xavier_uniform(W)

Xavier / Glorot uniform initialization (for tanh/sigmoid).

weights_zero(W)

Initialize weights to zero.

nerva_jax.weight_initializers.bias_uniform(b_: jax.numpy.ndarray, a: float = 0.0, b: float = 1.0)[source]

Uniform initialization within [a, b).

nerva_jax.weight_initializers.bias_normal(b: jax.numpy.ndarray, mean: float = 0.0, std: float = 1.0)[source]

Normal (Gaussian) initialization with given mean and std.

nerva_jax.weight_initializers.bias_zero(b: jax.numpy.ndarray)[source]

Initialize biases to zero.

nerva_jax.weight_initializers.weights_uniform(W: jax.numpy.ndarray, a: float = 0.0, b: float = 1.0)[source]

Uniform initialization within [a, b).

nerva_jax.weight_initializers.weights_normal(W: jax.numpy.ndarray, mean: float = 0.0, std: float = 1.0)[source]

Normal (Gaussian) initialization with given mean and std.

nerva_jax.weight_initializers.weights_zero(W: jax.numpy.ndarray)[source]

Initialize weights to zero.

nerva_jax.weight_initializers.weights_xavier_uniform(W: jax.numpy.ndarray)[source]

Xavier / Glorot uniform initialization (for tanh/sigmoid).

K = fan-out (output size) D = fan-in (input size)

nerva_jax.weight_initializers.weights_xavier_normal(W: jax.numpy.ndarray)[source]

Xavier / Glorot normal initialization (for tanh/sigmoid).

K = fan-out (output size) D = fan-in (input size)

nerva_jax.weight_initializers.weights_he_normal(W: jax.numpy.ndarray)[source]

He / Kaiming normal initialization (for ReLU).

K = fan-out (output size) D = fan-in (input size)

nerva_jax.weight_initializers.weights_he_uniform(W: jax.numpy.ndarray)[source]

He / Kaiming uniform initialization (for ReLU, less common).

K = fan-out (output size) D = fan-in (input size)

nerva_jax.weight_initializers.set_layer_weights(layer, text: str)[source]

Initialize a layer’s parameters according to a named scheme.