YOLOV7改进-添加EIOU,SIOU,AlphaIOU,FocalEIOU

news/2024/7/11 1:53:23 标签: YOLO, 深度学习, 计算机视觉, yolov7改进
  1. 打开utils->general.py
    在这里插入图片描述
    在这里插入图片描述
  2. 找到bbox_iou(),345行左右,将下面的与源码进行替换
def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, SIoU=False, EIoU=False, WIoU=False, Focal=False, alpha=1, gamma=0.5, scale=False, eps=1e-7):
    # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
    box2 = box2.T

    # Get the coordinates of bounding boxes
    if x1y1x2y2:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
        b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
    else:  # transform from xywh to xyxy
        b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
        b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
        b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
        b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2

    # Intersection area
    inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
            (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)

    # Union Area
    w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
    w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
    union = w1 * h1 + w2 * h2 - inter + eps
    if scale:
        self = WIoU_Scale(1 - (inter / union))

    # IoU
    # iou = inter / union # ori iou
    iou = torch.pow(inter/(union + eps), alpha) # alpha iou
    if CIoU or DIoU or GIoU or EIoU or SIoU or WIoU:
        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
        if CIoU or DIoU or EIoU or SIoU or WIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = (cw ** 2 + ch ** 2) ** alpha + eps  # convex diagonal squared
            rho2 = (((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4) ** alpha  # center dist ** 2
            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
                with torch.no_grad():
                    alpha_ciou = v / (v - iou + (1 + eps))
                if Focal:
                    return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)), torch.pow(inter/(union + eps), gamma)  # Focal_CIoU
                else:
                    return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha))  # CIoU
            elif EIoU:
                rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2
                rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2
                cw2 = torch.pow(cw ** 2 + eps, alpha)
                ch2 = torch.pow(ch ** 2 + eps, alpha)
                if Focal:
                    return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2), torch.pow(inter/(union + eps), gamma) # Focal_EIou
                else:
                    return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2) # EIou
            elif SIoU:
                # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
                s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps
                s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps
                sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
                sin_alpha_1 = torch.abs(s_cw) / sigma
                sin_alpha_2 = torch.abs(s_ch) / sigma
                threshold = pow(2, 0.5) / 2
                sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
                angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
                rho_x = (s_cw / cw) ** 2
                rho_y = (s_ch / ch) ** 2
                gamma = angle_cost - 2
                distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
                omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
                omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
                shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
                if Focal:
                    return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha), torch.pow(inter/(union + eps), gamma) # Focal_SIou
                else:
                    return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha) # SIou
            elif WIoU:
                if Focal:
                    raise RuntimeError("WIoU do not support Focal.")
                elif scale:
                    return getattr(WIoU_Scale, '_scaled_loss')(self), (1 - iou) * torch.exp((rho2 / c2)), iou # WIoU https://arxiv.org/abs/2301.10051
                else:
                    return iou, torch.exp((rho2 / c2)) # WIoU v1
            if Focal:
                return iou - rho2 / c2, torch.pow(inter/(union + eps), gamma)  # Focal_DIoU
            else:
                return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        if Focal:
            return iou - torch.pow((c_area - union) / c_area + eps, alpha), torch.pow(inter/(union + eps), gamma)  # Focal_GIoU https://arxiv.org/pdf/1902.09630.pdf
        else:
            return iou - torch.pow((c_area - union) / c_area + eps, alpha)  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    if Focal:
        return iou, torch.pow(inter/(union + eps), gamma)  # Focal_IoU
    else:
        return iou  # IoU

3、打开loss.py,找到ComputeLoss()
在这里插入图片描述
在这里插入图片描述

4、在这里注释掉,并添加

if type(iou) is tuple:
	lbox+=(iou[1].detach()*(1-iou[0])).mean()
    iou=iou[0]
else:
    lbox+=(1.0-iou).mean()

在这里插入图片描述

5、使用对应的iou,修改这个CIOU
在这里插入图片描述
6、使用Focal-iou的话,ok
在这里插入图片描述
这里对应调用之前更换的bbox_iou()里的参数


http://www.niftyadmin.cn/n/5068017.html

相关文章

【Java 进阶篇】JDBC 数据库连接池 C3P0 详解

数据库连接池是数据库编程中常用的一种技术,它可以有效地管理数据库连接,提高数据库访问的性能和效率。在 Java 编程中,有多种数据库连接池可供选择,其中之一就是 C3P0。本文将详细介绍 C3P0 数据库连接池的使用,包括原…

树上启发式合并 待补

对于每个子树&#xff0c;直接遍历所有轻儿子&#xff0c;继承重儿子 会了板子后&#xff0c;修改维护的东西和莫队是一样的 洛谷 U41492 #include <bits/stdc.h> #define ll long long #define ull unsigned long long constexpr int N1e55; std::vector<int> e…

openwrt使用教程

openwrt 插件安装 首先 我们需要明确自己什么版本的cpu 进入docker 然后 cat /proc/cpuinfo# 查看CPU信息 uname -m# 查看CPU架构 cat /proc/meminfo# 查看内存使用情况 df -h# 查看磁盘的使用率 uname -a# 查看内核信息 opkg print-architecture# 可接受的架构arm a5 比较奇…

并网逆变器+VSG控制+预同步控制+电流电流双环控制(Simulink仿真实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

C++ stack和queue及优先级队列

stack的介绍 stack是作为容器适配器被实现的&#xff0c;容器适配器即是对特定类封装作为其底层的容器&#xff0c;并提供一组特定的成员函数来访问其元素&#xff0c;将特定类作为其底层的&#xff0c;元素特定容器的尾部(即栈顶)被压入和弹出 stack的底层容器可以是任何标准…

【Windows】安装Microsoft Store,Microsoft Store离线包

用了大半年的Windows&#xff0c;今天发现没有Microsoft store 安装方法有2&#xff1a; 方法一、到这里找相关程序 https://store.rg-adguard.net/ …… 方法二、用下边的离线包 二选一下载&#xff1a; 离线包下载&#xff1a;https://wwon.lanzout.com/i4xPL1atnotc离…

哈希表的总结

今天刷了力扣的第一题&#xff08;1. 两数之和 - 力扣&#xff08;LeetCode&#xff09;&#xff09;&#xff0c;是一道用暴力解法就可以完成的题目&#xff08;两个for循环&#xff09;,但是官方解答给出了用哈希表的解法&#xff0c;用空间换时间&#xff0c;时间复杂度从O(…

【Unity】3D贪吃蛇游戏制作/WebGL本地测试及项目部署

本文是Unity3D贪吃蛇游戏从制作到部署的相关细节 项目开源代码&#xff1a;https://github.com/zstar1003/3D_Snake 试玩链接&#xff1a;http://xdxsb.top/Snake_Game_3D 效果预览&#xff1a; 试玩链接中的内容会和该效果图略有不同&#xff0c;后面会详细说明。 游戏规则 …