nerva_jax.loss_functions

Analytic loss functions and their gradients used during training.

Functions are provided in vector (lowercase) and matrix (uppercase) forms. Concrete LossFunction classes wrap these for use in the training loop.

Functions

Cross_entropy_loss(Y, T)

Cross entropy loss for matrices: -sum(T ⊙ log(Y)).

Cross_entropy_loss_gradient(Y, T)

Gradient of cross entropy loss for matrices.

Logistic_cross_entropy_loss(Y, T)

Logistic cross entropy loss for matrices.

Logistic_cross_entropy_loss_gradient(Y, T)

Gradient of logistic cross entropy loss for matrices.

Negative_log_likelihood_loss(Y, T)

Negative log likelihood loss for matrices.

Negative_log_likelihood_loss_gradient(Y, T)

Gradient of negative log likelihood loss for matrices.

Softmax_cross_entropy_loss(Y, T)

Softmax cross entropy loss for matrices.

Softmax_cross_entropy_loss_gradient(Y, T)

Gradient of softmax cross entropy loss for matrices.

Softmax_cross_entropy_loss_gradient_one_hot(Y, T)

Gradient of softmax cross entropy for one-hot targets (matrices).

Squared_error_loss(Y, T)

Squared error loss for matrices: sum of ||Y - T||².

Squared_error_loss_gradient(Y, T)

Gradient of squared error loss for matrices.

Stable_softmax_cross_entropy_loss(Y, T)

Stable softmax cross entropy loss for matrices.

Stable_softmax_cross_entropy_loss_gradient(Y, T)

Gradient of stable softmax cross entropy loss for matrices.

Stable_softmax_cross_entropy_loss_gradient_one_hot(Y, T)

Gradient of stable softmax cross entropy for one-hot targets (matrices).

cross_entropy_loss(y, t)

Cross entropy loss for vectors: -t^T log(y).

cross_entropy_loss_gradient(y, t)

Gradient of cross entropy loss for vectors.

logistic_cross_entropy_loss(y, t)

Logistic cross entropy loss for vectors.

logistic_cross_entropy_loss_gradient(y, t)

Gradient of logistic cross entropy loss for vectors.

negative_log_likelihood_loss(y, t)

Negative log likelihood loss for vectors.

negative_log_likelihood_loss_gradient(y, t)

Gradient of negative log likelihood loss for vectors.

parse_loss_function(text)

Parse a loss function name into a LossFunction instance.

softmax_cross_entropy_loss(y, t)

Softmax cross entropy loss for vectors.

softmax_cross_entropy_loss_gradient(y, t)

Gradient of softmax cross entropy loss for vectors.

softmax_cross_entropy_loss_gradient_one_hot(y, t)

Gradient of softmax cross entropy for one-hot targets.

squared_error_loss(y, t)

Squared error loss for vectors: ||y - t||².

squared_error_loss_gradient(y, t)

Gradient of squared error loss for vectors.

stable_softmax_cross_entropy_loss(y, t)

Stable softmax cross entropy loss for vectors.

stable_softmax_cross_entropy_loss_gradient(y, t)

Gradient of stable softmax cross entropy loss for vectors.

stable_softmax_cross_entropy_loss_gradient_one_hot(y, t)

Gradient of stable softmax cross entropy for one-hot targets.

Classes

CrossEntropyLossFunction()

Cross entropy loss function for classification with probabilities.

LogisticCrossEntropyLossFunction()

Logistic cross entropy loss for binary classification.

LossFunction()

Interface for loss functions with value and gradient on batch matrices.

NegativeLogLikelihoodLossFunction()

Negative log likelihood loss for probabilistic outputs.

SoftmaxCrossEntropyLossFunction()

Softmax cross entropy loss for classification with logits.

SquaredErrorLossFunction()

Squared error loss function for regression tasks.

StableSoftmaxCrossEntropyLossFunction()

Numerically stable softmax cross entropy loss for classification.

nerva_jax.loss_functions.squared_error_loss(y, t)[source]

Squared error loss for vectors: ||y - t||².

nerva_jax.loss_functions.squared_error_loss_gradient(y, t)[source]

Gradient of squared error loss for vectors.

nerva_jax.loss_functions.Squared_error_loss(Y, T)[source]

Squared error loss for matrices: sum of ||Y - T||².

nerva_jax.loss_functions.Squared_error_loss_gradient(Y, T)[source]

Gradient of squared error loss for matrices.

nerva_jax.loss_functions.cross_entropy_loss(y, t)[source]

Cross entropy loss for vectors: -t^T log(y).

nerva_jax.loss_functions.cross_entropy_loss_gradient(y, t)[source]

Gradient of cross entropy loss for vectors.

nerva_jax.loss_functions.Cross_entropy_loss(Y, T)[source]

Cross entropy loss for matrices: -sum(T ⊙ log(Y)).

nerva_jax.loss_functions.Cross_entropy_loss_gradient(Y, T)[source]

Gradient of cross entropy loss for matrices.

nerva_jax.loss_functions.softmax_cross_entropy_loss(y, t)[source]

Softmax cross entropy loss for vectors.

