在yolov5源码中添加注意力机制

news/2024/7/11 1:11:25 标签: YOLO, python

yolov5源码中添加注意力机制

  • 1 项目环境配置
    • 1.1 yolov5 源码下载
    • 1.2 创建虚拟环境
    • 1.3 安装依赖
  • 2 常用的注意力机制
    • 2.1 SE 注意力机制
    • 2.2 CBAM 注意力机制
    • 2.3 ECA 注意力机制
    • 2.4 CA 注意力机制
  • 3 添加方式
    • 3.1 修改 common.py 文件
    • 3.2 修改 yolo.py 文件
    • 3.3 修改 yolov5s.yaml 文件
    • 3.4 修改 train.py 文件

1 项目环境配置

1.1 yolov5 源码下载

点击下载

1.2 创建虚拟环境

win+r打开Windows终端界面输入(其中yolov5为我命名的虚拟环境名称):

mkvirtualenv yolov5

进入虚拟环境

python">workon yolov5

没有此模块无法创建虚拟环境的请移步:Python 的虚拟环境

1.3 安装依赖

  1. 依赖前提:有python环境以及pytorch

本人环境:python3.9,cuda11.7
安装 pytorch 移步官网

在这里插入图片描述

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117

避免不必要的错误,建议使用 pip 安装

  1. 安装项目依赖

进入项目文件夹,终端键入:

pip install -r requirements.txt

环境搭建完成!

2 常用的注意力机制

2.1 SE 注意力机制

python"># SE
class SE(nn.Module):
    def __init__(self, c1, c2, ratio=16):
        super(SE, self).__init__()
        #c*1*1
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.l1 = nn.Linear(c1, c1 // ratio, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.l2 = nn.Linear(c1 // ratio, c1, bias=False)
        self.sig = nn.Sigmoid()
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avgpool(x).view(b, c)
        y = self.l1(y)
        y = self.relu(y)
        y = self.l2(y)
        y = self.sig(y)
        y = y.view(b, c, 1, 1)
        return x * y.expand_as(x)

2.2 CBAM 注意力机制

python"># CBAM
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
        max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return out
    
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        # (特征图的大小-算子的size+2*padding)/步长+1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        # 1*h*w
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        #2*h*w
        x = self.conv(x)
        #1*h*w
        return self.sigmoid(x)
    
class CBAM(nn.Module):
    def __init__(self, c1, c2, ratio=16, kernel_size=7):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(c1, ratio)
        self.spatial_attention = SpatialAttention(kernel_size)
    def forward(self, x):
        out = self.channel_attention(x) * x
        # c*h*w
        # c*h*w * 1*h*w
        out = self.spatial_attention(out) * out
        return out

2.3 ECA 注意力机制

python">class ECA(nn.Module):
    """Constructs a ECA module.
    Args:
        channel: Number of channels of the input feature map
        k_size: Adaptive selection of kernel size
    """

    def __init__(self, c1,c2, k_size=3):
        super(ECA, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # feature descriptor on the global spatial information
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        # Multi-scale information fusion
        y = self.sigmoid(y)

        return x * y.expand_as(x)

2.4 CA 注意力机制

python"># CA
class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)
    def forward(self, x):
        return self.relu(x + 3) / 6
class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)
    def forward(self, x):
        return x * self.sigmoid(x)

