论文

MUTUAL MEAN-TEACHING: PSEUDO LABEL REFINERY FOR UNSUPERVISED DOMAIN ADAPTATION ON PERSON RE-IDENTIFICATION

解决的问题

通俗说:

  • reid任务在实际的使用中,即使是用大规模的数据集训练好的模型,如果直接部署于一个新的监控系统,由于领域差异会导致效果明显的下降。
  • 无监督领域,行人重识别任务中目标域的类别通常与源域没有重叠。

具体问题:

  • 减少伪标签噪声

推理思路

  • 无监督领域中基于聚类的伪标签最有效且保持最优精度。

    [2] X. Zhang, et al. Self-training with progressive augmentation for unsupervised cross-domain person re-identification. ICCV, 2019.
    [3] F. Yang, et al. Self-similarity grouping: A simple unsupervised cross domain adaptation approach for person re-identification. ICCV, 2019.

  • 无监督聚类存在伪标签噪声的问题。
  • 伪标签如果为onehot形式的话就比较极端,要么一定对,要么一定错。
  • 因此根据论文Bag of Tricks and A Strong Baseline for Deep Person Re-identification中的trick使用softlabel的原理,对onehot标签进行软化,例如:[0,1,0,0] -> [0.1,0.6,0.2,0.1]这样的标签更具有鲁棒性。
  • 拿到了目标域的“软”伪标签直接计算损失监督自己是不可取的,这样只会错的更错,因此作者想到训练一个对称网络,在协同训练下进行相互监督来避免自身过拟合。

详细实现

在这里插入图片描述
1、pretrain部分:

  • 使用不同的参数(代码中使用不同的随机种子实现)初始化两个相同的网络,在源域训练两个模型,对应图中Net 1和Net 2。

2、端到端部分:

  • 使用pretrain的模型分别初始化Net1,Net2。第一次初始化MeanNet1参数与Net1相同,同理MeanNet2。
  • 将数据Dt采用不同的随机增强方式分别输入Net1和Net2,得到特征kmeans聚类后可以得到每个样本的硬标签。
  • 根据模型的输出Prediction和硬标签可以得到Net1和Net2的ce损失(loss_ce_1,loss_ce_2)和tri损失(loss_tri_1,loss_tri_2)。
  • 同理将数据Dt采用不同的随机增强方式分别输入MeanNet1和MeanNet2,得到得到下面一行的prediction,作为软标签。
  • 将Net1的Prediction作为预测,将MeanNet2的Prediction作为目标计算软ce损失(loss_ce_soft1)和软tri损失(loss_tri_soft1)。同理Net2得到(loss_ce_soft2)和(loss_tri_soft2)。
  • 对应损失1:1结合得到loss_ce_soft、loss_tri_soft
  • 最后总ce损失为:loss_ce_all = (loss_ce_1 + loss_ce_2)_(1-ce_soft_weight) + loss_ce_soft _ ce_soft_weight
  • 最后总tri损失为:loss_tri_all = (loss_tri_1 + loss_tri_2)_(1-tri_soft_weight) + loss_tri_soft _ tri_soft_weight
  • 其中ce_soft_weight = 0.5、tri_soft_weight = 0.8
  • 最后的loss = loss_ce_all + loss_tri_all
  • 使用这个损失反向传播更新Net1、Net2网络的参数。
  • 使用更新后的Net1和Net2的参数与上一次的Net1、Net2参数加权求和更新MeanNet1和MeanNet2的参数。
  • 迭代端到端部分2-12步。

MeanNet更新逻辑看代码比较清晰,最开始几个step主要以Net1、Net2反向传播更新的参数为主,越到后头更新的参数权重越低最终仅占0.001。

def _update_ema_variables(self, model, ema_model, alpha=0.999, global_step):
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)

最终效果

在这里插入图片描述

一些问题

关于Kmeans的类别数取值:

