​ Runner是MMdetection中的一种深度学习算法“工厂”,是对深度学习算法各个组件的“容器”。简单来说,所有的机器学习算法所包含的无非就是数据、模型、训练策略、评估、推理这五个部分。Runner就是将这五个部分组合在一起的工具。

​ 其实光是Runner,可以说的东西不多,但是其背后的设计思路,以及对深度学习算法的概括是值的学习的。所以这篇文章不仅是提炼一下Runner中的操作,也是对一个算法“数据、模型、训练策略、评估、推理”这五个部分的一种总结。

1. Runner整体架构

​ 奉上一个我整理的Runner在MMdetection中的整体架构图,也是根据“数据、模型、训练策略、评估、推理”这五个部分分门别类的,在阅读源码的时候就可以根据这个架构来理清楚思路。

Runner框架梳理

2. 代码逻辑

​ Runner的源码封装在MMCV库当中,目前主要是有epoch_runneriter_runner两种,本质上差不多,只是大家的习惯不同罢了,其实在epoch_runner中也有内置计算了iter数。

​ 给小白简单解释下epoch和iter的差别,比如训练数据是1000张图片,Mini-batch是10,那么训练1个epoch就是训练100次,也就是100个iter。所以如果你打算训练1w个iter的话等价于训练100个epoch。

​ 我个人常用的是epoch_runner,那么就以epoch_runner为例理一理各个模块的使用。由于代码过多,所以有些类型判断、warning,参数注释我都删去了,只要理清楚整个代码运行逻辑就行。

@RUNNERS.register_module()
class EpochBasedRunner(BaseRunner):
    def run_iter(self, data_batch, train_mode, **kwargs):
        if train_mode:
            outputs = self.model.train_step(data_batch, self.optimizer,**kwargs)
        else:
            outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
        if not isinstance(outputs, dict):
            raise TypeError('"batch_processor()" or "model.train_step()"'
                            'and "model.val_step()" must return a dict')
        if 'log_vars' in outputs:
            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
        self.outputs = outputs

    def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self._inner_iter = i
            self.call_hook('before_train_iter')
            self.run_iter(data_batch, train_mode=True)
            self.call_hook('after_train_iter')
            self._iter += 1
        self.call_hook('after_train_epoch')
        self._epoch += 1

    def val(self, data_loader, **kwargs):
        self.model.eval()
        self.mode = 'val'
        self.data_loader = data_loader
        # 以下几个call_hook就是几个典型的应用,HOOK的用法可以详细看下这个专栏里对于HOOK机制剖析的文章
        self.call_hook('before_val_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self._inner_iter = i
            self.call_hook('before_val_iter')
            with torch.no_grad():
                self.run_iter(data_batch, train_mode=False)
            self.call_hook('after_val_iter')

        self.call_hook('after_val_epoch')

    def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
        """Start running.
        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
        """
        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if mode == 'train':
              # epoch runner内部还是
                self._max_iters = self._max_epochs * len(data_loaders[i])
                break

        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.call_hook('before_run')

        while self.epoch < self._max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                # 训练则mode='train'
                # 评估则mode='val'
                epoch_runner = getattr(self, mode)
                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= self._max_epochs:
                        break
                    epoch_runner(data_loaders[i], **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')

    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='epoch_{}.pth',
                        save_optimizer=True,
                        meta=None,
                        create_symlink=True):
        # 保存模型的相关代码
        if meta is None:
            meta = dict(epoch=self.epoch + 1, iter=self.iter)
        elif isinstance(meta, dict):
            meta.update(epoch=self.epoch + 1, iter=self.iter)
        else:
            raise TypeError(
                f'meta should be a dict or None, but got {type(meta)}')
        if self.meta is not None:
            meta.update(self.meta)

        filename = filename_tmpl.format(self.epoch + 1)
        filepath = osp.join(out_dir, filename)
        optimizer = self.optimizer if save_optimizer else None
        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
        # in some environments, `os.symlink` is not supported, you may need to
        # set `create_symlink` to False
        if create_symlink:
            dst_file = osp.join(out_dir, 'latest.pth')
            if platform.system() != 'Windows':
                mmcv.symlink(filename, dst_file)
            else:
                shutil.copy(filename, dst_file)

3. 细节短说

  • 数据:对于数据的处理,主要看datasets中的代码,数据处理这块主要是逻辑上要严谨,代码上要细致,功能上要全面。比如对开源的不同数据集的处理方式,输出格式的统一,在做transform的时候要考虑一些边界问题,多数据集训练多卡训练。不同sampler的方式,对训练IO通信的优化(预加载数据到内存),样本均衡。
  • 模型:算法框架被分为SingleStage和TwoStage两种,当然每一种都可以细分为AnchorFree或者是AnchorBased。算法这块内容有点多,但是主要还是LOSS和网络结构两个部分的创新。对于算法这块重头戏,打算在今年过年之前,对MMdetection中的所有算法都仔细研读并整理成文章。
  • 训练策略:这部分也算是框架内容上,比如warm_up,lr的下降,早停,打印log(支持终端输出和tensorboard),可配置optimizer,保存模型、加载预训练模型、继续训练模型等。
  • 评估策略:学术集的话一般是有提供评估代码,比如COCO和cityscapes,我们要做的只是把输出的格式处理成相应的格式就行。但是实际工程当中的话评估策略需要根据业务要求来定。想起在学校的时候说到分类必谈f1-score,但是到了实际工程中f1-score却很少用到,我自己也尝试过,发现这个指标很多时候不如TPR-FPR来得靠谱。
  • 推理:目标检测inference最常用的是多尺度预测,当然还有一些细节,比如检测框和网络predict之间是有相应的变换的,推理的时候需要反变换为原图尺度。其实推理这部分和评估很多地方是重复的,只是在代码结构上他们是分开的,而且在有些任务中推理和评估还有一些不同之处,因此才把他们区分开来。

4. 总结

​ 我想看懂runner的逻辑应该还是很简单的,其实最关键的还是理解HOOK机制,理解了HOOK机制才能学会从"使用"这个框架变成"改变这个框架",HOOK机制的剖析可以看这篇文章,原本按照我自己的写作习惯是要自己写一套demo的,但是画完Runner的架构梳理图之后觉得似乎没什么必要。

​ 一方面是Runner本身只是一个逻辑梳理的流程,把深度学习的五大组成部分很好的融合在一起,并且秉持着高内聚、低耦合的原则,提供了很好的代码范例,我觉得我很难写出比这个质量更高的代码。另一方面是觉得MMDetection的源码阅读中Registry、HOOK、Runner仅仅只是属于"代码架构"上的东西,作为一名算法工程师,应该花更多的时间在算法细节、设计上面。

​ 不管工作再忙,也要坚持写作,坚持输出,写完这几个架构上的东西,是时候上主菜了,下周开始每周分享一篇paper reading,力争在过年之前把MMDetection支持的所有算法从原理和代码上都扣得干干净净。

 都看到这了,不防点个免费的赞~,谢谢~~~