class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        mip = max(8, inp // reduction)
        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
    def forward(self, x):
        identity = x
        n, c, h, w = x.size()
        #c*1*W
        x_h = self.pool_h(x)
        #c*H*1
        #C*1*h
        x_w = self.pool_w(x).permute(0, 1, 3, 2)
        y = torch.cat([x_h, x_w], dim=2)
        #C*1*(h+w)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)
        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()
        out = identity * a_w * a_h
        return out

3 添加方式

3.1 修改 common.py 文件

修改 yolov5-master/models/common.py文件,将上述提供的注意力机制代码块直接加到 common.py 文件夹的末尾,此处以SE注意力机制为例

在这里插入图片描述

3.2 修改 yolo.py 文件

修改 yolov5-master/models/yolo.py文件,将注意力机制类名SE添加到 yolo.py 文件的 parse_model方法中如下集合里

在这里插入图片描述

3.3 修改 yolov5s.yaml 文件

修改 yolov5-master/models/yolov5s.yaml文件,将SE注意力机制模块添加到你想添加的位置,常见的有C3模块的后面,以及在主干网络 backboneSPPF 的前一层,这里我将SE注意力机制模块添加在主干网络 backboneSPPF 的前一层

修改前:
在这里插入图片描述

修改后:

在这里插入图片描述

另外,由于我将SE注意力机制模块添加在了第 9 层(层索引为 9,起始层索引为 0),那么,原来的第 9 层,以及第 9 层之后的层数都要加 1

加1前:

在这里插入图片描述

加1后:

在这里插入图片描述

3.4 修改 train.py 文件

修改 yolov5-master/train.py 文件,在默认参数 --cfg后面的 default中添加我们前面修改过的 yolov5s.yaml文件

修改前:
在这里插入图片描述
修改后:

在这里插入图片描述


http://www.niftyadmin.cn/n/1778801.html

相关文章

警惕命名与关键字同名带来的问题

2019独角兽企业重金招聘Python工程师标准>>> api名 在Account中创建一个按钮,执行一段js脚本,其中使用了{!Account.Id}来获得当前记录的Id。 当点击时,整个脚本执行完后程序结果与预期不一致,但有些记录程序结果与预期…

Android与Swift iOS开发:语言与框架对比

Swift是现在Apple主推的语言,2014年新推出的语言,比Scala等“新”语言还要年轻10岁。2015年秋已经开源。目前在linux上可用,最近已经支持Android NDK;在树莓派上有SwiftyGPIO库,可以通过GPIO控制一些硬件。 Object C i…

深度解读CDN高防防御机制,看它如何为服务器保驾护航?

随着互联网的飞速发展,人们对服务器的要求也越来越高,CDN高防也就此应运而生。那么CDN高防它到底是什么,它又是如何为服务器提供保障的呢?下面就为大家详细的介绍下CDN高防防御网络攻击的机制。 CDN高防是通过广泛分布的CDN节点和…

CDN高防是如何保护游戏服务器?它的这些“秘密武器”你知道多少呢?

在现在时代,互联网经济飞速发展,带火了很多很多的产业。比如现在最火的电竞产业,它的爆火让曾经大家都以打游戏就是不务正业的观念得到改变。让打游戏不用再偷偷的去黑网吧,电子竞技专业的学生可以直接在教室里光明正大的打&#…

第二阶段团队项目冲刺站立会议(二)

昨天做了什么: 第一天好像是最难开头的一天,但同时也是最好过的一天。昨日的进展并没有明确的进步。 今天准备做什么: 首先从竞赛开始抓起,不仅考虑到项目是竞赛的附加,也考虑到实现上的相似之处。、 遇到了什么问题&a…

LINUX下的tty,console与串口

LINUX下的tty,console与串口

CDN高防为什么会被互联网企业钟爱?它是如何保障服务器安全?

CDN高防在当今信息科技发展的时代,为什么会受到广大互联网企业的钟爱呢?它是如何防御网络犯罪分子攻击的呢,又是如何保障服务器安全,让网站平稳、高效的运营的呢?下面的文章为你详细解析CDN高防的奥秘。 首先我们先从最…

忘记mysql登录密码解决方案匹配(liunx服务器)

1 CRT 登录liunx服务器 修改mysql的配置,改为无密码登录 $ vi /etc/my.cnf 在[mysqld]的段中加上一句:skip-grant-tables 例如:[mysqld]datadir/var/lib/mysqlsocket/var/lib/mysql/mysql.sockskip-grant-tables 2 重启mysql $ service my…