注意力机制添加方法

news/2024/7/10 23:01:35 标签: 深度学习, YOLO, 注意力机制, 目标检测

 

要将注意力机制模块添加到YoloV5工程项目中的yolo.py中,可参考以下四种情况。

以下4个elif代码来自https://yolov5.blog.csdn.net/article/details/129108082

elif m in [SimAM, ECA, SpatialGroupEnhance,TripletAttention]:
    args = [*args[:]]

elif m in [CoordAtt, GAMAttention]:
    c1, c2 = ch[f], args[0]
    if c2 != no:
        c2 = make_divisible(c2 * gw, 8)
    args = [c1, c2, *args[1:]]
    
elif m in [SE, ShuffleAttention, CBAM, SKAttention, DoubleAttention, CoTAttention, EffectiveSEModule,GlobalContext, GatherExcite, MHSA]:
    c1 = ch[f]
    args = [c1, *args[0:]]
    
elif m in [S2Attention, NAMAttention, CrissCrossAttention, SequentialPolarizedSelfAttention,ParallelPolarizedSelfAttention, ParNetAttention]:
    c1 = ch[f]
    args = [c1]

 

根据这4种情况,我们在yaml文件中,填写args时(比如下图中RefConv的[1024,3,1]以及SE中的[1024]),需要填入的参数个数是不同的

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],  #2
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],  #4
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],  #6
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],  # 8
   [-1,1,SE,[1024]],  #9
   [[-1, 1, SimAM, [1e-4]],  # 10
   [-1, 1, SPPF, [1024, 5]],  # 11
   [-1, 1, RefConv, [1024, 3, 1]],  # 12
  ]

具体来说,是要将elif模块添加到yolo.py文件中的parse_model函数里。在编写elif模块代码时,我们需要关注的是,你的注意力机制模块代码(在common.py)中的__init__里面的参数,是否设置了“输入通道数”和“输出通道数”,这两个参数。

 

一、先上结论

情况1:_init_中不包含“输入通道数”和“输出通道数”,但含有其它参数

以下这些模块:SimAM,ECA,SpatialGroupEnhance,TripletAttention 全都满足情况1

elif m in [SimAM, ECA, SpatialGroupEnhance,TripletAttention]:
    args = [*args[:]]

以下是SimAM,ECA,SpatialGroupEnhance,TripletAttention 的__init__

class SimAM(torch.nn.Module):
    def __init__(self, e_lambda=1e-4):

class ECA(nn.Module):
    def __init__(self, k_size=3):
    
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):

class SpatialGroupEnhance(nn.Module):
    def __init__(self, groups=8):

class TripletAttention(nn.Module):
    def __init__(self, no_spatial=False):

这里解释一下args = [*args[:]]是什么意思

在Python中,*args 是一个特殊的语法,用于在函数定义中处理不确定数量的位置参数。args 是一个元组,包含了所有传递给函数的位置参数。

args[:] 是一个切片操作,它会创建一个 args 的浅拷贝。这意味着如果你修改了 args[:] 的内容,原始的 args 不会被改变。

[*args[:]] 则是将 args[:] 中的元素解包(unpack)成一个列表。这样做的目的通常是为了创建一个新的列表,而不是修改原始的 args。

例如:

args = [1, 2, 3, 4]
new_args = [*args[:]]
print(new_args) #输出为[1,2,3,4]

情况2:_init_中同时包含“输入通道数”和“输出通道数”,且含有其它参数

以下这些模块:CoordAtt, GAMAttention 全都满足情况2

elif m in [CoordAtt, GAMAttention]:
    c1, c2 = ch[f], args[0]
    if c2 != no:
        c2 = make_divisible(c2 * gw, 8)
    args = [c1, c2, *args[1:]]

以下是CoordAtt, GAMAttention 的__init__

class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):

class GAMAttention(nn.Module):
    def __init__(self, c1, c2, group=True, rate=4):

情况3:_init_中只包含“输入通道数”,不包含“输出通道数”,且含有其它参数

