nerva_jax.matrix_operations

Matrix operations built on top of torch to support the math in the library.

The functions here intentionally mirror the names in the accompanying docs. They operate on 1D/2D tensors and keep broadcasting explicit for clarity.

Functions

Diag(x)

Create diagonal matrix with x as diagonal.

apply(f, X)

Element-wise application of function f to X.

column_repeat(x, n)

Repeat column vector x horizontally n times.

columns_max(X)

Returns a column vector with the maximum values of each row in X.

columns_mean(X)

Returns a column vector with the mean values of each row in X.

columns_sum(X)

Sum over columns (returns row vector).

diag(X)

Extract diagonal of X as a vector.

dot(x, y)

Dot product of vectors x and y.

elements_sum(X)

Returns the sum of the elements of X.

exp(X)

Element-wise exponential exp(X).

hadamard(X, Y)

Element-wise product X ⊙ Y.

identity(n)

Returns the nxn identity matrix.

inv_sqrt(X)

Element-wise inverse square root X^(-1/2) with epsilon for stability.

is_column_vector(x)

Check if x can be treated as a column vector.

is_row_vector(x)

Check if x can be treated as a row vector.

is_square(X)

Check if X is a square matrix.

is_vector(x)

Check if x is a 1D tensor.

log(X)

Element-wise natural logarithm log(X).

log_sigmoid(X)

Element-wise log(sigmoid(X)) computed stably.

ones(m[, n])

Returns an mxn matrix with all elements equal to 1.

product(X, Y)

Matrix multiplication X @ Y.

reciprocal(X)

Element-wise reciprocal 1/X.

row_repeat(x, m)

Repeat row vector x vertically m times.

rows_max(X)

Returns a row vector with the maximum values of each column in X.

rows_mean(X)

Returns a row vector with the mean values of each column in X.

rows_sum(X)

Sum over rows (returns column vector).

sqrt(X)

Element-wise square root √X.

square(X)

Element-wise square X².

vector_size(x)

Get size along first dimension.

zeros(m[, n])

Returns an mxn matrix with all elements equal to 0.

nerva_jax.matrix_operations.is_vector(x: jax.numpy.ndarray) bool[source]

Check if x is a 1D tensor.

nerva_jax.matrix_operations.is_column_vector(x: jax.numpy.ndarray) bool[source]

Check if x can be treated as a column vector.

nerva_jax.matrix_operations.is_row_vector(x: jax.numpy.ndarray) bool[source]

Check if x can be treated as a row vector.

nerva_jax.matrix_operations.vector_size(x: jax.numpy.ndarray) int[source]

Get size along first dimension.

nerva_jax.matrix_operations.is_square(X: jax.numpy.ndarray) bool[source]

Check if X is a square matrix.

nerva_jax.matrix_operations.dot(x: jax.numpy.ndarray, y: jax.numpy.ndarray)[source]

Dot product of vectors x and y.

nerva_jax.matrix_operations.zeros(m: int, n=None) jax.numpy.ndarray[source]

Returns an mxn matrix with all elements equal to 0.

nerva_jax.matrix_operations.ones(m: int, n=None) jax.numpy.ndarray[source]

Returns an mxn matrix with all elements equal to 1.

nerva_jax.matrix_operations.identity(n: int) jax.numpy.ndarray[source]

Returns the nxn identity matrix.

nerva_jax.matrix_operations.product(X: jax.numpy.ndarray, Y: jax.numpy.ndarray) jax.numpy.ndarray[source]

Matrix multiplication X @ Y.

nerva_jax.matrix_operations.hadamard(X: jax.numpy.ndarray, Y: jax.numpy.ndarray) jax.numpy.ndarray[source]

Element-wise product X ⊙ Y.

nerva_jax.matrix_operations.diag(X: jax.numpy.ndarray) jax.numpy.ndarray[source]

Extract diagonal of X as a vector.

nerva_jax.matrix_operations.Diag(x: jax.numpy.ndarray) jax.numpy.ndarray[source]

Create diagonal matrix with x as diagonal.

nerva_jax.matrix_operations.elements_sum(X: jax.numpy.ndarray)[source]

Returns the sum of the elements of X.

nerva_jax.matrix_operations.column_repeat(x: jax.numpy.ndarray, n: int) jax.numpy.ndarray[source]

Repeat column vector x horizontally n times.

nerva_jax.matrix_operations.row_repeat(x: jax.numpy.ndarray, m: int) jax.numpy.ndarray[source]

Repeat row vector x vertically m times.

nerva_jax.matrix_operations.columns_sum(X: jax.numpy.ndarray) jax.numpy.ndarray[source]

Sum over columns (returns row vector).

nerva_jax.matrix_operations.rows_sum(X: jax.numpy.ndarray) jax.numpy.ndarray[source]

Sum over rows (returns column vector).

nerva_jax.matrix_operations.columns_max(X: jax.numpy.ndarray) jax.numpy.ndarray[source]

Returns a column vector with the maximum values of each row in X.

nerva_jax.matrix_operations.rows_max(X: jax.numpy.ndarray) jax.numpy.ndarray[source]

Returns a row vector with the maximum values of each column in X.

nerva_jax.matrix_operations.columns_mean(X: jax.numpy.ndarray) jax.numpy.ndarray[source]

Returns a column vector with the mean values of each row in X.

nerva_jax.matrix_operations.rows_mean(X: jax.numpy.ndarray) jax.numpy.ndarray[source]

Returns a row vector with the mean values of each column in X.

nerva_jax.matrix_operations.apply(f, X: jax.numpy.ndarray) jax.numpy.ndarray[source]

Element-wise application of function f to X.

nerva_jax.matrix_operations.exp(X: jax.numpy.ndarray) jax.numpy.ndarray[source]

Element-wise exponential exp(X).

nerva_jax.matrix_operations.log(X: jax.numpy.ndarray) jax.numpy.ndarray[source]

Element-wise natural logarithm log(X).

nerva_jax.matrix_operations.reciprocal(X: jax.numpy.ndarray) jax.numpy.ndarray[source]

Element-wise reciprocal 1/X.

nerva_jax.matrix_operations.square(X: jax.numpy.ndarray) jax.numpy.ndarray[source]

Element-wise square X².

nerva_jax.matrix_operations.sqrt(X: jax.numpy.ndarray) jax.numpy.ndarray[source]

Element-wise square root √X.

nerva_jax.matrix_operations.inv_sqrt(X: jax.numpy.ndarray) jax.numpy.ndarray[source]

Element-wise inverse square root X^(-1/2) with epsilon for stability.

nerva_jax.matrix_operations.log_sigmoid(X: jax.numpy.ndarray) jax.numpy.ndarray[source]

Element-wise log(sigmoid(X)) computed stably.