mindcv.data.dataset_download 源代码

"""
Dataset download
"""

import os
from mindcv.utils.download import DownLoad

__all__ = [
    "MnistDownload",
    "Cifar10Download",
    "Cifar100Download"
]


[文档]class MnistDownload(DownLoad): """Utility class for downloading Mnist dataset. Args: root: The root path where the downloaded dataset is placed. """ url_path = 'http://yann.lecun.com/exdb/mnist/' resources = [("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")] def __init__(self, root: str): super().__init__() self.root = root self.path = root
[文档] def download(self): """Download the MNIST dataset if it doesn't exist.""" bool_list = [] # Check whether the file exists and check value of md5. for url, md5 in self.resources: filename = os.path.splitext(url)[0] file_path = os.path.join(self.root, filename) bool_list.append(os.path.isfile(file_path)) if all(bool_list): return # download files for filename, md5 in self.resources: url = os.path.join(self.url_path, filename) self.download_and_extract_archive(url, download_path=self.root, filename=filename, md5=md5, remove_finished=True)
[文档]class Cifar10Download(DownLoad): """Utility class for downloading Cifar10 dataset. Args: root: The root path where the downloaded dataset is placed. """ url = ('http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz', 'c32a1d4ab5d03f1284b67883e8d87530') base_dir = 'cifar-10-batches-bin' resources = ['data_batch_1.bin', 'data_batch_2.bin', 'data_batch_3.bin', 'data_batch_4.bin', 'data_batch_5.bin', 'test_batch.bin', 'batches.meta.txt'] def __init__(self, root: str): super().__init__() self.root = root self.path = os.path.join(self.root, self.base_dir)
[文档] def download(self): """Download the Cifar10 dataset if it doesn't exist.""" bool_list = [] # Check whether the file exists and check value of md5. for filename in self.resources: file_path = os.path.join(self.root, self.base_dir, filename) bool_list.append(os.path.isfile(file_path)) if all(bool_list): return # download files self.download_and_extract_archive(self.url[0], download_path=self.root, md5=self.url[1], remove_finished=True)
[文档]class Cifar100Download(DownLoad): """Utility class for downloading Cifar100 dataset. Args: root: The root path where the downloaded dataset is placed. """ url = ('http://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz', '03b5dce01913d631647c71ecec9e9cb8') base_dir = 'cifar-100-binary' resources = ['train.bin', 'test.bin', 'fine_label_names.txt', 'coarse_label_names.txt'] def __init__(self, root: str): super().__init__() self.root = root self.path = os.path.join(self.root, self.base_dir)
[文档] def download(self): """Download the Cifar100 dataset if it doesn't exist.""" bool_list = [] # Check whether the file exists and check value of md5. for filename in self.resources: file_path = os.path.join(self.root, self.base_dir, filename) bool_list.append(os.path.isfile(file_path)) if all(bool_list): return # download files self.download_and_extract_archive(self.url[0], download_path=self.root, md5=self.url[1], remove_finished=True)