yolov7改进之使用QFocalLoss

news/2024/7/11 1:17:28 标签: YOLO, 计算机视觉, 深度学习, 人工智能, python

深度学习三大件:数据、模型、Loss。一个好的Loss有利于让模型更容易学到需要的特征,不过深度学习已经白热化了,Loss这块对一个成熟任务的提升是越来越小了。虽然如此,也不妨碍我们在难以从数据和模型层面入手时,从这个方面尝试了。

BCEBlurWithLogitsLoss

yolov7中loss由三部分构成,cls loss, obj loss, box loss。分别是类别损失、框的置信度损失、框的位置损失。其中cls loss和obj loss都是使用BCEBlurWithLogitsLoss,这个loss的源代码如下:

class BCEBlurWithLogitsLoss(nn.Module):
    # BCEwithLogitLoss() with reduced missing label effects.
    def __init__(self, alpha=0.05):
        super(BCEBlurWithLogitsLoss, self).__init__()
        self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none')  # must be nn.BCEWithLogitsLoss()
        self.alpha = alpha

    def forward(self, pred, true):
        loss = self.loss_fcn(pred, true)
        pred = torch.sigmoid(pred)  # prob from logits
        dx = pred - true  # reduce only missing label effects
        # dx = (pred - true).abs()  # reduce missing label and false label effects
        alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
        loss *= alpha_factor
        return loss.mean()

与普通的交叉熵损失相比,这个损失可以降低missing label(也就是本身是个正样本,但是没有标注)以及false label(错误标注,代码中注释部分)。做法是降低这些样本的loss权重,具体而言就是通过pred和label的差异来衡量,差异太大就降低权重。

FocalLoss

yolov7中也提供了另一种Loss, 在hyp参数中设置大于0的gamma就可以开启,FocalLoss的思想是加大困难样本的权重,同时为正负样本分配不同的权重,公式如下:
F L ( p t ) = − a t ( 1 − p t ) γ l o g ( p t ) FL(p_t)=-a_t(1-p_t)^\gamma log(p_t) FL(pt)=at(1pt)γlog(pt)
讨论对于一个二分类的问题,也就是两个类别讨论。当一个样本被分错时,也就是当标签类y = 1时,p = 0.3,根据上式可以看到,y=1 , p= 0.3 , 则 p t = 0.3 p_t = 0.3 pt=0.3那么 ( 1 − p t ) γ (1 - p_t)^\gamma (1pt)γ就很大(通常 γ \gamma γ取2)。这也就说明,分错的这个类表示难分的类。

QFocalLoss

FocalLoss存在一些问题:
classification score 和 IoU/centerness score 训练测试不一致。

这个不一致主要体现在两个方面:

1) 用法不一致。训练的时候,分类和质量估计各自训记几个儿的,但测试的时候却又是乘在一起作为NMS score排序的依据,这个操作显然没有end-to-end,必然存在一定的gap。

2) 对象不一致。借助Focal Loss的力量,分类分支能够使得少量的正样本和大量的负样本一起成功训练,但是质量估计通常就只针对正样本训练。那么,对于one-stage的检测器而言,在做NMS score排序的时候,所有的样本都会将分类score和质量预测score相乘用于排序,那么必然会存在一部分分数较低的“负样本”的质量预测是没有在训练过程中有监督信号的,有就是说对于大量可能的负样本,他们的质量预测是一个未定义行为。这就很有可能引发这么一个情况:一个分类score相对低的真正的负样本,由于预测了一个不可信的极高的质量score,而导致它可能排到一个真正的正样本(分类score不够高且质量score相对低)的前面。具体可以参考QFocalLoss作者的博客, 知乎《大白话 Generalized Focal Loss》。
对于第一个问题,为了保证training和test一致,同时还能够兼顾分类score和质量预测score都能够训练到所有的正负样本,那么一个方案呼之欲出:就是将两者的表示进行联合。这个合并也非常有意思,从物理上来讲,我们依然还是保留分类的向量,但是对应类别位置的置信度的物理含义不再是分类的score,而是改为质量预测的score。这样就做到了两者的联合表示,同时,暂时不考虑优化的问题,我们就有可能完美地解决掉第一个问题

简单的说,做过检测的应该知道,一个检测结果的置信度=类别的置信度* 框的置信度。但是训练的时候,这两个置信度是分开的,这个不一致可能出现类别置信度低,但框置信度高,。QFocalLoss的出发点就是在训练时,分类损失中的label修改成label* 框置信度。这样修改会出现0~1之间的label,原版的FocalLoss是不支持离散label的,所以作者进行了魔改:
image.png
yolov7中有人提交了一个QFocalLoss代码,不过并没有启用,并且从代码来看,并没有正确实现作者的主要意图:将分类score和框置信度score联合训练。所以我没有使用yolov7中的实现,而是对其进行了修改:

