mmdetection模型构建及Registry注册器机制好久没有做⽬标检测了,最近突然⼜接到了检测任务,跟同事讨论时,发现⾃⼰竟然忘了很多细节,
于是想趁训练模型的间隙,重新梳理下⽬标检测。我选择了mmdetection来学习,除了⽬标检测本⾝,
这个框架中很多python的使⽤技巧和框架的设计模式也是值得学习。最近⼀年基本都在使⽤python,
希望能将这些技巧应⽤在以后的⼯作之中。mmdetection封装的很好,很⽅便使⽤,⽐如我想训练的
话只需如下的⼀条指令。在train.py中,通过build_detector来构建模型(参数来⾃ faster_rcnn_r50_fpn_1x_voc0712.py),
python tools/train.py  configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py
build_detector的定义如下,最后通过build_from_cfg来构建模型,这⾥看到了让⼈困惑的Registry.
from mmdet.cv_core.utils import Registry, build_from_cfg
屋顶漏水防水补漏
from torch import nn
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')
def build(cfg, registry, default_args=None):
叶全真个人简历"""Build a module.
Args:
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
我的乐园400字作文四年级if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
⼀、Registry是⼲什么的
Registry完成了从字符串到类的映射,这样模型信息、训练时的参数信息,只需要写⼊到⼀个配置⽂件⾥,然后使⽤注册器来实例化即可。
⼆、如何实现
通过装饰器来实现。在mmcv/mmcv/registry.py中,我们看到了Registry类。其中完成字符串到类的映射,实际上就是下⾯的成员函数来实现的,核⼼代码就⼀句,将要注册的类添加到字典⾥,key为类的名字(字符串)。下⾯通过⼀个⼩例⼦,
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {self.name}')
self._module_dict[module_name] = module_class
来看看它的构建过程。在导⼊下⾯这个⽂件时,⾸先创建FRUIT实例,接着通过装饰器(这⾥是⽤成员函数装饰类)来注册Apple类,调⽤
register_module,然后调⽤_register(注意:参数cls即为类Apple),最后调⽤_register_module完成Apple的添加。完成后,FRUIT就有了个字典成员:['Apple']=APPle。在build_from_cfg中,传⼊模型参数,即可通过FRUIT构建Apple的实例化对象。
class Registry():
def__init__(self, name):
self._name = name
self._module_dict = dict()
def _register_module(self, module_class, module_name, force):
self._module_dict[module_name] = module_class
def register_module(self, name=None, force=False, module=None):
print('register module ...')
def _register(cls):
print('cls ', cls)
self._register_module(
如何生成目录
module_class=cls, module_name=name, force=force)
return cls
return _register
FRUIT = Registry('fruit')
@ister_module()
class Apple():
def__init__(self, name):
self.name = name
def build_from_cfg(cfg, registry, default_args=None):
大昭寺
args = py()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type')
if is_str(obj_type):
obj_cls = (obj_type)
return obj_cls(**args)
三、Registry在mmdetection中是如何构建模型的
我们来看⼀下构建模型的流程:
1、在train.py中通过build_detector构建模型,其中del, ain_cfg如下,包括模型信息和训练信息。
model = build_detector(
2、最关键的部分来了。⾸先通过build_detector构建模型,其中传⼊的DETECTORS是Registry的实例,在该实例中,包含了所有已经实现的检测器,如图。那么它是在哪⾥实现添加这些检测的类的呢?
def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
看了前⾯那个⼩例⼦我们就能猜到,⼀定是在这些检测类上,⽤Registry对其进⾏了注册,看看faster rcnn的实现,证明了我们的猜想。这样只要
在定义这些类时,对其进⾏注册,那么就会⾃动加⼊到DETECTORS这个实例的成员字典⾥,⾮常的巧妙。当我们想实例化某个检测⽹络时,传⼊其字符名称
即可。
既然都看到这⾥了,就进⼀步看看⽹络时如何继续构建的吧。mmdetection将⽹络分成了⼏个部分,backbone,head,neck等。
在TwoStageDetector(
faster rcnn的基类)中,可以看到分别构建了这⼏个部分。head, neck, loss等,同样是通过Registry来注册实现的。最后就是将这⼏个部分组合起来即可。
@ister_module()
class TwoStageDetector(BaseDetector):
"""Base class for two-stage detectors.
Two-stage detectors typically consisting of a region proposal network and a
task-specific regression head.
"""
def__init__(self,
backbone,
neck=None,
rpn_head=None,
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(TwoStageDetector, self).__init__()
self.backbone = build_backbone(backbone)
if neck is not None:
if rpn_head is not None:
rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
rpn_head_ = py()
rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
self.rpn_head = build_head(rpn_head_)
if roi_head is not None:
# update train and test cfg here for now
# TODO: refactor assigner & sampler
rcnn_train_cfg = if train_cfg is not None else None
roi_head.update(train_cfg=rcnn_train_cfg)
roi_head.update(test_cfg=)
self.init_weights(pretrained=pretrained)
四、Registry的应⽤
在我最近的⼀个数据处理的项⽬中,有三类数据,sample, measure 和image。如果我想得到某个数据类型的实例,我是通过if来判断的。那如果数据类别很多呢?就像检测器这样有⼏⼗种,再⽤if就显得很蠢了。借⽤Registry机制,可以轻松解决这个问题。