nerva_jax.loss_functions.softmax_cross_entropy_loss_gradient(y, t)[source]

Gradient of softmax cross entropy loss for vectors.

nerva_jax.loss_functions.softmax_cross_entropy_loss_gradient_one_hot(y, t)[source]

Gradient of softmax cross entropy for one-hot targets.

nerva_jax.loss_functions.Softmax_cross_entropy_loss(Y, T)[source]

Softmax cross entropy loss for matrices.

nerva_jax.loss_functions.Softmax_cross_entropy_loss_gradient(Y, T)[source]

Gradient of softmax cross entropy loss for matrices.

nerva_jax.loss_functions.Softmax_cross_entropy_loss_gradient_one_hot(Y, T)[source]

Gradient of softmax cross entropy for one-hot targets (matrices).

nerva_jax.loss_functions.stable_softmax_cross_entropy_loss(y, t)[source]

Stable softmax cross entropy loss for vectors.

nerva_jax.loss_functions.stable_softmax_cross_entropy_loss_gradient(y, t)[source]

Gradient of stable softmax cross entropy loss for vectors.

nerva_jax.loss_functions.stable_softmax_cross_entropy_loss_gradient_one_hot(y, t)[source]

Gradient of stable softmax cross entropy for one-hot targets.

nerva_jax.loss_functions.Stable_softmax_cross_entropy_loss(Y, T)[source]

Stable softmax cross entropy loss for matrices.

nerva_jax.loss_functions.Stable_softmax_cross_entropy_loss_gradient(Y, T)[source]

Gradient of stable softmax cross entropy loss for matrices.

nerva_jax.loss_functions.Stable_softmax_cross_entropy_loss_gradient_one_hot(Y, T)[source]

Gradient of stable softmax cross entropy for one-hot targets (matrices).

nerva_jax.loss_functions.logistic_cross_entropy_loss(y, t)[source]

Logistic cross entropy loss for vectors.

nerva_jax.loss_functions.logistic_cross_entropy_loss_gradient(y, t)[source]

Gradient of logistic cross entropy loss for vectors.

nerva_jax.loss_functions.Logistic_cross_entropy_loss(Y, T)[source]

Logistic cross entropy loss for matrices.

nerva_jax.loss_functions.Logistic_cross_entropy_loss_gradient(Y, T)[source]

Gradient of logistic cross entropy loss for matrices.

nerva_jax.loss_functions.negative_log_likelihood_loss(y, t)[source]

Negative log likelihood loss for vectors.

nerva_jax.loss_functions.negative_log_likelihood_loss_gradient(y, t)[source]

Gradient of negative log likelihood loss for vectors.

nerva_jax.loss_functions.Negative_log_likelihood_loss(Y, T)[source]

Negative log likelihood loss for matrices.

nerva_jax.loss_functions.Negative_log_likelihood_loss_gradient(Y, T)[source]

Gradient of negative log likelihood loss for matrices.

class nerva_jax.loss_functions.LossFunction[source]

Bases: object

Interface for loss functions with value and gradient on batch matrices.

gradient(Y: jax.numpy.ndarray, T: jax.numpy.ndarray) jax.numpy.ndarray[source]
class nerva_jax.loss_functions.SquaredErrorLossFunction[source]

Bases: LossFunction

Squared error loss function for regression tasks.

gradient(Y: jax.numpy.ndarray, T: jax.numpy.ndarray) jax.numpy.ndarray[source]
class nerva_jax.loss_functions.CrossEntropyLossFunction[source]

Bases: LossFunction

Cross entropy loss function for classification with probabilities.

gradient(Y: jax.numpy.ndarray, T: jax.numpy.ndarray) jax.numpy.ndarray[source]
class nerva_jax.loss_functions.SoftmaxCrossEntropyLossFunction[source]

Bases: LossFunction

Softmax cross entropy loss for classification with logits.

gradient(Y: jax.numpy.ndarray, T: jax.numpy.ndarray) jax.numpy.ndarray[source]
class nerva_jax.loss_functions.StableSoftmaxCrossEntropyLossFunction[source]

Bases: LossFunction

Numerically stable softmax cross entropy loss for classification.

gradient(Y: jax.numpy.ndarray, T: jax.numpy.ndarray) jax.numpy.ndarray[source]
class nerva_jax.loss_functions.LogisticCrossEntropyLossFunction[source]

Bases: LossFunction

Logistic cross entropy loss for binary classification.

gradient(Y: jax.numpy.ndarray, T: jax.numpy.ndarray) jax.numpy.ndarray[source]
class nerva_jax.loss_functions.NegativeLogLikelihoodLossFunction[source]

Bases: LossFunction

Negative log likelihood loss for probabilistic outputs.

gradient(Y: jax.numpy.ndarray, T: jax.numpy.ndarray) jax.numpy.ndarray[source]
nerva_jax.loss_functions.parse_loss_function(text: str) LossFunction[source]

Parse a loss function name into a LossFunction instance.

Supported names: SquaredError, CrossEntropy, SoftmaxCrossEntropy, LogisticCrossEntropy, NegativeLogLikelihood.