nerva_jax.weight_initializers
Weight and bias initialization helpers for linear layers.
Functions
|
Normal (Gaussian) initialization with given mean and std. |
|
Uniform initialization within [a, b). |
|
Initialize biases to zero. |
|
Initialize a layer's parameters according to a named scheme. |
He / Kaiming normal initialization (for ReLU). |
|
He / Kaiming uniform initialization (for ReLU, less common). |
|
|
Normal (Gaussian) initialization with given mean and std. |
|
Uniform initialization within [a, b). |
Xavier / Glorot normal initialization (for tanh/sigmoid). |
|
Xavier / Glorot uniform initialization (for tanh/sigmoid). |
|
|
Initialize weights to zero. |
- nerva_jax.weight_initializers.bias_uniform(b_: jax.numpy.ndarray, a: float = 0.0, b: float = 1.0)[source]
Uniform initialization within [a, b).
- nerva_jax.weight_initializers.bias_normal(b: jax.numpy.ndarray, mean: float = 0.0, std: float = 1.0)[source]
Normal (Gaussian) initialization with given mean and std.
- nerva_jax.weight_initializers.weights_uniform(W: jax.numpy.ndarray, a: float = 0.0, b: float = 1.0)[source]
Uniform initialization within [a, b).
- nerva_jax.weight_initializers.weights_normal(W: jax.numpy.ndarray, mean: float = 0.0, std: float = 1.0)[source]
Normal (Gaussian) initialization with given mean and std.
- nerva_jax.weight_initializers.weights_zero(W: jax.numpy.ndarray)[source]
Initialize weights to zero.
- nerva_jax.weight_initializers.weights_xavier_uniform(W: jax.numpy.ndarray)[source]
Xavier / Glorot uniform initialization (for tanh/sigmoid).
K = fan-out (output size) D = fan-in (input size)
- nerva_jax.weight_initializers.weights_xavier_normal(W: jax.numpy.ndarray)[source]
Xavier / Glorot normal initialization (for tanh/sigmoid).
K = fan-out (output size) D = fan-in (input size)
- nerva_jax.weight_initializers.weights_he_normal(W: jax.numpy.ndarray)[source]
He / Kaiming normal initialization (for ReLU).
K = fan-out (output size) D = fan-in (input size)