mindcv.optim.optim_factory 源代码

''' optim factory '''
import os
from typing import Optional
from mindspore import nn
from mindspore import load_checkpoint, load_param_into_net
from .adan import Adan
from .adamw import AdamW
from .nadam import NAdam

__all__ = ["create_optimizer"]


def init_group_params(params, weight_decay):
    decay_params = []
    no_decay_params = []

    for param in params:
        if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
            decay_params.append(param)
        else:
            no_decay_params.append(param)
    return [
        {'params': decay_params, 'weight_decay': weight_decay},
        {'params': no_decay_params},
        {'order_params': params}
    ]


[文档]def create_optimizer( params, opt: str = 'adam', lr: Optional[float] = 1e-3, weight_decay: float = 0, momentum: float = 0.9, nesterov: bool = False, filter_bias_and_bn: bool = True, loss_scale: float = 1.0, schedule_decay: float = 4e-3, checkpoint_path: str = '', eps: float = 1e-10, **kwargs): r"""Creates optimizer by name. Args: params: network parameters. Union[list[Parameter],list[dict]], which must be the list of parameters or list of dicts. When the list element is a dictionary, the key of the dictionary can be "params", "lr", "weight_decay","grad_centralization" and "order_params". opt: Wrapped optimizer. You could choose like 'sgd', 'nesterov', 'momentum', 'adam', 'adamw', 'rmsprop', 'adagrad', 'lamb'. 'adam' is the default choise for convolution-based networks. 'adamw' is recommended for ViT-based networks. Default: 'adam'. lr: learning rate: float or lr scheduler. Fixed and dynamic learning rate are supported. Default: 1e-3. weight_decay: weight decay factor. It should be noted that weight decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule to get the weight decay value of current step. Default: 0. momentum: momentum if the optimizer supports. Default: 0.9. nesterov: Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients. Default: False. filter_bias_and_bn: whether to filter batch norm paramters and bias from weight decay. If True, weight decay will not apply on BN parameters and bias in Conv or Dense layers. Default: True. loss_scale: A floating point value for the loss scale, which must be larger than 0.0. Default: 1.0. Returns: Optimizer object """ opt = opt.lower() if weight_decay and filter_bias_and_bn: params = init_group_params(params, weight_decay) opt_args = dict(**kwargs) # if lr is not None: # opt_args.setdefault('lr', lr) # non-adaptive: SGD, momentum, and nesterov if opt == 'sgd': # note: nn.Momentum may perform better if momentum > 0. optimizer = nn.SGD(params=params, learning_rate=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov, loss_scale=loss_scale, **opt_args ) elif opt in ['momentum', 'nesterov']: optimizer = nn.Momentum(params=params, learning_rate=lr, momentum=momentum, weight_decay=weight_decay, use_nesterov=nesterov, loss_scale=loss_scale, ) # adaptive elif opt == 'adam': optimizer = nn.Adam(params=params, learning_rate=lr, weight_decay=weight_decay, loss_scale=loss_scale, use_nesterov=nesterov, **opt_args) elif opt == 'adamw': optimizer = AdamW(params=params, learning_rate=lr, weight_decay=weight_decay, loss_scale=loss_scale, **opt_args) elif opt == 'nadam': optimizer = NAdam(params=params, learning_rate=lr, weight_decay=weight_decay, loss_scale=loss_scale, schedule_decay=schedule_decay, **opt_args) elif opt == 'adan': optimizer = Adan(params=params, learning_rate=lr, weight_decay=weight_decay, loss_scale=loss_scale, **opt_args) elif opt == 'rmsprop': optimizer = nn.RMSProp(params=params, learning_rate=lr, momentum=momentum, weight_decay=weight_decay, loss_scale=loss_scale, epsilon=eps, **opt_args ) elif opt == 'adagrad': optimizer = nn.Adagrad(params=params, learning_rate=lr, weight_decay=weight_decay, loss_scale=loss_scale, **opt_args) elif opt == 'lamb': assert loss_scale == 1.0, 'Loss scaler is not supported by Lamb optimizer' optimizer = nn.Lamb(params=params, learning_rate=lr, weight_decay=weight_decay, **opt_args) else: raise ValueError(f'Invalid optimizer: {opt}') if os.path.exists(checkpoint_path): param_dict = load_checkpoint(checkpoint_path) load_param_into_net(optimizer, param_dict) return optimizer