class QFocalLoss(nn.Module):
    # Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
    def __init__(self, loss_fcn, beta = 2.0):
        super(QFocalLoss, self).__init__()
        self.loss_fcn = loss_fcn  # must be nn.BCEWithLogitsLoss()
        self.beta = beta
        self.reduction = loss_fcn.reduction
        self.loss_fcn.reduction = 'none'  # required to apply FL to each element

    def forward(self, pred, target):
        assert len(target) ==2, "target must be a tuple of (class, score)"
        label, score = target
        iou_target = label*score.view(-1,1)
        pred_sigmoid = torch.sigmoid(pred)  # prob from logits
        scale_factor = torch.abs(pred_sigmoid - iou_target)
        beta = self.beta
        loss = self.loss_fcn(pred, iou_target) * scale_factor.pow(beta)

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:  # 'none'
            return loss

主要区别在于此处target由类别标签和预测的框置信度共同构成。
为了使用QFocalLoss, 需要在调用时,将预测框的执行度也传入,具体而言,其实就是预测框和GT的iou, 在代码里本身就已经有了:

 # Objectness
                tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype)  # iou ratio

                # Classification
                selected_tcls = targets[i][:, 1].long()
                if self.nc > 1:  # cls loss (only if multiple classes)
                    t = torch.full_like(ps[:, 5:], self.cn, device=device)  # targets
                    t[range(n), selected_tcls] = self.cp
                    
                    if isinstance(self.BCEcls, QFocalLoss):
                        lcls += self.BCEcls(ps[:, 5::], (t, iou.detach()))
                    else:
                        lcls += self.BCEcls(ps[:, 5:], t)  # BCE

再把代码中原本用来触发FocalLoss的地方,修改为触发QFocalLoss即可。

使用体验

从个人在私有任务使用看,使用QFocalLoss对于map基本没有提升,甚至误检增多了(recall升高)。分析可能是和FocalLoss一样的问题,由于提高了困难样本的权重,会造成过于关注一些模棱两可的样本。对于小样本来说,估计会有提升,对于一般任务,不如使用BCEBlurWithLogitsLoss。


http://www.niftyadmin.cn/n/5141896.html

相关文章

嵌入式Linux系统的闪存设备和文件系统学习纪要

嵌入式Linux系统的闪存设备和文件系统学习纪要 Linux下的文件系统结构如下: NAND Flash 是一种非易失性存储器(Non-Volatile Memory),常用于闪存设备和固态硬盘(SSD)中。以下是几种常见的 NAND Flash 种类&…

golang的类型断言

前言:原因很简单,写的代码panic了。报错如下。为此专门看下golang的类型断言。 “[PANIC]interface conversion: interface {} is string, not float64”。 1、类型断言(assertion) 所谓“类型断言”即判断一个变量是不是某个类型的实例(简单来讲就是判…

蓝桥杯(C++ 扫雷)

题目&#xff1a; 思想&#xff1a; 1、遍历每个点是否有地雷&#xff0c;有地雷则直接返回为9&#xff0c;无地雷则遍历该点的周围八个点&#xff0c;计数一共有多少个地雷&#xff0c;则返回该数。 代码&#xff1a; #include<iostream> using namespace std; int g[…

要是能重来,你还会选择程序员吗?

昨天面试了2个应届毕业生&#xff0c;一男一女&#xff0c;男的我觉得技术还可以&#xff0c;就录了&#xff0c;女的没有通过&#xff0c;完事后我去厕所边上的楼道抽烟&#xff0c;却发现女孩子蹲在地上哭的一塌糊涂。 我听得很清楚她跟那个男同学说的话&#xff0c;她已经忘…

WINCC7.5-根据时间跨度选择趋势

yyyy-MM-dd hh:mm:ss “yyyy”&#xff1a;表示四位数的年份&#xff0c;例如&#xff1a;2022。 “MM”&#xff1a;表示两位数的月份&#xff0c;从01到12。 “dd”&#xff1a;表示两位数的日期&#xff0c;从01到31。 “hh”&#xff1a;表示12小时制的小时数&#xff0c;从…

一行代码搞定禁用web开发者工具

在如今的互联网时代&#xff0c;网页源码的保护显得尤为重要&#xff0c;特别是前端代码&#xff0c;几乎就是明文展示&#xff0c;很容易造成源码泄露&#xff0c;黑客和恶意用户往往会利用浏览器的开发者工具来窃取网站的敏感信息。为了有效防止用户打开浏览器的Web开发者工具…

Android 类似淘宝的吸顶特效 NestedScrollView+RecycleView

运行图 布局的设计 要实现上面的效果需要搞定NestedScrollView和RecycleView的滑动冲突。有人要问RecycleView为何要滑动自动撑大不就好了么&#xff1f;这个问题其实对于有限的资源加载来说是很好的解决方案&#xff0c;但是如果涉及到的是图文结合的并且有大批量的数据的时候…

(CV)论文列表

CNN卷积神经网络之SKNet及代码 https://blog.csdn.net/qq_41917697/article/details/122791002 【CVPR2022 oral】MixFormer: Mixing Features across Windows and Dimensions 【精选】【CVPR2022 oral】MixFormer: Mixing Features across Windows and Dimensions-CSDN博客