YOLOv5、YOLOv8改进:CotNet Transformer

news/2024/7/10 22:50:00 标签: YOLO, transformer, 深度学习

1.简介

京东AI研究院提出的一种新的注意力结构。将CoT Block代替了ResNet结构中的3x3卷积,在分类检测分割等任务效果都出类拔萃

论文地址:https://arxiv.org/pdf/2107.12292.pdf

源代码地址:https://github.com/JDAI-CV/CoTNet

具有自注意力的Transformer引发了自然语言处理领域的革命,最近还激发了Transformer式架构设计的出现,并在众多计算机视觉任务中取得了具有竞争力的结果。

大多数现有设计直接在2D特征图上使用自注意力来获得基于每个空间位置的独立查询和键对的注意力矩阵,但未充分利用相邻键之间的丰富上下文。在今天分享的工作中,研究者设计了一个新颖的Transformer风格的模块,即Contextual Transformer (CoT)块,用于视觉识别。这种设计充分利用输入键之间的上下文信息来指导动态注意力矩阵的学习,从而增强视觉表示能力。从技术上讲,CoT块首先通过3×3卷积对输入键进行上下文编码,从而产生输入的静态上下文表示。

上图a是传统的self-attention仅利用孤立的查询-键对来测量注意力矩阵,但未充分利用键之间的丰富上下文。 b就是CoT块

研究者进一步将编码的键与输入查询连接起来,通过两个连续的1×1卷积来学习动态多头注意力矩阵。学习到的注意力矩阵乘以输入值以实现输入的动态上下文表示。静态和动态上下文表示的融合最终作为输出。CoT块很吸引人,因为它可以轻松替换ResNet架构中的每个3 × 3卷积,产生一个名为Contextual Transformer Networks (CoTNet)的Transformer式主干。通过对广泛应用(例如图像识别、对象检测和实例分割)的大量实验,验证了CoTNet作为更强大的主干的优越性

Attention注意力机制与self-attention自注意力机制

为什么要注意力机制?

在Attention诞生之前,已经有CNN和RNN及其变体模型了,那为什么还要引入attention机制?主要有两个方面的原因,如下:

(1)计算能力的限制:当要记住很多“信息“,模型就要变得更复杂,然而目前计算能力依然是限制神经网络发展的瓶颈。

(2)优化算法的限制:LSTM只能在一定程度上缓解RNN中的长距离依赖问题,且信息“记忆”能力并不高。

什么是注意力机制

在介绍什么是注意力机制之前,先让大家看一张图片。当大家看到下面图片,会首先看到什么内容?当过载信息映入眼帘时,我们的大脑会把注意力放在主要的信息上,这就是大脑的注意力机制。

同样,当我们读一句话时,大脑也会首先记住重要的词汇,这样就可以把注意力机制应用到自然语言处理任务中,于是人们就通过借助人脑处理信息过载的方式,提出了Attention机制。

self attention是注意力机制中的一种,也是transformer中的重要组成部分。自注意力机制是注意力机制的变体,其减少了对外部信息的依赖,更擅长捕捉数据或特征的内部相关性。自注意力机制在文本中的应用,主要是通过计算单词间的互相影响,来解决长距离依赖问题。

 传统的自注意力很好地触发了不同空间位置的特征交互,具体取决于输入本身。然而,在传统的自注意力机制中,所有成对的查询键关系都是通过孤立的查询键对独立学习的,而无需探索其间的丰富上下文。这严重限制了自注意力学习在2D特征图上进行视觉表示学习的能力。

为了缓解这个问题,研究者构建了一个新的Transformer风格的构建块,即上图 (b)中的 Contextual Transformer (CoT) 块,它将上下文信息挖掘和自注意力学习集成到一个统一的架构中。

2.YOLOv5改进

2.1 YOLOv5的yaml配置文件修改

增加以下yolov5s_cotnet.yaml文件

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, CoT3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, CoT3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, CoT3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, CoT3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024]],  # 23 (P5/32-large)

   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

2.2 common.py配置

./models/common.py文件增加以下模块

class CoT3(nn.Module):
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)
        self.m = nn.Sequential(*[CoTBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
        # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])

    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))

