mindcv.models.model_factory 源代码

import os
from mindspore import load_checkpoint, load_param_into_net
from .registry import is_model, model_entrypoint

__all__ = ["create_model"]


[文档]def create_model( model_name: str, num_classes: int = 1000, pretrained=False, in_channels: int = 3, checkpoint_path: str = '', use_ema=False, **kwargs): r"""Creates model by name. Args: model_name (str): The name of model. num_classes (int): The number of classes. Default: 1000. pretrained (bool): Whether to load the pretrained model. Default: False. in_channels (int): The input channels. Default: 3. checkpoint_path (str): The path of checkpoint files. Default: "". use_ema (bool): Whether use ema method. Default: False. """ if checkpoint_path != '' and pretrained: raise ValueError('checkpoint_path is mutually exclusive with pretrained') model_args = dict(num_classes=num_classes, pretrained=pretrained, in_channels=in_channels) kwargs = {k: v for k, v in kwargs.items() if v is not None} if not is_model(model_name): raise RuntimeError(f'Unknown model {model_name}') create_fn = model_entrypoint(model_name) model = create_fn(**model_args, **kwargs) if os.path.exists(checkpoint_path): checkpoint_param = load_checkpoint(checkpoint_path) ema_param_dict = dict() for param in checkpoint_param: if param.startswith("ema"): new_name = param.split("ema.")[1] ema_data = checkpoint_param[param] ema_data.name = new_name ema_param_dict[new_name] = ema_data if ema_param_dict and use_ema: load_param_into_net(model, ema_param_dict) elif bool(ema_param_dict) is False and use_ema: raise ValueError('chekpoint_param does not contain ema_parameter, please set use_ema is False.') else: load_param_into_net(model, checkpoint_param) return model