YOLOv9有效提点|加入SE、CBAM、ECA、SimAM等几十种注意力机制(一)


专栏介绍:YOLOv9改进系列 | 包含深度学习最新创新,主力高效涨点!!!


一、本文介绍

        本文将以SE注意力机制为例,演示如何在YOLOv9种添加注意力机制!


 《Squeeze-and-Excitation Networks》

        SENet提出了一种基于“挤压和激励”(SE)的注意力模块,用于改进卷积神经网络(CNN)的性能。SE块可以适应地重新校准通道特征响应,通过建模通道之间的相互依赖关系来增强CNN的表示能力。这些块可以堆叠在一起形成SENet架构,使其在多个数据集上具有非常有效的泛化能力。


《CBAM:Convolutional Block Attention Module》

        CBAM模块能够同时关注CNN的通道和空间两个维度,对输入特征图进行自适应细化。这个模块轻量级且通用,可以无缝集成到任何CNN架构中,并可以进行端到端训练。实验表明,使用CBAM可以显著提高各种模型的分类和检测性能。


《ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks》

        通道注意力模块ECA,可以提升深度卷积神经网络的性能,同时不增加模型复杂性。通过改进现有的通道注意力模块,作者提出了一种无需降维的局部交互策略,并自适应选择卷积核大小。ECA模块在保持性能的同时更高效,实验表明其在多个任务上具有优势。


《SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks》

        SimAM一种概念简单且非常有效的注意力模块。不同于现有的通道/空域注意力模块,该模块无需额外参数为特征图推导出3D注意力权值。具体来说,SimAM的作者基于著名的神经科学理论提出优化能量函数以挖掘神经元的重要性。该模块的另一个优势在于:大部分操作均基于所定义的能量函数选择,避免了过多的结构调整

适用检测目标:   YOLOv9模块通用改进


二、改进步骤

        以下以SE注意力机制为例在YOLOv9中加入注意力代码,其他注意力机制同理!

 2.1 复制代码

        将SE的代码辅助到models包下common.py文件中。

 2.2 修改yolo.py文件

        在yolo.py脚本的第700行(可能因YOLOv9版本变化而变化)增加下方代码。

python">        elif m in (SE,):
            args.insert(0, ch[f])

2.3 创建配置文件

        创建模型配置文件(yaml文件),将我们所作改进加入到配置文件中(这一步的配置文件可以复制models  - > detect 下的yaml修改。)。对YOLO系列yaml文件不熟悉的同学可以看我往期的yaml详解教学!

YOLO系列 “.yaml“文件解读-CSDN博客

# YOLOv9
# Powered bu https://blog.csdn.net/StopAndGoyyy
# parameters
nc: 80  # number of classes
depth_multiple: 1  # model depth multiple
width_multiple: 1  # layer channel multiple
#activation: nn.LeakyReLU(0.1)
#activation: nn.ReLU()

# anchors
anchors: 3