以下这些模块:SE, ShuffleAttention, CBAM, SKAttention, DoubleAttention, CoTAttention, EffectiveSEModule,GlobalContext, GatherExcite, MHSA全都满足情况3

elif m in [SE, ShuffleAttention, CBAM, SKAttention, DoubleAttention, CoTAttention, EffectiveSEModule,GlobalContext, GatherExcite, MHSA]:
    c1 = ch[f]
    args = [c1, *args[0:]]

以下是SE, ShuffleAttention, CBAM, SKAttention, DoubleAttention, CoTAttention, EffectiveSEModule,GlobalContext, GatherExcite, MHSA的__init__

#SE机制它的输入通道和输出通道是一样的,所以在实现上可以只传入输入通道c1,但如果也给出输出通道c2的参数也是可以的。下面这两种都是在迪菲赫尔曼博客中实现过的SE模块
class SE(nn.Module):
    def __init__(self, c1, ratio=16):
class SE(nn.Module):
    def __init__(self, c1, c2, ratio=16):

class ShuffleAttention(nn.Module):
    def __init__(self, channel=512, G=8):

class CBAM(nn.Module):
    def __init__(self, c1, ratio=16, kernel_size=7):

class SKAttention(nn.Module):
    def __init__(self, channel=512, kernels=[1, 3, 5, 7], reduction=16, group=1, L=32):

class DoubleAttention(nn.Module):
    def __init__(self, in_channels, reconstruct=True):

class CoTAttention(nn.Module):
    def __init__(self, dim=512, kernel_size=3):

class EffectiveSEModule(nn.Module):
    def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid'):

class GlobalContext(nn.Module):
    def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False,
                 rd_ratio=1. / 8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'):
class GatherExcite(nn.Module):
    def __init__(
            self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,
            rd_ratio=1. / 16, rd_channels=None, rd_divisor=1, add_maxpool=False,
            act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):

class MHSA(nn.Module):
    def __init__(self, n_dims, width=14, height=14, heads=4, pos_emb=False):    

情况4:_init_中只包含“输入通道数”,不包含“输出通道数”,且不含有其他参数(注意对比情况3)

以下这些模块:S2Attention, NAMAttention, CrissCrossAttention, SequentialPolarizedSelfAttention,ParallelPolarizedSelfAttention, ParNetAttention全都满足情况4

elif m in [S2Attention, NAMAttention, CrissCrossAttention, SequentialPolarizedSelfAttention,ParallelPolarizedSelfAttention, ParNetAttention]:
    c1 = ch[f]
    args = [c1]

以下是S2Attention, NAMAttention, CrissCrossAttention, SequentialPolarizedSelfAttention,ParallelPolarizedSelfAttention, ParNetAttention的__init__

class S2Attention(nn.Module):
    def __init__(self, channels=512):

class NAMAttention(nn.Module):
    def __init__(self, channels):

class CrissCrossAttention(nn.Module):
    def __init__(self, in_dim):

class SequentialPolarizedSelfAttention(nn.Module):
    def __init__(self, channel=512):

class ParallelPolarizedSelfAttention(nn.Module):
    def __init__(self, channel=512):
 
class ParNetAttention(nn.Module):
    def __init__(self, channel=512):

 

二、解释代码

在理解代码前,我们需要知道,在parse_model函数中,args列表的前两个位置被设计为存放输入通道数(c1)和输出通道数(c2)。这是因为在创建这些模块时,我们通常会按照这个顺序传递参数。例如,对于nn.Conv2d,其构造函数的签名为Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True),其中in_channels和out_channels就对应于我们的c1和c2。因此,当我们在parse_model函数中创建这些模块时,我们需要先获取c1和c2,然后将它们放在args的前两个位置,以确保它们能被正确地传递给模块的构造函数。

情况1

