nerva_jax.utilities
Miscellaneous utilities (formatting, timing, parsing, I/O).
Functions
Disable GPU usage for JAX. |
|
|
Loads a dictionary from a file in .npz format |
|
Parse a string of the shape NAME(key1=value1, key2=value2, ...). |
|
Pretty-print a tensor with name and shape info, using NumPy formatting. |
|
Internal helper: pretty-print using NumPy arrays only. |
|
Saves a dictionary of JAX arrays to a compressed .npz file. |
Configure NumPy print options for readable output. |
Classes
|
|
- nerva_jax.utilities.pp_numpy(name: str, arr: ndarray)[source]
Internal helper: pretty-print using NumPy arrays only.
- nerva_jax.utilities.pp(name: str, x: jax.numpy.ndarray)[source]
Pretty-print a tensor with name and shape info, using NumPy formatting.
- nerva_jax.utilities.parse_function_call(text: str) FunctionCall [source]
Parse a string of the shape NAME(key1=value1, key2=value2, …). If there are no arguments the parentheses may be omitted. If there is only one parameter, it is allowed to pass NAME(value) instead of NAME(key=value)