Market1501测试集一个750类,根据实验结果k取500效果最好,700也有一些比较好的结果:
在这里插入图片描述
DUKE测试集一个702类,根据实验结果k取700效果最好:在这里插入图片描述
因此感觉Kmean聚类的类别数接近真实类别数较好。实验可以在真是类别数附近测试,如果不知道真实类别数只能盲猜了。

损失设计使用软标签和硬标签结合,是否可以去掉其中一个?

在这里插入图片描述
仅看数据集Market->Duke,IBN+resnet50:

实验 mAp rank1
- 65.7 79.3
去掉硬分类和硬三元损失 -41.2 -41.3
去掉硬分类损失 -38.2 -37.3
去掉硬三元损失 -0.1 +0.1
去掉软分类损失 -5.4 -3.6
去掉软三元损失 -4.0 -2.2

总结:硬三元损失对实验效果影响不大,其余都有一些提升,硬分类损失是框架得知目标域数据分布的关键,贡献最大。

软三元损失没有标签怎么取一对正负样本?

首先,硬三元损失计算的时候,因为知道标签类别,采用难样本挖掘的思想,一个batch的数据由M = P * K个样本构成,其中P为类别数,K为这个类别的图像数。遍历M个样本,在正样本中找到与当前样本欧氏距离最大的样本构建正样本对,在负样本中找到与当前样本欧式距离最小的样本构建负样本对。
作者实现代码:

    def forward(self, emb1, emb2, label):
        if self.normalize_feature:
            # equal to cosine similarity
            emb1 = F.normalize(emb1)
            emb2 = F.normalize(emb2)

        mat_dist = euclidean_dist(emb1, emb1)
        assert mat_dist.size(0) == mat_dist.size(1)
        N = mat_dist.size(0)
        mat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float()

        dist_ap, dist_an, ap_idx, an_idx = _batch_hard(mat_dist, mat_sim, indice=True)
        assert dist_an.size(0)==dist_ap.size(0)
        triple_dist = torch.stack((dist_ap, dist_an), dim=1)
        triple_dist = F.log_softmax(triple_dist, dim=1)
        if (self.margin is not None):
            loss = (- self.margin * triple_dist[:,0] - (1 - self.margin) * triple_dist[:,1]).mean()
            return loss

        mat_dist_ref = euclidean_dist(emb2, emb2)
        dist_ap_ref = torch.gather(mat_dist_ref, 1, ap_idx.view(N,1).expand(N,N))[:,0]
        dist_an_ref = torch.gather(mat_dist_ref, 1, an_idx.view(N,1).expand(N,N))[:,0]
        triple_dist_ref = torch.stack((dist_ap_ref, dist_an_ref), dim=1)
        triple_dist_ref = F.softmax(triple_dist_ref, dim=1).detach()

        loss = (- triple_dist_ref * triple_dist).mean(0).sum()
        return loss

作者首先根据是否有margin来判断输出是Net1或者Net2编码特征的三元损失还是MeanNet编码特征与Net编码特征的三元损失的结合。

该函数的输入:em1是Net1或者Net2编码的特征,em2是对称的Mean Net的编码的特征,label是聚类的硬标签。
这一行代码:dist_ap, dist_an, ap_idx, an_idx = _batch_hard(mat_dist, mat_sim, indice=True)
是根据硬标签和em1的距离矩阵得到每个特征与正样本距离最大的正样本的索引与负样本距离最小的索引同时计算了这些样本的距离矩阵。就是找到了对应的最难的正负样本。
triple_dist = F.log_softmax(triple_dist, dim=1)计算了Net1或Net2输出特征根据硬标签的三元损失。
同理下面一段:triple_dist_ref = F.softmax(triple_dist_ref, dim=1).detach()计算了MeanNet1或MeanNet2输出特征根据硬标签的三元损失。
最后的loss为两者相乘:loss = (- triple_dist_ref * triple_dist).mean(0).sum()