'''
这段代码的目的是为了处理SimAM模块的参数。

elif m in [SimAM]:
这行代码检查当前的模块m是否是SimAM模块。如果是,那么就执行下一行代码。

args = [*args[:]]:
这行代码创建了args的一个浅拷贝。在Python中,args[:]会创建一个新的列表,这个新列表包含了args中的所有元素。*操作符会将这个新列表解包,然后我们再用[]将解包后的元素重新组装成一个新的列表。所以,[*args[:]]就等于args[:],它们都会创建args的一个浅拷贝。

综上,这段代码使args列表保持不变,因为SimAM模块不需要修改输入和输出通道数。
'''

elif m in [SimAM, ECA, SpatialGroupEnhance,TripletAttention]:
    args = [*args[:]]

 

情况2

'''
这段代码的目的是为了处理CoordAtt和GAMAttention模块的参数。

elif m in [CoordAtt, GAMAttention]:
这行代码检查当前的模块m是否是CoordAtt模块或GAMAttention模块。如果是,那么就执行下面的代码。

c1, c2 = ch[f], args[0]:
这行代码从ch和args两个列表中获取了两个值c1和c2。ch是一个列表,它保存了模型中每一层的输出通道数。f是一个索引,它指向ch中的某个元素,这个元素表示当前层的输入通道数。所以ch[f]就是当前层的输入通道数。args也是一个列表,它保存了当前层的参数。在模型配置文件中,每一层的参数都被保存在一个列表中,例如[64, 6, 2, 2]。这个列表的第一个元素通常是当前层的输出通道数。所以args[0]就是当前层的输出通道数。

if c2 != no:
这行代码检查当前层的输出通道数c2是否不等于no。no是模型的总输出通道数。

c2 = make_divisible(c2 * gw, 8):
如果c2不等于no,那么就重新计算c2的值。gw是模型的宽度倍数,make_divisible(c2 * gw, 8)会将c2 * gw调整为最接近的8的倍数。这是因为某些硬件(如GPU)在处理通道数为8的倍数的数据时,可以获得更好的性能。

args = [c1, c2, *args[1:]]:
这行代码创建了一个新的参数列表args。新的args列表的第一个元素是c1,第二个元素是c2,剩下的元素是原始args列表的第二个元素及其后面的所有元素。*args[1:]是Python的解包(unpack)操作,它可以将列表args[1:]中的所有元素解包出来。所以,[c1, c2, *args[1:]]就等于[c1, c2]和args[1:]两个列表的连接。
'''

elif m in [CoordAtt, GAMAttention]:
    c1, c2 = ch[f], args[0]
    if c2 != no:
        c2 = make_divisible(c2 * gw, 8)
    args = [c1, c2, *args[1:]] 


情况3

'''
这段代码的目的是为了处理SE, ShuffleAttention等模块的参数。

elif m in [SE, ShuffleAttention, CBAM, SKAttention, DoubleAttention, CoTAttention, EffectiveSEModule,GlobalContext, GatherExcite, MHSA]:
这行代码检查当前的模块m是否是SE, ShuffleAttention等模块中的一个。如果是,那么就执行下面的代码。

c1 = ch[f]:
这行代码从ch列表中获取了一个值c1。ch是一个列表,它保存了模型中每一层的输出通道数。f是一个索引,它指向ch中的某个元素,这个元素表示当前层的输入通道数。所以ch[f]就是当前层的输入通道数。

args = [c1, *args[0:]]:
这行代码创建了一个新的参数列表args。新的args列表的第一个元素是c1,剩下的元素是原始args列表的第一个元素及其后面的所有元素。*args[0:]是Python的解包(unpack)操作,它可以将列表args[0:]中的所有元素解包出来。所以,[c1, *args[0:]]就等于[c1]和args[0:]两个列表的连接。

综上,这段代码将当前层的输入通道数c1添加到参数列表args的开始位置,因为这些模块的初始化函数通常需要输入通道数作为第一个参数。

'''


elif m in [SE, ShuffleAttention, CBAM, SKAttention, DoubleAttention, CoTAttention, EffectiveSEModule,GlobalContext, GatherExcite, MHSA]:
    c1 = ch[f]
    args = [c1, *args[0:]]


 

情况4

