By

PyTorch Dataloader for HDF5 data

Read in the dark

Context

I’m a newbie with HDF5, less so with PyTorch yet I found it hard to find guidelines regarding good practices to load data from HDF5 data.

So here’s my take on the issue, inspired by torchmeta

First Attempt - TypeError: h5py objects cannot be pickled

Here’s my usecase: I have .h5 files containing samples as datasets (no groups/subgroups but that wouldn’t be so different I expect).

Initially I thought “well, let’s just open the files and dynamically load the datasets”. So I wrote something like:

class H5Dataset(Dataset):
    def __init__(self, h5_path):
        self.h5_file = h5py.File(h5_path, "r")

    def __len__(self):
        return len(self.h5_file)

    def __getitem__(self, index):
        dataset = self.h5_file[f"trajectory_{index}"]
        data = torch.from_numpy(dataset[:])
        labels = dict(dataset.attrs)

        return {
            "data": data,
            "labels": labels
        }

...

loader = torch.utils.data.DataLoader(H5Dataset("/some/path.h5"), num_workers=2)
batch = next(iter(loader))

And then…

TypeError: h5py objects cannot be pickled

So that’s bad news. The issue is when using num_workers > 0 the Datasets are created and then passed to the DataLoader’s worker processes, which requires any data sent to be pickleable… unlike h5py.File objects.

Solution

One could want to shift file opening to __getitem__ but this means that you will need to open and read the file once for every sample along the training procedure, which could create overhead and filesystem pressure.

The solution is to lazy-load the files: load them the first time they are needed and store them after the first call:

import torch
from torch.utils.data import Dataset
import h5py


class H5Dataset(Dataset):
    def __init__(self, h5_paths, limit=-1):
        self.limit = limit
        self.h5_paths = h5_paths
        self._archives = [h5py.File(h5_path, "r") for h5_path in self.h5_paths]
        self.indices = {}
        idx = 0
        for a, archive in enumerate(self.archives):
            for i in range(len(archive)):
                self.indices[idx] = (a, i)
                idx += 1

        self._archives = None

    @property
    def archives(self):
        if self._archives is None: # lazy loading here!
            self._archives = [h5py.File(h5_path, "r") for h5_path in self.h5_paths]
        return self._archives

    def __getitem__(self, index):
        a, i = self.indices[index]
        archive = self.archives[a]
        dataset = archive[f"trajectory_{i}"]
        data = torch.from_numpy(dataset[:])
        labels = dict(dataset.attrs)

        return {"data": data, "labels": labels}

    def __len__(self):
        if self.limit > 0:
            return min([len(self.indices), self.limit])
        return len(self.indices)

That’s it :)

Now:

loader = torch.utils.data.DataLoader(H5Dataset(["/some/path.h5", "/some/path2.h5"]), num_workers=2)
batch = next(iter(loader))

Does not raise an error anymore!