YOLOV5添加 ECA CA SE CBAM 等八种注意力机制(小白可用)

news/2024/7/11 0:38:27 标签: YOLO, cnn, transformer

目录

CBAM注意力机制原理及代码实现

代码实现

 yaml文件

修改后的结构图

SE注意力机制

SE结构图

完整代码实现

报错


⭐欢迎大家订阅我的专栏一起学习⭐

🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

http://t.csdnimg.cn/Q4aka

http://t.csdnimg.cn/sVHxv

💡魔改网络、复现论文、优化创新💡

原理

尽管Ultralytics 推出了最新版本的 YOLOv8 模型。但YOLOv5作为一个经典的目标检测的算法,仍经常被提及。注意力机制是提高模型性能最热门的方法之一。本次将介绍几种常见的注意力机制,这些注意力机制在大多数的数据集上均能有效的提升目标检测的精度/召回率/准确率。

CBAM注意力机制原理及代码实现
CBAM注意力机制结构图


CBAM(Convolutional Block Attention Module)是一种用于卷积神经网络(CNN)的注意力机制,它能够增强网络对输入特征的关注度,提高网络性能。CBAM 主要包含两个子模块:通道注意力模块(Channel Attention Module)和空间注意力模块(Spatial Attention Module)。

以下是CBAM注意力机制的基本原理:

1. 通道注意力模块(Channel Attention Module):
输入:经过卷积层的特征图。
处理步骤:
对每个通道进行全局平均池化,得到通道的全局平均值。
通过两个全连接层,将全局平均值映射为两个权重向量(一个用于缩放,一个用于偏置)。
将这两个权重向量与原始特征图相乘,以加权调整每个通道的重要性。

2. 空间注意力模块(Spatial Attention Module):**
输入:通道注意力模块的输出。
处理步骤:
     对每个通道的特征图进行分别的最大池化和平均池化,得到两个空间特征图。
     将这两个空间特征图相加,通过一个卷积层产生一个权重图。
     将原始特征图与权重图相乘,以加权调整每个空间位置的重要性。

3. 整合:
   将通道注意力模块和空间注意力模块的输出相乘,得到最终的注意力增强特征图。
   将这个注意力增强的特征图传递给网络的下一层进行进一步处理。

CBAM的关键优势在于它能够同时考虑通道和空间信息,有助于网络更好地理解和利用输入特征。这种注意力机制有助于提高网络在视觉任务上的性能,使其能够更有针对性地关注重要的特征。

