Source code for nndt.datasets.dataset

import os
import warnings
from pathlib import Path

from nndt import datasets
from nndt.datasets.utils import (
    _check_md5,
    _download_from_google,
    _download_from_url,
    _extract_7z_file,
)


[docs]class Dataset: def __init__(self, name=None, to_path=None): self.name = name self.to_path = to_path self.hash = None self.urls = None self._dict = None
[docs] def dataset_list(self): """ Get available subsets of the models :return: list of the available datasets for the download """ return [key for key in self._dict if "_test" not in key]
[docs] def load(self) -> str: """ Load the dataset and return the path to its location. :return: path to dataset location """ if self.to_path is None: self.to_path = f"./.datasets/{self.name}/" Path(self.to_path).mkdir(parents=True, exist_ok=True) complete = False for url in self.urls: if "drive.google" in url: try: z = _download_from_google(url, self.to_path) assert _check_md5(z, self.hash) _extract_7z_file(z, self.to_path) except Exception as e: warnings.warn(str(e)) continue else: try: print("Downloading...") z = _download_from_url(url, self.to_path) assert _check_md5(z, self.hash) _extract_7z_file(z, self.to_path) except Exception as e: warnings.warn(str(e)) continue complete = True os.remove(z) print("Loading complete") break if not complete: raise ConnectionError( "Looks like you can't reach any mirror, " f"please report this issue to: {datasets.source_url}" ) return self.to_path