YOLOv5、YOLOv7 注意力机制改进SEAM、MultiSEAM、TripletAttention

news/2024/7/10 23:56:10 标签: YOLO, 注意力机制

用于学习记录

文章目录

  • 前言
  • 一、SEAM、MultiSEAM
    • 1.1 models/common.py
    • 1.2 models/yolo.py
    • 1.3 models/SEAM.yaml
    • 1.4 models/MultiSEAM.yaml
    • 1.5 SEAM 训练结果图
    • 1.6 MultiSEAM 训练结果图
  • 二、TripletAttention
    • 2.1 models/common.py
    • 2.2 models/yolo.py
    • 2.3 yolov7/cfg/training/TripletAttention.yaml
    • 2.4 TripletAttention 训练结果图


前言


一、SEAM、MultiSEAM

1.1 models/common.py

在这里插入图片描述

class Residual(nn.Module):
    def __init__(self, fn):
        super(Residual, self).__init__()
        self.fn = fn
 
    def forward(self, x):
        return self.fn(x) + x
 
class SEAM(nn.Module):
    def __init__(self, c1, c2, n, reduction=16):
        super(SEAM, self).__init__()
        if c1 != c2:
            c2 = c1
        self.DCovN = nn.Sequential(
            # nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1, groups=c1),
            # nn.GELU(),
            # nn.BatchNorm2d(c2),
            *[nn.Sequential(
                Residual(nn.Sequential(
                    nn.Conv2d(in_channels=c2, out_channels=c2, kernel_size=3, stride=1, padding=1, groups=c2),
                    nn.GELU(),
                    nn.BatchNorm2d(c2)
                )),
                nn.Conv2d(in_channels=c2, out_channels=c2, kernel_size=1, stride=1, padding=0, groups=1),
                nn.GELU(),
                nn.BatchNorm2d(c2)
            ) for i in range(n)]
        )
        self.avg_pool = torch.nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(c2, c2 // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(c2 // reduction, c2, bias=False),
            nn.Sigmoid()
        )
 
        self._initialize_weights()
        # self.initialize_layer(self.avg_pool)
        self.initialize_layer(self.fc)
 
 
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.DCovN(x)
        y = self.avg_pool(y).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        y = torch.exp(y)
        return x * y.expand_as(x)
 
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight, gain=1)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
 
    def initialize_layer(self, layer):
        if isinstance(layer, (nn.Conv2d, nn.Linear)):
            torch.nn.init.normal_(layer.weight, mean=0., std=0.001)
            if layer.bias is not None:
                torch.nn.init.constant_(layer.bias, 0)
 
def DcovN(c1, c2, depth, kernel_size=3, patch_size=3):
    dcovn = nn.Sequential(
        nn.Conv2d(c1, c2, kernel_size=patch_size, stride=patch_size),
        nn.SiLU(),
        nn.BatchNorm2d(c2),
        *[nn.Sequential(
            Residual(nn.Sequential(
                nn.Conv2d(in_channels=c2, out_channels=c2, kernel_size=kernel_size, stride=1, padding=1, groups=c2),
                nn.SiLU(),
                nn.BatchNorm2d(c2)
            )),
            nn.Conv2d(in_channels=c2, out_channels=c2, kernel_size=1, stride=1, padding=0, groups=1),
            nn.SiLU(),
            nn.BatchNorm2d(c2)
        ) for i in range(depth)]
    )
    return dcovn
 
class MultiSEAM(nn.Module):
    def __init__(self, c1, c2, depth, kernel_size=3, patch_size=[3, 5, 7], reduction=16):
        super(MultiSEAM, self).__init__()
        if c1 != c2:
            c2 = c1
        self.DCovN0 = DcovN(c1, c2, depth, kernel_size

http://www.niftyadmin.cn/n/4997309.html

相关文章

CSS中可继承与不可继承属性

可继承 1. 字体属性: font、font-style、font-variant、font-weight、font-size、line-height等属性是字体样式的属性,都可以被子元素继承。 2. 文本属性: color、text-indent、text-align、text-decoration、text-transform、letter-spa…

CSS 一个好玩的卡片“开卡效果”

文章目录 一、用到的一些CSS技术二、实现效果三、代码 一、用到的一些CSS技术 渐变 conic-gradientbox-shadowclip-path变换、过渡 transform、transition动画 animation keyframes伪类、伪元素 :hover、::before、::after …绝对布局。。。 clip-path 生成网站 https://techb…

【Spring Security】UserDetails 接口介绍

文章目录 UserDetails 的作用UserDetails 接口中各个方法详解 UserDetails 的作用 UserDetails 在 Spring Security 框架中主要担任获取用户信息的接口,通过该接口就能拿到用户的信息和验证用户的信息,这些信息在下面的方法中会有讲述。 UserDetails 接…

代码数据会促进LLM的推理能力吗?

深度学习自然语言处理 原创作者:Winnie 代码数据对提升LLM的推理能力有效吗?为了解答这个问题,最近的一篇工作提出了CIRS(复杂度影响推理分数)这一新的指标,用来衡量代码数据的复杂性,进而验证不…

移位的应用整理

移位的应用整理 移位运算的概念 移位针对的对象是二进制数据 为什么计算机要使用补码存储数据&#xff1a;补码存储的原因 //例如int类型5在计算机中存储为 0000 0000 0000 0000 0000 0000 0000 0101移位运算符 << 左移&#xff1a;左边丢弃&#xff0c;右边补0 >&…

Ubuntu安装Protobuf,指定版本

参考&#xff1a;https://github.com/protocolbuffers/protobuf#readme https://github.com/protocolbuffers/protobuf/blob/v3.20.3/src/README.md 其实官网的readme给的步骤很详细。 1.安装相关依赖 sudo apt-get install autoconf automake libtool curl make g unzip …

向量数据库Annoy和Milvus

Annoy 和 Milvus 都是用于向量索引和相似度搜索的开源库&#xff0c;它们可以高效地处理大规模的向量数据。 Annoy&#xff08;Approximate Nearest Neighbors Oh Yeah&#xff09;&#xff1a; Annoy 是一种近似最近邻搜索算法&#xff0c;它通过构建一个树状结构来加速最近…

嵌入式开发-lin总线介绍 一.概述

1.1lin总线定义和历史 LIN总线&#xff08;Local Interconnect Network&#xff09;是一种基于UART/SCI&#xff08;Universal Asynchronous Receiver-Transmitter/Serial Communication Interface&#xff09;的低成本串行通信协议。它主要用于汽车、家电、办公设备等多种领域…