代码实现
class ChannelAttention(nn.Module):
    """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""
 
    def __init__(self, channels: int) -> None:
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.act = nn.Sigmoid()
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * self.act(self.fc(self.pool(x)))
 
 
class SpatialAttention(nn.Module):
    """Spatial-attention module."""
 
    def __init__(self, kernel_size=7):
        """Initialize Spatial-attention module with kernel size argument."""
        super().__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.act = nn.Sigmoid()
 
    def forward(self, x):
        """Apply channel and spatial attention on input for feature recalibration."""
        return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
 
 
class CBAM(nn.Module):
    """Convolutional Block Attention Module."""
 
    def __init__(self, c1, kernel_size=7):  # ch_in, kernels
        super().__init__()
        self.channel_attention = ChannelAttention(c1)
        self.spatial_attention = SpatialAttention(kernel_size)
 
    def forward(self, x):
        """Applies the forward pass through C1 module."""
        return self.spatial_attention(self.channel_attention(x))
 yaml文件
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.67  # model depth multiple
width_multiple: 0.75  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, CBAM, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)

   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]
修改后的结构图

改进后完整的YOLOv5结构图
SE注意力机制

卷积神经网络建立在卷积运算的基础上,它通过在局部感受野内将空间和通道信息融合在一起来提取信息特征。为了提高网络的表示能力,最近的几种方法已经显示了增强空间编码的好处。在这项工作中,我们专注于通道关系,并提出了一种新颖的架构单元,我们将其称为“挤压和激励”(SE)块,它通过显式建模通道之间的相互依赖性来自适应地重新校准通道方面的特征响应。我们证明,通过将这些块堆叠在一起,我们可以构建在具有挑战性的数据集上具有极好的泛化能力的 SENet 架构。至关重要的是,我们发现 SE 模块能够以最小的额外计算成本为现有最先进的深度架构带来显着的性能改进。 SENets 构成了我们 ILSVRC 2017 分类提交的基础,该分类提交赢得了第一名,并将 top-5 错误率显着降低至 2.251%,与 2016 年获胜条目相比相对提高了约 25%。

我们通过引入一个新的架构单元(我们将其称为“挤压和激励”(SE)块)来研究架构设计的另一个方面 - 通道关系。我们的目标是通过显式建模网络卷积特征通道之间的相互依赖性来提高网络的表示能力。为了实现这一目标,我们提出了一种允许网络执行特征重新校准的机制,通过该机制它可以学习使用全局信息来选择性地强调信息丰富的特征并抑制不太有用的特征。

SE结构图
模型结构


SE 构建块的基本结构如图所示。对于任何给定的变换 Ftr : X → U, X ∈ RH′×W ′×C′ , U ∈ RH×W ×C ,,我们可以构建相应的 SE 块来执行特征重新校准。特征 U 首先通过挤压操作,该操作聚合跨空间维度 H × W 的特征图以生成通道描述符。该描述符嵌入了通道特征响应的全局分布,使得来自网络的全局感受野的信息能够被其较低层利用。接下来是激励操作,其中通过基于通道依赖性的自门机制为每个通道学习特定于样本的激活,控制每个通道的激励。然后对特征图 U 进行重新加权以生成 SE 块的输出,然后可以将其直接馈送到后续层中。

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)
class SEConv(nn.Module):
    # channel attentive convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, reduction=16):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
        self.se = SELayer(c2, reduction)
 
    def forward(self, x):
        residual = x
        out = self.bn(self.conv(x))
        out = self.se(out)
        out += residual
        return self.act(out)
class SEBottleneck(nn.Module):
    # Standard bottleneck
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = SEConv(c1, c_, 1, 1)
        self.cv2 = SEConv(c_, c2, 3, 1, g=g)
        self.add = shortcut and c1 == c2
 
    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

其他注意力机制未完待续...

完整代码实现

链接: https://pan.baidu.com/s/1LYdIwI71ZLCo4GqM6L5PBg?pwd=4yr5 提取码: 4yr5 

yaml文件记得选择对应的注意力机制

报错

如果报错,查看

解决Yolov5的RuntimeError: result type Float can‘t be cast to the desired output type long int 问题_yolov5 runtimeerror: result type float can't be ca-CSDN博客


http://www.niftyadmin.cn/n/5428315.html

相关文章

基于SCIP的约束处理器Conshdlr添加惰性约束——以TSP问题为例

文章目录 1. TSP案例引入2. 考虑惰性约束的求解效率对比2.1 求解基础TSP模型2.2 基于SCIP的Conshdlr添加惰性约束1. TSP案例引入 在运筹学建模和求解过程中,“lazy constraints”(惰性约束)是一种动态添加约束的策略,松弛部分约束后求解得到的“可行解”,不断地进行可行性…

基于HarmonyOS ArkTS中秋国庆祝福程序、以代码之名,写阖家团圆祝福

中秋、国庆双节将至,作为程序员,以代码之名,表达对于阖家团圆的祝福。本节将演示如何在基于HarmonyOS ArkUI的SwiperController、Image、Swiper等组件来实现节日祝福轮播程序。 规则要求具体要求如下: 1、根据主题,用…

在【IntelliJ IDEA】中配置【Tomcat】【2023版】【中文】【图文详解】

作为一款功能强大的集成开发环境(IDE),IntelliJ IDEA为Web服务器提供了卓越的支持,从而极大地简化了程序员在Web开发过程中的工作流程。学习Java Web开发实质上就是掌握如何创造动态Web资源,这些资源在完成开发后&…

安卓多个listView拖动数据交换位置和拖动

注意这里只是给出大概思路&#xff0c;具体可以参考修改自己想要的 public class MainActivity extends AppCompatActivity {private ListView listView1;private ListView listView2;private ArrayAdapter<String> adapter1;private ArrayAdapter<String> adapter…

数字孪生与智慧城市:实现城市治理现代化的新路径

随着信息技术的迅猛发展&#xff0c;智慧城市已成为城市发展的必然趋势。数字孪生技术作为智慧城市建设的重要支撑&#xff0c;以其独特的优势为城市治理现代化提供了新的路径。本文将探讨数字孪生技术在智慧城市中的应用&#xff0c;以及如何实现城市治理的现代化。 一、数字…

#QT(显示组件、日期时间组件)

1.IDE&#xff1a;QTCreator 2.实验&#xff1a; 3.记录 4.代码 #include "widget.h" #include "ui_widget.h" #include <QDateTime> #include <QCalendar> #include <QCalendarWidget> Widget::Widget(QWidget *parent): QWidget(pare…

Debezium日常分享系列之:Debezium2.5稳定版本之Mysql连接器的工作原理

Debezium日常分享系列之&#xff1a;Debezium2.5稳定版本之Mysql连接器的工作原理 一、Mysql连接器的工作原理1.支持的 MySQL 拓扑2.MariaDB 支持3.Schema history topic4.Schema change topic5.Snapshots1&#xff09;使用全局读锁的初始快照2&#xff09;Debezium MySQL 连接…

Nodejs 第五十四章(net)

net模块是Node.js的核心模块之一&#xff0c;它提供了用于创建基于网络的应用程序的API。net模块主要用于创建TCP服务器和TCP客户端&#xff0c;以及处理网络通信。 TCP&#xff08;Transmission Control Protocol&#xff09;是一种面向连接的、可靠的传输协议&#xff0c;用于…