CoordAtt注意力网络结构

news/2024/7/10 22:57:51 标签: 深度学习, pytorch, 机器学习, YOLO

源码:

import torch
import torch.nn as nn
import math
import torch.nn.functional as F

class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6

class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)

class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        

    def forward(self, x):
        identity = x
        
        n,c,h,w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y) 
        
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out



def CA_onnx_gen():
    conv=CoordAtt(64,64)
    dummy_input = torch.randn(8,64, 128, 128)
    out=conv(dummy_input)
    print(out.shape)
 
    print(conv)
    # conv.load_state_dict(checkpoint)
    conv.eval()
    input_names = ["input"]
    output_names = ["output"]
    torch.onnx.export(conv, dummy_input, "CA.onnx", verbose=True, opset_version=13,input_names=input_names,
                      output_names=output_names)


if __name__=="__main__":
    CA_onnx_gen()

onnx结构:

 


http://www.niftyadmin.cn/n/4944037.html

相关文章

Java 单例模式简单介绍

何为单例模式 所谓类的单例设计模式,就是采取一定的方法保证在整个的软件系统中,对某个类只能存在一个对象实例,并且该类只提供一个取得其对象实例的方法。 实现思路 如果我们要让类在一个虚拟机中只能产生一个对象,我们首先必…

Azure使用CLI创建VM

使用CLI创建VM之前,确保资源中的IP资源已经释放掉了,避免创建的过程中没有可以利用的公共IP地址打开 cloudshell ,并输入创建CLI的命令如下,-n指定名称,-g指定资源组,image指定镜像,admin-usernam指定用户名…

c#设计模式-结构型模式 之 桥接模式

前言 桥接模式是一种设计模式,它将抽象与实现分离,使它们可以独立变化。这种模式涉及到一个接口作为桥梁,使实体类的功能独立于接口实现类。这两种类型的类可以结构化改变而互不影响。 桥接模式的主要目的是通过将实现和抽象分离,…

网络通信原理网络层TCP/IP协议(第四十三课)

1.什么是TCP/IP 目前应用广泛的网络通信协议集 国际互联网上电脑相互通信的规则、约定。 2.主机通信的三要素 IP地址:用来标识一个节点的网络地址(区分网络中电脑身份的地址,如人有名字) 子网掩码:配合IP地址确定网络号 IP路由:网关的地址,网络的出口 3.IP地址 …

nvm命令

1. 常见命令 1. nvm -v //查看nvm版本 nvm --version :显示 nvm 版本 2. nvm list //显示版本列表 nvm list :显示已安装的版本(同 nvm list installednvm list installed:显示已安装的版本nvm list available:显示所有…

kube-prometheus 系列1 项目介绍

Prometheus 已经成为云原生监控的事实标准。整个生态包含诸多组件,为了简化安装部署和配置高可用等,社区开发了kube-prometheus项目。接下来用一系列文章介绍一下相关配置。 项目简介: kube-prometheus 是一个基于 Kubernetes 部署的 Prometh…

axios同一个接口,同时接收 文件 或者 数据

1、前端代码 const service axios.create({baseURL: "http://192.168.2.200:8080/api",timeout: 180000 })// 响应拦截 service.interceptors.response.use(async response > {if(response){// 请求时设置返回blob, 但是实际上可能返回的是json的情况if (respon…

开源代码分享(13)—整合本地电力市场与级联批发市场的投标策略(附matlab代码)

1.引言 1.1摘要 本地电力市场是在分配层面促进可再生能源的效率和使用的一种有前景的理念。然而,作为一个新概念,如何设计和将这些本地市场整合到现有市场结构中,并从中获得最大利润仍然不清楚。在本文中,我们提出了一个本地市场…