nerva_jax.datasets

In-memory data loader helpers and one-hot conversions.

The DataLoader defined here mirrors a small subset of the PyTorch DataLoader API but operates on in-memory tensors loaded from .npz files.

Functions

create_npz_dataloaders(filename[, batch_size])

Creates a data loader from a file containing a dictionary with Xtrain, Ttrain, Xtest and Ttest tensors.

from_one_hot(one_hot)

Convert one-hot encoded rows to class index tensor.

infer_num_classes(Ttrain, Ttest)

Infer total number of classes from targets.

max_(X)

Return the maximum element of X as a Python scalar.

to_one_hot(x, num_classes)

Convert class index tensor to one-hot matrix with num_classes columns.

Classes

DataLoader(Xdata, Tdata, batch_size[, ...])

A minimal in-memory data loader with an interface similar to torch.utils.data.DataLoader.

nerva_jax.datasets.to_one_hot(x: jax.numpy.ndarray, num_classes: int)[source]

Convert class index tensor to one-hot matrix with num_classes columns.

nerva_jax.datasets.from_one_hot(one_hot: jax.numpy.ndarray) jax.numpy.ndarray[source]

Convert one-hot encoded rows to class index tensor.

class nerva_jax.datasets.DataLoader(Xdata: jax.numpy.ndarray, Tdata: jax.numpy.ndarray, batch_size: int, num_classes=0)[source]

Bases: object

A minimal in-memory data loader with an interface similar to torch.utils.data.DataLoader.

Notes / Warning:

  • When Tdata contains class indices (shape (N,) or (N,1)), this loader will one-hot encode the labels. If num_classes is not provided, it will be inferred as max(Tdata) + 1.

  • On small datasets or subsets where some classes are absent, this inference can underestimate the true number of classes and produce one-hot targets with too few columns. This may cause dimension mismatches with the model output during training/evaluation.

  • To avoid this, pass num_classes explicitly whenever you know the total number of classes.

property dataset_size

Total number of examples.

nerva_jax.datasets.max_(X: jax.numpy.ndarray) int | float[source]

Return the maximum element of X as a Python scalar.

nerva_jax.datasets.infer_num_classes(Ttrain: jax.numpy.ndarray, Ttest: jax.numpy.ndarray) int[source]

Infer total number of classes from targets.

  • If either Ttrain or Ttest is one-hot encoded (2D with width > 1), use that width.

  • Otherwise assume class indices and return max over both + 1.

nerva_jax.datasets.create_npz_dataloaders(filename: str, batch_size: int = True) Tuple[DataLoader, DataLoader][source]

Creates a data loader from a file containing a dictionary with Xtrain, Ttrain, Xtest and Ttest tensors.