'''
这段代码的目的是为了处理S2Attention, NAMAttention等模块的参数。

elif m in [S2Attention, NAMAttention, CrissCrossAttention, SequentialPolarizedSelfAttention,ParallelPolarizedSelfAttention, ParNetAttention]:
这行代码检查当前的模块m是否是S2Attention, NAMAttention等模块中的一个。如果是,那么就执行下面的代码。

c1 = ch[f]:
这行代码从ch列表中获取了一个值c1。ch是一个列表,它保存了模型中每一层的输出通道数。f是一个索引,它指向ch中的某个元素,这个元素表示当前层的输入通道数。所以ch[f]就是当前层的输入通道数。

args = [c1]:这行代码创建了一个新的参数列表args。新的args列表只有一个元素,就是c1。

综上,这段代码将当前层的输入通道数c1作为唯一的参数传递给这些模块,因为这些模块的初始化函数通常只需要输入通道数作为参数。
'''


elif m in [S2Attention, NAMAttention, CrissCrossAttention, SequentialPolarizedSelfAttention,ParallelPolarizedSelfAttention, ParNetAttention]:
    c1 = ch[f]
    args = [c1]


 

 

 


http://www.niftyadmin.cn/n/5248626.html

相关文章

【学习笔记】群作用与轨道-稳定化子定理

群作用,轨道-稳定化子定理 不妨通过一个简单的例子来引入群作用的概念,恕我直言这个东西真的很神奇 引入 令 S S S是一个非空集合,我们考虑所有 S → S S\rightarrow S S→S的双射 f f f所组成的集合,记为 P e r m ( S ) Perm(S…

linux系统web服务以及apache介绍

web服务 WEB服务WEB服务简介WEB 服务协议 ApacheApache 介绍 apache安装apache目录介绍访问控制虚拟主机基于ip基于域名基于端口临时添加ip WEB服务 WEB服务简介 目前最主流的三个Web服务器是Apache、Nginx、 IIS。 - WEB服务器一般指网站服务器,可以向浏览器等We…

web项目服务器后台运行

阿里官方方法 在Linux系统的ECS实例内,当断开SSH客户端后,如何保持进程继续运行的解决方案_云服务器 ECS-阿里云帮助中心 (aliyun.com)

《算法通关村——原来贪心如此简单》

《算法通关村——原来贪心如此简单》 贪心如此简单,我们就通过几个题目了解一下吧。 455. 分发饼干 假设你是一位很棒的家长,想要给你的孩子们一些小饼干。但是,每个孩子最多只能给一块饼干。 对每个孩子 i,都有一个胃口值 g[…

uniapp页面跳转函数

1.在需要跳转的类加一个点击事件click 2.写一个跳转函数 ToSetting() {uni.navigateTo({url:"/pages/tabbar-5-detial/setting/setting"})} ok了。 跳转的页面是静态页面,即没有从上一个页面获取数据。 最初级的页面跳转。。。

k8s之镜像拉取时使用secret

k8s之secret使用 一、说明二、secret使用2.1 secret类型2.2 创建secret2.3 配置secret 一、说明 从公司搭建的网站镜像仓库,使用k8s部署服务时拉取镜像失败,显示未授权: 需要在拉取镜像时添加认证信息. 关于secret信息,参考: https://www.…

使用阿里云国际CDN加速后网站无法访问的排查步骤

使用阿里云国际CDN加速后网站无法访问的排查步骤,下面是一些常见的问题,以:www.c.9he.com为例,如果解决不了来信服务器厂商解决。 检查CDN访问异常是CDN节点的问题还是源站问题 如果是源站访问异常,请直接排查源站服务…

个人信息展示网站需求分析报告

目录 一. 概述1.1 设计目的1.2 术语定义 二. 需求分析三. 系统功能需求3.1 功能总览3.2 业务流程图1.系统用例图2.系统流程 四.开发技术4.1 技术组成 五.界面及运行环境1.用户界面2.运行环境 一. 概述 1.1 设计目的 兴趣使然。将知识点综合运用。CSDN有功能限制,因…