nerva_jax.matrix_operations
Matrix operations built on top of torch to support the math in the library.
The functions here intentionally mirror the names in the accompanying docs. They operate on 1D/2D tensors and keep broadcasting explicit for clarity.
Functions
|
Create diagonal matrix with x as diagonal. |
|
Element-wise application of function f to X. |
|
Repeat column vector x horizontally n times. |
|
Returns a column vector with the maximum values of each row in X. |
|
Returns a column vector with the mean values of each row in X. |
|
Sum over columns (returns row vector). |
|
Extract diagonal of X as a vector. |
|
Dot product of vectors x and y. |
|
Returns the sum of the elements of X. |
|
Element-wise exponential exp(X). |
|
Element-wise product X ⊙ Y. |
|
Returns the nxn identity matrix. |
|
Element-wise inverse square root X^(-1/2) with epsilon for stability. |
Check if x can be treated as a column vector. |
|
Check if x can be treated as a row vector. |
|
|
Check if X is a square matrix. |
|
Check if x is a 1D tensor. |
|
Element-wise natural logarithm log(X). |
|
Element-wise log(sigmoid(X)) computed stably. |
|
Returns an mxn matrix with all elements equal to 1. |
|
Matrix multiplication X @ Y. |
|
Element-wise reciprocal 1/X. |
|
Repeat row vector x vertically m times. |
|
Returns a row vector with the maximum values of each column in X. |
|
Returns a row vector with the mean values of each column in X. |
|
Sum over rows (returns column vector). |
|
Element-wise square root √X. |
|
Element-wise square X². |
|
Get size along first dimension. |
|
Returns an mxn matrix with all elements equal to 0. |
- nerva_jax.matrix_operations.is_vector(x: jax.numpy.ndarray) bool [source]
Check if x is a 1D tensor.
- nerva_jax.matrix_operations.is_column_vector(x: jax.numpy.ndarray) bool [source]
Check if x can be treated as a column vector.
- nerva_jax.matrix_operations.is_row_vector(x: jax.numpy.ndarray) bool [source]
Check if x can be treated as a row vector.
- nerva_jax.matrix_operations.vector_size(x: jax.numpy.ndarray) int [source]
Get size along first dimension.
- nerva_jax.matrix_operations.is_square(X: jax.numpy.ndarray) bool [source]
Check if X is a square matrix.
- nerva_jax.matrix_operations.dot(x: jax.numpy.ndarray, y: jax.numpy.ndarray)[source]
Dot product of vectors x and y.
- nerva_jax.matrix_operations.zeros(m: int, n=None) jax.numpy.ndarray [source]
Returns an mxn matrix with all elements equal to 0.
- nerva_jax.matrix_operations.ones(m: int, n=None) jax.numpy.ndarray [source]
Returns an mxn matrix with all elements equal to 1.
- nerva_jax.matrix_operations.identity(n: int) jax.numpy.ndarray [source]
Returns the nxn identity matrix.
- nerva_jax.matrix_operations.product(X: jax.numpy.ndarray, Y: jax.numpy.ndarray) jax.numpy.ndarray [source]
Matrix multiplication X @ Y.
- nerva_jax.matrix_operations.hadamard(X: jax.numpy.ndarray, Y: jax.numpy.ndarray) jax.numpy.ndarray [source]
Element-wise product X ⊙ Y.
- nerva_jax.matrix_operations.diag(X: jax.numpy.ndarray) jax.numpy.ndarray [source]
Extract diagonal of X as a vector.
- nerva_jax.matrix_operations.Diag(x: jax.numpy.ndarray) jax.numpy.ndarray [source]
Create diagonal matrix with x as diagonal.
- nerva_jax.matrix_operations.elements_sum(X: jax.numpy.ndarray)[source]
Returns the sum of the elements of X.
- nerva_jax.matrix_operations.column_repeat(x: jax.numpy.ndarray, n: int) jax.numpy.ndarray [source]
Repeat column vector x horizontally n times.
- nerva_jax.matrix_operations.row_repeat(x: jax.numpy.ndarray, m: int) jax.numpy.ndarray [source]
Repeat row vector x vertically m times.
- nerva_jax.matrix_operations.columns_sum(X: jax.numpy.ndarray) jax.numpy.ndarray [source]
Sum over columns (returns row vector).
- nerva_jax.matrix_operations.rows_sum(X: jax.numpy.ndarray) jax.numpy.ndarray [source]
Sum over rows (returns column vector).
- nerva_jax.matrix_operations.columns_max(X: jax.numpy.ndarray) jax.numpy.ndarray [source]
Returns a column vector with the maximum values of each row in X.
- nerva_jax.matrix_operations.rows_max(X: jax.numpy.ndarray) jax.numpy.ndarray [source]
Returns a row vector with the maximum values of each column in X.
- nerva_jax.matrix_operations.columns_mean(X: jax.numpy.ndarray) jax.numpy.ndarray [source]
Returns a column vector with the mean values of each row in X.
- nerva_jax.matrix_operations.rows_mean(X: jax.numpy.ndarray) jax.numpy.ndarray [source]
Returns a row vector with the mean values of each column in X.
- nerva_jax.matrix_operations.apply(f, X: jax.numpy.ndarray) jax.numpy.ndarray [source]
Element-wise application of function f to X.
- nerva_jax.matrix_operations.exp(X: jax.numpy.ndarray) jax.numpy.ndarray [source]
Element-wise exponential exp(X).
- nerva_jax.matrix_operations.log(X: jax.numpy.ndarray) jax.numpy.ndarray [source]
Element-wise natural logarithm log(X).
- nerva_jax.matrix_operations.reciprocal(X: jax.numpy.ndarray) jax.numpy.ndarray [source]
Element-wise reciprocal 1/X.
- nerva_jax.matrix_operations.square(X: jax.numpy.ndarray) jax.numpy.ndarray [source]
Element-wise square X².
- nerva_jax.matrix_operations.sqrt(X: jax.numpy.ndarray) jax.numpy.ndarray [source]
Element-wise square root √X.