nerva_jax.utilities

Miscellaneous utilities (formatting, timing, parsing, I/O).

Functions

disable_gpu()

Disable GPU usage for JAX.

load_dict_from_npz(filename)

Loads a dictionary from a file in .npz format

parse_function_call(text)

Parse a string of the shape NAME(key1=value1, key2=value2, ...).

pp(name, x)

Pretty-print a tensor with name and shape info, using NumPy formatting.

pp_numpy(name, arr)

Internal helper: pretty-print using NumPy arrays only.

save_dict_to_npz(filename, data)

Saves a dictionary of JAX arrays to a compressed .npz file.

set_jax_options()

Configure NumPy print options for readable output.

Classes

FunctionCall(name, arguments)

StopWatch()

nerva_jax.utilities.set_jax_options()[source]

Configure NumPy print options for readable output.

nerva_jax.utilities.disable_gpu()[source]

Disable GPU usage for JAX.

nerva_jax.utilities.pp_numpy(name: str, arr: ndarray)[source]

Internal helper: pretty-print using NumPy arrays only.

nerva_jax.utilities.pp(name: str, x: jax.numpy.ndarray)[source]

Pretty-print a tensor with name and shape info, using NumPy formatting.

class nerva_jax.utilities.StopWatch[source]

Bases: object

seconds()[source]

Get elapsed time in seconds since creation or last reset.

reset()[source]

Reset the timer to the current time.

class nerva_jax.utilities.FunctionCall(name: str, arguments: dict)[source]

Bases: object

has_key(key: str) bool[source]

Check if the given key exists in parsed arguments.

get_value(key: str) str[source]
as_scalar(key: str, default_value: float = None) float[source]
as_string(key: str, default_value: str = '') str[source]
nerva_jax.utilities.parse_function_call(text: str) FunctionCall[source]

Parse a string of the shape NAME(key1=value1, key2=value2, …). If there are no arguments the parentheses may be omitted. If there is only one parameter, it is allowed to pass NAME(value) instead of NAME(key=value)

nerva_jax.utilities.load_dict_from_npz(filename: str) Dict[str, jax.numpy.ndarray][source]

Loads a dictionary from a file in .npz format

nerva_jax.utilities.save_dict_to_npz(filename: str, data: Dict[str, jax.numpy.ndarray])[source]

Saves a dictionary of JAX arrays to a compressed .npz file.