class CoTBottleneck(nn.Module):
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
        super(CoTBottleneck, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = CoT(c_, 3)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

class CoT(nn.Module):
    # Contextual Transformer Networks https://arxiv.org/abs/2107.12292
    def __init__(self, dim=512,kernel_size=3):
        super().__init__()
        self.dim=dim
        self.kernel_size=kernel_size

        self.key_embed=nn.Sequential(
            nn.Conv2d(dim,dim,kernel_size=kernel_size,padding=kernel_size//2,groups=4,bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU()
        )
        self.value_embed=nn.Sequential(
            nn.Conv2d(dim,dim,1,bias=False),
            nn.BatchNorm2d(dim)
        )

        factor=4
        self.attention_embed=nn.Sequential(
            nn.Conv2d(2*dim,2*dim//factor,1,bias=False),
            nn.BatchNorm2d(2*dim//factor),
            nn.ReLU(),
            nn.Conv2d(2*dim//factor,kernel_size*kernel_size*dim,1)
        )


    def forward(self, x):
        bs,c,h,w=x.shape
        k1=self.key_embed(x) #bs,c,h,w
        v=self.value_embed(x).view(bs,c,-1) #bs,c,h,w

        y=torch.cat([k1,x],dim=1) #bs,2c,h,w
        att=self.attention_embed(y) #bs,c*k*k,h,w
        att=att.reshape(bs,c,self.kernel_size*self.kernel_size,h,w)
        att=att.mean(2,keepdim=False).view(bs,c,-1) #bs,c,h*w
        k2=F.softmax(att,dim=-1)*v
        k2=k2.view(bs,c,h,w)
        
        return k1+k2

2.3 yolo.py配置修改

然后找到./models/yolo.py文件下里的parse_model函数,将加入的模块名CoT3加入进去
在 models/yolo.py文件夹下

定位到parse_model函数中
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):内部
对应位置 下方只需要增加 CoT3模块

 后续专栏会持续更新v7和v8内容


http://www.niftyadmin.cn/n/5050830.html

相关文章

leetcode 23. 合并 K 个升序链表

2023.9.25 本题要合并k个有序链表,最朴素的方法可以联想到之前做的合并两个有序链表。 用一个for循环遍历lists数组,利用合并两个有序链表的代码,不断合并lists中的链表,最后返回头节点即可。 代码如下: /*** Definit…

Mooctest

开发者 测试框架junit 1.字符串不能除 2.a给了c 3. 4. 5.输入是否>0 6.注释

GE IS220PVIBH1A 336A4940CSP16 电源模块

GE IS220PVIBH1A 336A4940CSP16 电源模块是通用电气(GE)的一种电源模块,用于工业控制和电力系统中,提供电源供应和保护功能。以下是这种类型电源模块的一般特点和功能: 电源供应:GE IS220PVIBH1A 336A4940C…

虚拟机数据恢复:Stellar Data Recovery for Virtual Machine

虚拟机数据恢复-----Crack Version Stellar Data Recovery for Virtual Machine 软件可从 VMware (.vmdk)、ORACLE (.vdi) 和 Microsoft (.vhd) 虚拟映像文件中恢复丢失和删除的数据。 从 VMDK、VDI 和 VHD 虚拟映像文件中恢复数据提供原始恢复选项来恢复数据从已删除或无法识…

Logstash、sharding-proxy组件高级配置

记录Logstash数据同步插件在分库分表场景下相关高可用、高并发配置 一、Logstash 1.配置文件控制任务数 vim /etc/logstash/logstash.yml pipeline.workers: 24 pipeline.batch.size: 10000 pipeline.batch.delay: 10 Logstash建议在修改配置项以提高性能的时候,每…

【C++STL基础入门】stack栈的基础使用

文章目录 前言一、栈是什么?二、STL中栈的使用2.1 栈的头文件2.2 栈的构造函数 三、stack属性3.1 empty()函数3.2 size()函数 总结 前言 C STL(Standard Template Library)是C标准库中的一个强大的工具集,提供了各种常用的数据结…

JAVA 实用开源工具集持续梳理中......

1. 接口文档生成器 1.1. knife 介绍 knife4j,示例, 官网 Knife4j是一个集Swagger2 和 OpenAPI3为一体的增强解决方案 使用示例 参见快速开始 2. 快速开发框架 2.1. magic-api 介绍 git地址 magic-api 是一个基于Java的接口快速开发框架,编写接口将通过magic-api提供…

java排序start~

数组排序 Integer[] a {1,2,3,4,5,6};Arrays.sort(a, (a1, b1)->{return b1-a1;//倒序排序});System.out.println(Arrays.toString(a)); Java中Arrays.sort()的三种常用用法(自定义排序规则)_arrays.sort() 自定义排序_请叫我算术嘉的博客-CSDN博客…