mindcv.data.dataset_factory 源代码

"""
Create dataset by name
"""

from typing import Optional
import os

from mindspore.dataset import MnistDataset, Cifar10Dataset, Cifar100Dataset, ImageFolderDataset, DistributedSampler
import mindspore.dataset as ds

from .dataset_download import MnistDownload, Cifar10Download, Cifar100Download
from .distributed_sampler import RepeatAugSampler
#from .dataset_reader import ImageNetDataset

__all__ = ["create_dataset"]

_MINDSPORE_BASIC_DATASET = dict(
    mnist=(MnistDataset, MnistDownload),
    cifar10=(Cifar10Dataset, Cifar10Download),
    cifar100=(Cifar100Dataset, Cifar100Download)
)

#_DATASET_SIZE = {'imagenet': }

[文档]def create_dataset( name: str = '', root: str = './', split: str = 'train', shuffle: bool = True, num_samples: Optional[bool] = None, num_shards: Optional[int] = None, shard_id: Optional[int] = None, num_parallel_workers: Optional[int] = None, download: bool = False, num_aug_repeats: int = 0, **kwargs ): r"""Creates dataset by name. Args: name: dataset name like MNIST, CIFAR10, ImageNeT, ''. '' means a customized dataset. Default: ''. root: dataset root dir. Default: './'. split: data split: '' or split name string (train/val/test), if it is '', no split is used. Otherwise, it is a subfolder of root dir, e.g., train, val, test. Default: 'train'. shuffle: whether to shuffle the dataset. Default: True. num_samples: Number of elements to sample (default=None, which means sample all elements). num_shards: Number of shards that the dataset will be divided into (default=None). When this argument is specified, `num_samples` reflects the maximum sample number of per shard. shard_id: The shard ID within `num_shards` (default=None). This argument can only be specified when `num_shards` is also specified. num_parallel_workers: Number of workers to read the data (default=None, set in the config). download: whether to download the dataset. Default: False num_aug_repeats: Number of dataset repeatition for repeated augmentation. If 0 or 1, repeated augmentation is diabled. Otherwise, repeated augmentation is enabled and the common choice is 3. (Default: 0) Note: For custom datasets and imagenet, the dataset dir should follow the structure like: .dataset_name/ ├── split1/ │ ├── class1/ │ │ ├── 000001.jpg │ │ ├── 000002.jpg │ │ └── .... │ └── class2/ │ ├── 000001.jpg │ ├── 000002.jpg │ └── .... └── split2/ ├── class1/ │ ├── 000001.jpg │ ├── 000002.jpg │ └── .... └── class2/ ├── 000001.jpg ├── 000002.jpg └── .... Returns: Dataset object """ assert (num_samples is None) or (num_aug_repeats==0), 'num_samples and num_aug_repeats can NOT be set together.' name = name.lower() # subset sampling if num_samples is not None and num_samples > 0: # TODO: rewrite ordered distributed sampler (subset sampling in distributed mode is not tested) if num_shards is not None and num_shards > 1: # distributed print('ns', num_shards, 'num_samples', num_samples) sampler = DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) else: # standalone if shuffle: sampler = ds.RandomSampler(replacement=False, num_samples=num_samples) else: sampler = ds.SequentialSampler(num_samples=num_samples) mindspore_kwargs = dict(shuffle=None, sampler=sampler, num_parallel_workers=num_parallel_workers, **kwargs) else: sampler = None mindspore_kwargs = dict(shuffle=shuffle, sampler=sampler, num_shards=num_shards, shard_id=shard_id, num_parallel_workers=num_parallel_workers, **kwargs) # sampler for repeated augmentation if num_aug_repeats > 0: dataset_size = get_dataset_size(name, root, split) print(f'INFO: Repeated augmentation is enabled, num_aug_repeats: {num_aug_repeats}, original dataset size: ', dataset_size) #since drop_remainder is usally True, we don't need to do rounding in sampling sampler = RepeatAugSampler(dataset_size, num_shards=num_shards, rank_id=shard_id, num_repeats=num_aug_repeats, selected_round=0, shuffle=shuffle) mindspore_kwargs = dict(shuffle=None, sampler=sampler, num_shards=None, shard_id=None, **kwargs) # create dataset if name in _MINDSPORE_BASIC_DATASET: dataset_class = _MINDSPORE_BASIC_DATASET[name][0] dataset_download = _MINDSPORE_BASIC_DATASET[name][1] dataset_new_path = None if download: if shard_id is not None: root = os.path.join(root, f'dataset_{str(shard_id)}') dataset_download = dataset_download(root) dataset_download.download() dataset_new_path = dataset_download.path dataset = dataset_class(dataset_dir=dataset_new_path if dataset_new_path else root, usage=split, **mindspore_kwargs) # address ms dataset num_classes empty issue if name == 'mnist': dataset.num_classes = lambda :10 elif name == 'cifar10': dataset.num_classes = lambda :10 elif name == 'cifar100': dataset.num_classes = lambda :100 else: if name == "imagenet" and download: raise ValueError("Imagenet dataset download is not supported. Please download imagenet from https://www.image-net.org/download.php, and parse the path of dateset directory via args.data_dir") if os.path.isdir(root): root = os.path.join(root, split) dataset = ImageFolderDataset(dataset_dir=root, **mindspore_kwargs) ''' Another implementation which a bit slower than ImageFolderDataset imagenet_dataset = ImageNetDataset(dataset_dir=root) sampler = RepeatAugSampler(len(imagenet_dataset), num_shards=num_shards, rank_id=shard_id, num_repeats=repeated_aug, selected_round=1, shuffle=shuffle) dataset = ds.GeneratorDataset(imagenet_dataset, column_names=imagenet_dataset.column_names, sampler=sampler) ''' return dataset
def get_dataset_size(name, root, split): if name in _MINDSPORE_BASIC_DATASET: dataset_class = _MINDSPORE_BASIC_DATASET[name][0] dataset = dataset_class(dataset_dir=root, usage=split) else: if os.path.isdir(root): root = os.path.join(root, split) dataset = ImageFolderDataset(dataset_dir=root) return dataset.get_dataset_size()