# YOLOv9 backbone
backbone:
  [
   [-1, 1, Silence, []],  
   
   # conv down
   [-1, 1, Conv, [64, 3, 2]],  # 1-P1/2

   # conv down
   [-1, 1, Conv, [128, 3, 2]],  # 2-P2/4

   # elan-1 block
   [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 3

   # avg-conv down
   [-1, 1, ADown, [256]],  # 4-P3/8

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 5

   # avg-conv down
   [-1, 1, ADown, [512]],  # 6-P4/16

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 7

   # avg-conv down
   [-1, 1, ADown, [512]],  # 8-P5/32

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 9
  ]

# YOLOv9 head
head:
  [
   # elan-spp block
   [-1, 1, SPPELAN, [512, 256]],  # 10

   # up-concat merge
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 7], 1, Concat, [1]],  # cat backbone P4

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 13

   # up-concat merge
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 5], 1, Concat, [1]],  # cat backbone P3

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]],  # 16 (P3/8-small)

   # avg-conv-down merge
   [-1, 1, ADown, [256]],
   [[-1, 13], 1, Concat, [1]],  # cat head P4

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 19 (P4/16-medium)

   # avg-conv-down merge
   [-1, 1, ADown, [512]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 22 (P5/32-large)
   
   
   # multi-level reversible auxiliary branch
   
   # routing
   [5, 1, CBLinear, [[256]]], # 23
   [7, 1, CBLinear, [[256, 512]]], # 24
   [9, 1, CBLinear, [[256, 512, 512]]], # 25
   
   # conv down
   [0, 1, Conv, [64, 3, 2]],  # 26-P1/2

   # conv down
   [-1, 1, Conv, [128, 3, 2]],  # 27-P2/4

   # elan-1 block
   [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 28

   # avg-conv down fuse
   [-1, 1, ADown, [256]],  # 29-P3/8
   [[23, 24, 25, -1], 1, CBFuse, [[0, 0, 0]]], # 30  

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 31

   # avg-conv down fuse
   [-1, 1, ADown, [512]],  # 32-P4/16
   [[24, 25, -1], 1, CBFuse, [[1, 1]]], # 33 

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 34

   # avg-conv down fuse
   [-1, 1, ADown, [512]],  # 35-P5/32
   [[25, -1], 1, CBFuse, [[2]]], # 36

   # elan-2 block
   [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 37
   [-1, 1, SE, [16]],  # 38

   
   
   # detection head

   # detect
   [[31, 34, 38, 16, 19, 22], 1, DualDDetect, [nc]],  # DualDDetect(A3, A4, A5, P3, P4, P5)
  ]

3.1 训练过程

        最后,复制我们创建的模型配置,填入训练脚本(train_dual)中(不会训练的同学可以参考我之前的文章。),运行即可。

YOLOv9 最简训练教学!-CSDN博客

​​

​​


SE代码

python">class SE(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SE, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

CBAM代码

python">class CBAMBlock(nn.Module):

    def __init__(self, channel=512, reduction=16, kernel_size=7):
        super().__init__()
        self.ca = ChannelAttention(channel=channel, reduction=reduction)
        self.sa = SpatialAttention(kernel_size=kernel_size)

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, _, _ = x.size()
        out = x * self.ca(x)
        out = out * self.sa(out)
        return out

ECA代码

python">class ECAAttention(nn.Module):

    def __init__(self, kernel_size=3):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.sigmoid = nn.Sigmoid()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        y = self.gap(x)  # bs,c,1,1
        y = y.squeeze(-1).permute(0, 2, 1)  # bs,1,c
        y = self.conv(y)  # bs,1,c
        y = self.sigmoid(y)  # bs,1,c
        y = y.permute(0, 2, 1).unsqueeze(-1)  # bs,c,1,1
        return x * y.expand_as(x)

SimAM代码

python">class SimAM(torch.nn.Module):
    def __init__(self, e_lambda=1e-4):
        super(SimAM, self).__init__()

        self.activaton = nn.Sigmoid()
        self.e_lambda = e_lambda

    def __repr__(self):
        s = self.__class__.__name__ + '('
        s += ('lambda=%f)' % self.e_lambda)
        return s

    @staticmethod
    def get_module_name():
        return "simam"

    def forward(self, x):
        b, c, h, w = x.size()

        n = w * h - 1

        x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
        y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5

        return x * self.activaton(y)

如果觉得本文章有用的话给博主点个关注吧!



http://www.niftyadmin.cn/n/5402314.html

相关文章

【系统分析师】-需求工程

一、需求工程 需求工程分为需求开发和需求管理。 需求开发:需求获取,需求分析,需求定义、需求验证。 需求管理:变更控制、版本控制、需求跟踪,需求状态跟踪。(对需求基线的管理) 1.1需求获取…

IDEA spring-boot-devtools 热部署

1、IDEA 编写 Spring Boot 项目时&#xff0c;修改了Java文件&#xff0c;浏览器无法实时访问修改后的内容时&#xff0c;此时可以设置热部署插件。 2、在 pom.xml 文件中添加热部署依赖&#xff0c;<plugins> 中设置插件 fork 为 true <dependencies><!--添加…

了解处理器

了解处理器 摘要写在前面1. 计算机简介1.1.计算机发展简史1.2.计算机分类1.3.PC机结构 2.初识处理器2.1.处理器的硬件模型2.2.处理器的编程模型2.3.处理器的分层模型2.4.如何选择处理器 3.指令集体系结构3.1.处理器编程模型3.2.指令集发展历程3.3.指令集分类3.4.汇编语言格式3.…

在Redhat 7 Linux上安装llama.cpp [ 错误stdatomic.h: No such file or directory]

前期准备 在github上下载llama.cpp或克隆。 GitHub - ggerganov/llama.cpp: LLM inference in C/C ​ git clone https://github.com/ggerganov/llama.cpp.gitcd llama.cpp 执行make命令编译llama.cpp make 在huggingface里下载量化了的 gguf格式的llama2模型。 https:/…

SpringCloud gateway限流无效,redis版本低的问题

在使用springCloud gateway的限流功能的时候&#xff0c;配置RedisRateLimiter限流无效&#xff0c;后来发现是Redis版本过低导致的问题&#xff0c;实测 Redis版本为3.0.504时限流无效&#xff0c;改用7.0.x版本的Redis后限流生效。查了资料发现很多人都遇见过这个问题&#x…

STM32标准库——(14)I2C通信协议、MPU6050简介

1.I2C通信 I2C 通讯协议(Inter&#xff0d;Integrated Circuit)是由Phiilps公司开发的&#xff0c;由于它引脚少&#xff0c;硬件实现简单&#xff0c;可扩展性强&#xff0c; 不需要USART、CAN等通讯协议的外部收发设备&#xff0c;现在被广泛地使用在系统内多个集成电路(IC)间…

Git多人合作的推送流程

多人合作时&#xff0c;使用Git进行代码推动&#xff08;push&#xff09;需要一定的协调和规范&#xff0c;以确保代码库的整体健康。以下是一个常见的多人合作时的Git代码推动流程&#xff1a; 同步主分支&#xff1a; 在推送之前&#xff0c;确保你的本地主分支&#xff08;…

从metashape导出深度图,从深度图恢复密集点云

从metashape导出深度图&#xff0c;从深度图恢复密集点云 1.从metashape导出深度图 参考&#xff1a;https://blog.csdn.net/WHU_StudentZhong/article/details/123107072?spm1001.2014.3001.5502 2.从深度图建立密集点云 首先从metashape导出blockExchange格式的xml文件&…