YOLOv8改进算法之添加CA注意力机制

news/2024/7/10 23:43:21 标签: YOLO, 算法, 人工智能

1. CA注意力机制

CA(Coordinate Attention)注意力机制是一种用于加强深度学习模型对输入数据的空间结构理解的注意力机制。CA 注意力机制的核心思想是引入坐标信息,以便模型可以更好地理解不同位置之间的关系。如下图:

1. 输入特征: CA 注意力机制的输入通常是一个特征图,它通常是卷积神经网络(CNN)中的某一层的输出,具有以下形状:[C, H, W],其中:

  • C 是通道数,表示特征图中的不同特征通道。
  • H 是高度,表示特征图的垂直维度。
  • W 是宽度,表示特征图的水平维度。

2. 全局平均池化: CA 注意力机制首先对输入特征图进行两次全局平均池化,一次在宽度方向上,一次在高度方向上。这两次操作分别得到两个特征映射:

  • 在宽度方向上的平均池化得到特征映射 [C, H, 1]
  • 在高度方向上的平均池化得到特征映射 [C, 1, W]

这两个特征映射分别捕捉了在宽度和高度方向上的全局特征。

3. 合并宽高特征: 将上述两个特征映射合并,通常通过简单的堆叠操作,得到一个新的特征层,形状为 [C, 1, H + W],其中 H + W 表示在宽度和高度两个方向上的维度合并在一起。

4. 卷积+标准化+激活函数: 对合并后的特征层进行卷积操作,通常是 1x1 卷积,以捕捉宽度和高度维度之间的关系。然后,通常会应用标准化(如批量标准化)和激活函数(如ReLU)来进一步处理特征,得到一个更加丰富的表示。

5. 再次分开: 分别从上述特征层中分离出宽度和高度方向的特征:

  • 一个分支得到特征层 [C, 1, H]
  • 另一个分支得到特征层 [C, 1, W]

6. 转置: 对分开的两个特征层进行转置操作,以恢复宽度和高度的维度,得到两个特征层分别为 [C, H, 1][C, 1, W]

7. 通道调整和 Sigmoid: 对两个分开的特征层分别应用 1x1 卷积,以调整通道数,使其适应注意力计算。然后,应用 Sigmoid 激活函数,得到在宽度和高度维度上的注意力分数。这些分数用于指示不同位置的重要性。

8. 应用注意力: 将原始输入特征图与宽度和高度方向上的注意力分数相乘,得到 CA 注意力机制的输出。

2. YOLOv8添加CA注意力机制

加入注意力机制,在ultralytics包中的nn包的modules里添加CA注意力模块,我这里选择在conv.py文件中添加CA注意力机制。

CA注意力机制代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6


class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)


class CoordAtt(nn.Module):
    def __init__(self, inp, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()

        self.conv_h = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        identity = x

        n, c, h, w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)

        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out

CA注意力机制的注册和引用如下:

 ultralytics/nn/modules/_init_.py文件中:

  ultralytics/nn/tasks.py文件夹中:

 在tasks.py中的parse_model中添加如下代码:

        elif m in {CoordAtt}:
            args=[ch[f],*args]

新建相应的yolov8s-CA.yaml文件,代码如下:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1,1,CoordAtt,[]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1,1,CoordAtt,[]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1,1,CoordAtt,[]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 8], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 5], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 15], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[18, 21, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)

在main.py文件中进行训练:

if __name__ == '__main__':

    # 使用yaml配置文件来创建模型,并导入预训练权重.
    model = YOLO('ultralytics/cfg/models/v8/yolov8s-CA.yaml')
    # model.load('yolov8n.pt')
    model.train(**{'cfg': 'ultralytics/cfg/default.yaml', 'data': 'dataset/data.yaml'})


http://www.niftyadmin.cn/n/5060176.html

相关文章

ChatGPT终于可以进行网络搜索 内容不再限于2021年9月前

微软和谷歌已经让旗下聊天机器人进行网上搜索,并提供原始材料的链接,以提高信息共享的可信度和范围。但是,ChatGPT迄今为止只接受了有时间限制的训练数据,这些数据仅限于从互联网上收集的2021年9月之前的信息。在周三的一系列推文…

敏捷开发的实施要素和实现敏捷的实际改进

敏捷开发的实施要素如下: 个体和交互:胜过过程和工具。可以工作的软件:胜过面面俱到的文档。客户合作:胜过合同谈判。响应变化:胜过遵循计划。 敏捷开发过程是一个增量的、迭代的过程,责任人、开发人员和…

idea debug 重启弹窗提示窗口询问是否关闭运行着的服务器

目录 方法121版本的IDEA idea重新启动服务器时会有一个提示窗口询问是否关闭运行着的服务器,,这个窗口不小心点了不再提示.重新打开弹窗方法 方法1 idea编辑器由于勾选了不再提示选项导致的弹窗无法继续弹出:解决方案 1.打开项目没提示&…

【RocketMQ】基本使用:Java操作RocketMQ(rocketmq-client)

【RocketMQ】基本使用&#xff1a;Java操作RocketMQ&#xff08;rocketmq-client&#xff09; 1.引入依赖 <dependency><groupId>org.apache.rocketmq</groupId><artifactId>rocketmq-client</artifactId><version>4.3.2</version>…

信息系统项目管理师 12题

1.管理团队所获得的主要收益体现在&#xff08;&#xff09;。 ①指导团队选择和职责分配 ②管理冲突 ③解决问题 ④改进团队协作 ⑤影响团队行为 ⑥评估团队成员绩效 A.①③⑤⑥ B.②③④⑥ C.①③④⑥ D.②③⑤⑥ 【答案】D。管理团队。本过程的主要收益&#xff1a;影响…

C语言数据结构之排序整合与比较(冒泡,选择,插入,希尔,堆排序,快排及改良,归并排序,计数排序)

前言&#xff1a;排序作为数据结构中的一个重要模块&#xff0c;重要性不言而寓&#xff0c;我们的讲法为下理论掌握大致的算法结构&#xff0c;再上代码及代码讲解&#xff0c;助你一臂之力。 一&#xff0c;冒泡 冒泡排序应该是大家学习以来第一个认识的排序方法&#xff0…

【Unity3D】UGUI物体世界坐标转屏幕坐标问题

如题&#xff1a; UGUI物体世界坐标转屏幕坐标问题&#xff0c;获取UI(UGUI)屏幕坐标问题等相关问题 思路&#xff1a;必须使用Canvas身上的Camera&#xff0c;进行Camera.WorldToScreenPoint(UI物体的世界坐标Vector3)&#xff0c;会返回一个Vector3(x,y,z)&#xff0c;我们要…

leetcode之打家劫舍

leetcode 198 打家劫舍 leetcode 213 打家劫舍 II leetcode 337. 打家劫舍 III 你是一个专业的小偷&#xff0c;计划偷窃沿街的房屋&#xff0c;每间房内都藏有一定的现金。这个地方所有的房屋都 围成一圈 &#xff0c;这意味着第一个房屋和最后一个房屋是紧挨着的。同时&#…