'''model registry and list'''
import sys
import fnmatch
from collections import defaultdict
__all__ = [
'list_models',
'is_model',
'model_entrypoint',
'list_modules',
'is_model_in_modules',
'is_model_pretrained'
]
_module_to_models = defaultdict(set)
_model_to_module = {}
_model_entrypoints = {}
_model_has_pretrained = set()
def register_model(fn):
# lookup containing module
mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''
# add model to __all__ in module
model_name = fn.__name__
if hasattr(mod, '__all__'):
mod.__all__.append(model_name)
else:
mod.__all__ = [model_name]
# add entries to registry dict/sets
_model_entrypoints[model_name] = fn
_model_to_module[model_name] = module_name
_module_to_models[module_name].add(model_name)
has_pretrained = False
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
cfg = mod.default_cfgs[model_name]
has_pretrained = 'url' in cfg and cfg['url']
if has_pretrained:
_model_has_pretrained.add(model_name)
return fn
[文档]def list_models(filter='', module='', pretrained=False, exclude_filters=''):
if module:
all_models = list(_module_to_models[module])
else:
all_models = _model_entrypoints.keys()
if filter:
models = []
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
for f in include_filters:
include_models = fnmatch.filter(all_models, f) # include these models
if include_models:
models = set(models).union(include_models)
else:
models = all_models
if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)):
exclude_filters = [exclude_filters]
for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models
if exclude_models:
models = set(models).difference(exclude_models)
if pretrained:
models = _model_has_pretrained.intersection(models)
models = sorted(list(models))
return models
[文档]def is_model(model_name):
"""
Check if a model name exists
"""
return model_name in _model_entrypoints
[文档]def model_entrypoint(model_name):
"""
Fetch a model entrypoint for specified model name
"""
return _model_entrypoints[model_name]
[文档]def list_modules():
"""
Return list of module names that contain models / model entrypoints
"""
modules = _module_to_models.keys()
return list(sorted(modules))
[文档]def is_model_in_modules(model_name, module_names):
"""
Check if a model exists within a subset of modules
Args:
model_name (str) - name of model to check
module_names (tuple, list, set) - names of modules to search in
"""
assert isinstance(module_names, (tuple, list, set))
return any(model_name in _module_to_models[n] for n in module_names)
[文档]def is_model_pretrained(model_name):
return model_name in _model_has_pretrained