为什么用时序平均模型更新参数?

作者实验发现用时序平均模型得到的软标签更加可靠。观察发现尽管最前期epoch会有很多错误,但是可以有效的防止偏差放大。
在这里插入图片描述

论文

Improved Mutual Mean-Teaching for Unsupervised Domain Adaptive Re-ID

主要贡献

  • 提出了一个集合了伪标签和域翻译(gan)的方法。
  • 优化MMT方法,把标注了的源域图像和目标域的无标签图像结合起来训练。

Pipeline

  • 用SDA将源域的图像风格转换到目标域。
  • 使用MMT+框架对转换的图像以及目标域的图像训练。

SDA细节

SDA主体是一个CircleGAN的过程,不同的地方在于风格转换的图像在目标域编码的时候将该特征拿出来和风格转换前的特征计算triplet loss。这样的目的是让风格转换后的图像正样本对和正负样本对间的距离与风格转换前收敛到一致,这样的话不仅实现了源域图像风格迁移到目标域,还实现了源域图像和目标域图像的分布一致。
在这里插入图片描述

finetune with MMT+

如果是MMT的话加上SDA大致流程如下:
在这里插入图片描述
对原始图像使用SDA迁移风格到目标域,然后使用迁移后的图像pretrain网络Net1和Net2(用不同的参数以及数据增强)。这时候Net1和Net2有一定的对目标域图像编码的能力。然后对目标域图像聚类得到伪标签,根据MMT框架的原理训练reid模型。
作者在比赛中升级了MMT框架,升级后的流程如下:
在这里插入图片描述
改进点:

  • 之前MMT框架仅仅使用目标域的图像训练,MMT+将源域图像也加入了训练,但是两个域有领域差异,所以提出的解决办法是对两组数据使用不同的bn处理。bn完再混在一起用。
    在这里插入图片描述
  • 可能是因为gan生成图片以及在训练MMT时加入了source domain的图片的原因,作者实验发现伪标签因为噪声数据集ID分布的影响数据集的噪声更多了,对实验的结果影响较大,因此加入了MoCo loss进行训练。
    在这里插入图片描述

    MoCo loss细节

    在这里插入图片描述
  • x_q:代表某一图片(定义为P_q)的图像增强操作(旋转、平移、剪切等)后的一个矩阵;
  • x_k:代表多张图片(定义为P_K, 其中P_K包含P_q)的图像增强操作后的多个矩阵的矩阵集;
  • encoder,momentum encoder:分别代表两个编码网络,这两个网络的结构相同,参数不同;
  • q:x_q经过encoder网络编码后的一个向量;
  • k:x_k经过momentum encoder网络编码后的多个向量;
  • contrastive loss(即L_q):
    在这里插入图片描述

Lq趋于0
在这里插入图片描述趋于1
则负样本在这里插入图片描述趋于0
则负样本在这里插入图片描述趋于负无穷
所以负样本之间的夹角趋于180度。
所以最终的优化目标是正样本的夹角趋于0负样本间的夹角趋于180度。
注意这里优化的是点乘的结果其实就是优化余弦夹角,因为q和k都进行了l2 normal。

其他涨点技巧

1、在二三阶段作者分别训练了以ResNetSt50、ResNetSt101、DenseNet169-IBN和ResNetXt101-IBN为主干网络的四个模型,最终的比赛结果是用四个模型输出的特征向量concat然后再L2-normalized得到。
在这里插入图片描述
2、图片风格迁移部分,使用SDA涨点情况:
在这里插入图片描述
3、MMT+相对于MMT涨点情况:
在这里插入图片描述

相对于MMT的改进

  • MMT网络的pretrain数据来自SDA生成的。
  • 把源域和目标域的图像都拿来训练MMT+,用不同的BN解决了领域差异大的问题。
  • 加上了MoCo对比损失来减少目标域噪声ID的影响。