【RT-DETR有效改进】 主干篇 | SwinTransformer替换Backbone(附代码 + 详细修改步骤 +原理介绍)

前言

大家好,这里是RT-DETR有效涨点专栏

本专栏的内容为根据ultralytics版本的RT-DETR进行改进,内容持续更新,每周更新文章数量3-10篇。

专栏以ResNet18、ResNet50为基础修改版本,同时修改内容也支持ResNet32、ResNet101和PPHGNet版本,其中ResNet为RT-DETR官方版本1:1移植过来的,参数量基本保持一致(误差很小很小),不同于ultralytics仓库版本的ResNet官方版本,同时ultralytics仓库的一些参数是和RT-DETR相冲的所以我也是会教大家调好一些参数和代码,真正意义上的跑ultralytics的和RT-DETR官方版本的无区别。

👑欢迎大家订阅本专栏,一起学习RT-DETR👑   

 一、本文介绍

本文给大家带来的改进机制是利用Swin Transformer替换RT-DETR中的骨干网络其是一个开创性的视觉变换器模型,它通过使用位移窗口来构建分层的特征图,有效地适应了计算机视觉任务。与传统的变换器模型不同,Swin Transformer的自注意力计算仅限于局部窗口内,使得计算复杂度与图像大小成线性关系,而非二次方。这种设计不仅提高了模型的效率,还保持了强大的特征提取能力。Swin Transformer的创新在于其能够在不同层次上捕捉图像的细节和全局信息,使其成为各种视觉任务的强大通用骨干网络。

推荐指数:⭐⭐⭐⭐

涨点效果:⭐⭐⭐⭐

 专栏链接:RT-DETR剑指论文专栏,持续复现各种顶会内容——论文收割机RT-DETR

目录

 一、本文介绍

二、Swin Transformer原理

2.1 Swin Transformer的基本原理

2.2 层次化特征映射

2.3 局部自注意力计算

2.4 移动窗口自注意力

2.5 移动窗口分区

三、 Swin Transformer的完整代码

四、手把手教你添加Swin Transformer网络结构

4.1 修改一

4.2 修改二 

4.3 修改三 

4.4 修改四

4.5 修改五

4.6 修改六

4.7 修改七 

4.8 修改八

4.9 RT-DETR不能打印计算量问题的解决

4.10 可选修改

五、Swin Transformer的yaml文件

5.1 yaml文件

5.2 运行文件

5.3 成功训练截图

六、全文总结


二、Swin Transformer原理

论文地址:官方论文地址

代码地址:官方代码地址


2.1 Swin Transformer的基本原理

Swin Transformer是一个新的视觉变换器,能够作为通用的计算机视觉骨干网络。这个模型解决了将Transformer从语言处理领域适应到视觉任务中的挑战,主要是因为这两个领域之间存在差异,例如视觉实体的尺度变化大,以及图像中像素的高分辨率与文本中的单词相比。下图对比展示了Swin Transformer与Vision Transformer (ViT)的不同之处,清楚地展示了Swin Transformer在构建特征映射和处理计算复杂度方面的创新优势。

(a) Swin Transformer:提出的Swin Transformer通过在更深层次合并图像小块(灰色部分所示)来构建层次化的特征映射。在每个局部窗口(红色部分所示)内只计算自注意力,因此它对输入图像大小有线性的计算复杂度。它可以作为通用的骨干网络,用于图像分类和密集识别任务,如分割和检测。

(b) Vision Transformer (ViT):以前的视觉Transformer模型(如ViT)产生单一低分辨率的特征映射,并且由于全局自注意力的计算,其计算复杂度与输入图像大小呈二次方关系。

我们可以将Swin Transformer的基本原理分为以下几点:

1. 层次化特征映射:Swin Transformer通过合并图像的相邻小块(patches),在更深的Transformer层次中逐步构建层次化的特征映射。这样的层次化特征映射可以方便地利用密集预测的高级技术,如特征金字塔网络(Feature Pyramid Networks, FPN)或U-Net。

2. 局部自注意力计算:为了实现线性计算复杂性,Swin Transformer在非重叠的局部窗口内计算自注意力,这些窗口是通过划分图像来创建的。每个窗口内的小块数量是固定的,因此计算复杂性与图像大小成线性关系。

3. 移动窗口自注意力(Shifted Window based Self-Attention):标准的Transformer架构在全局范围内计算自注意力,即计算一个标记与所有其他标记之间的关系。这种全局计算导致与标记数量成二次方的计算复杂性,不适用于许多需要处理大规模高维数据的视觉问题。Swin Transformer通过一个基于移动窗口的多头自注意力(MSA)模块取代了传统的MSA模块。每个Swin Transformer块由一个基于移动窗口的MSA模块组成,然后是两层带有GELU非线性的MLP,之前是LayerNorm(LN)层,之后是残差连接。

4. 移动窗口分区:为了在连续的Swin Transformer块中引入跨窗口连接的同时保持非重叠窗口的有效计算,提出了一种移动窗口分区方法。这种方法在连续的块之间交替使用两种分区配置。第一个模块使用常规的窗口分区策略,然后下一个模块采用的窗口配置与前一层相比,通过移动窗口偏移了一定距离,从而实现窗口的交替。

下图详细展示了Swin Transformer的架构和两个连续Swin Transformer块的设计。图中的W-MSA和SW-MSA分别代表带有常规和移动窗口配置的多头自注意力模块。这两种类型的注意力模块交替使用,允许模型在保持局部计算的同时,也能够捕捉更广泛的上下文信息。

(a) 架构(Architecture):图展示了Swin Transformer的四个阶段。每个阶段都包含若干Swin Transformer块。输入图像首先通过“Patch Partition”被划分成小块,并通过“Linear Embedding”转换成向量序列。各个阶段通过“Patch Merging”操作降低特征图的分辨率,同时增加特征维数(例如,第一阶段输出的特征维数为C,第二阶段为2C,依此类推)。

(b) 两个连续的Swin Transformer块(Two Successive Swin Transformer Blocks):每个Swin Transformer块由多头自注意力模块(W-MSA和SW-MSA)和多层感知机(MLP)组成,其中W-MSA使用常规窗口配置,而SW-MSA使用移动窗口配置。每个块内部,先是LayerNorm(LN)层,然后是自注意力模块,再是另一个LayerNorm层,最后是MLP。块之间通过残差连接进行连接,这样的设计可以避免深层网络中的梯度消失问题,并允许信息在网络中更有效地流动。


2.2 层次化特征映射

层次化特征映射可以使Swin Transformer有效地处理不同分辨率的特征,并适用于各种视觉任务,如图像分类、对象检测和语义分割。这种层次化设计使Swin Transformer与以往基于Transformer的架构(这些架构产生单一分辨率的特征图并具有二次方复杂度)形成对比,后者不适合需要在像素级进行密集预测的视觉任务。Swin Transformer的层次化特征映射主要通过以下步骤实现:

1. 分块和线性嵌入:首先,输入图像被分割成小块(通常是4x4像素大小),每个小块被视为一个“标记”,其特征是原始像素RGB值的串联。然后,一个线性嵌入层被应用于这些原始值特征,将其投影到任意维度(表示为C)。这些步骤构成了所谓的“第1阶段”。

2. 分块合并:随着网络深入,通过合并层减少标记的数量,从而降低特征图的分辨率。例如,第一个合并层将每组2x2相邻小块的特征合并,并应用一个线性层到这些4C维度的串联特征上,这样做将标记的数量减少了4倍(分辨率降低了2倍),并将输出维度设为2C。这个过程在后续的“第2阶段”、“第3阶段”和“第4阶段”中重复,分别产生更低分辨率的输出。

3. 层次化特征图:通过在更深的Transformer层合并相邻小块,Swin Transformer构建了层次化的特征映射。这些层次化特征映射允许模型方便地使用密集预测的高级技术,例如特征金字塔网络(FPN)或U-Net。

4. 计算效率:Swin Transformer在非重叠的局部窗口内局部计算自注意力,从而实现了线性的计算复杂度。每个窗口中的小块数量是固定的,因此复杂度与图像大小成线性关系。


2.3 局部自注意力计算

Swin Transformer的局部自注意力计算通过在小窗口内计算自注意力以及通过移动窗口在连续层之间引入跨窗口的信息流通,使得计算更加高效,同时保留了模型捕捉长距离依赖的能力。Swin Transformer中的局部自注意力计算我们可以通过以下方式实现:

1. 替代标准多头自注意力模块:Swin Transformer使用基于移动窗口的多头自注意力(MSA)模块替代了传统Transformer块中的标准多头自注意力模块,其他层保持不变。每个Swin Transformer块由一个基于移动窗口的MSA模块组成,后跟一个两层的MLP,中间包含GELU非线性激活函数。在每个MSA模块和MLP之前都会应用一个LayerNorm(LN)层,每个模块之后都会应用残差连接。

2. 在各个窗口内计算自注意力:在每一层中,采用常规的窗口分区方案,每个窗口内部独立计算自注意力。在下一层中,窗口分区会发生移动,形成新的窗口。新窗口中的自注意力计算会跨越之前层中窗口的边界,建立它们之间的连接。

3. 非重叠窗口中的自注意力:为了有效的建模,Swin Transformer在非重叠的局部窗口内计算自注意力。这些窗口被安排以均匀非重叠的方式分割图像。假设每个窗口包含M×M个小块,全局MSA模块和基于窗口的MSA模块的计算复杂度分别为二次方和线性,当M固定时(默认设为7)。

4. 循环位移和掩码机制:提出了一种通过循环位移来提高批量计算的效率的方法。通过这种位移,一个批次的窗口可能由几个在特征图中不相邻的子窗口组成,因此采用掩码机制限制在每个子窗口内计算自注意力。这种循环位移保持了批次窗口的数量与常规窗口分区相同。

5. 窗口间的位移:为了在连续层之间实现更高效的硬件实现,Swin Transformer提出在连续层之间位移窗口,这样的位移允许跨窗口的连接,同时维持计算的高效性。

6. 相对位置偏置:在计算自注意力时,Swin Transformer包括了相对位置偏置B,以增强模型对不同位置之间关系的学习能力。


2.4 移动窗口自注意力

移动窗口自注意力是Swin Transformer设计的核心元素,它通过在局部窗口内计算自注意力并在连续层之间引入窗口位移,以实现高效的计算和强大的建模能力。在Swin Transformer论文中,移动窗口自注意力(shifted window self-attention)的主要特点包括:

1. 替代多头自注意力模块:在Swin Transformer块中,标准的多头自注意力(MSA)模块被基于移动窗口的MSA模块替换。这种基于移动窗口的MSA模块后跟一个两层的MLP,中间有GELU非线性激活函数。每个MSA模块和MLP之前都会应用一个LayerNorm(LN)层,每个模块之后都会应用残差连接。

2. 移动窗口分区:在连续的Swin Transformer块中,窗口分区策略在每一层之间交替。在某一层中,采用常规窗口分区,而在下一层中,窗口分区会发生移动,从而形成新的窗口。这种移动窗口分区方法能够跨越前一层中窗口的边界,提供窗口间的连接。

3. 交替分区配置:移动窗口分区方法在连续的Swin Transformer块中交替使用两种分区配置。例如,第一个模块从左上角像素开始使用常规窗口分区策略,接着下一个模块采用的窗口配置将与前一层相比移动一定距离。

4. 移动窗口自注意力的计算:移动窗口自注意力计算的有效性不仅在图像分类、目标检测和语义分割任务中得到了验证,而且它的实现也被证明在所有MLP架构中有益。

5. 效率:相比于滑动窗口方法,移动窗口方法具有更低的延迟,但在建模能力上却相似。此外,移动窗口方法也有助于提高批量计算的效率。

6. 连续块的计算:在移动窗口分区方法中,连续的Swin Transformer块的计算方式如下:\hat{z_l} = W\text{-}MSA(LN(z_{l-1})) + z_{l-1},然后是MLP层,之后是\\hat{z_{l+1}} = SW\text{-}MSA(LN(z_l)) + z_l。这里,\hat{z_l}z_l分别代表块l的(S)W-MSA模块和MLP模块的输出特征。

下面我给大家展示了所提出的Swin Transformer架构中用于计算自注意力的移动窗口方法

在第l层(左侧),采用了常规窗口划分方案,并且在每个窗口内计算自注意力。在接下来的第l+1层(右侧),窗口划分被移动,结果在新的窗口中进行了自注意力计算。这些新窗口中的自注意力计算跨越了l层中之前窗口的边界,提供了它们之间的连接。这种移动窗口方法提高了效率,因为它限制了自注意力计算在非重叠的局部窗口内,同时允许窗口间的交叉连接。 


2.5 移动窗口分区

移动窗口分区是Swin Transformer中一项关键的创新,它通过在连续层之间交替窗口的分区方式,有效地促进了信息在窗口之间的流动,同时保持了处理高分辨率图像时的计算效率。下面我将通过图片解释如何使用循环位移来计算在移动窗口中的自注意力,以及如何高效地实施这一计算

(1)窗口分区(Window partition):首先,图像被分成多个窗口。
(2)循环位移(Cyclic shift):接着,为了计算自注意力,窗口内的像素或特征会进行循环位移。这样可以将本来不相邻的像素或特征暂时性地排列到同一个窗口内,使得可以在局部窗口中计算原本跨窗口的自注意力。
(3)掩码多头自注意力(Masked MSA):在经过循环位移后,可以在这些临时形成的窗口上执行掩码多头自注意力操作,以此计算注意力得分和更新特征。
(4)逆循环位移(Reverse cyclic shift):完成自注意力计算后,特征会进行逆循环位移,恢复到它们原来在图像中的位置。


三、 Swin Transformer的完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

class Mlp(nn.Module):
    """ Multilayer perceptron."""

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class WindowAttention(nn.Module):
    """ Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """ Forward function.

        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class SwinTransformerBlock(nn.Module):
    """ Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        self.H = None
        self.W = None

    def forward(self, x, mask_matrix):
        """ Forward function.

        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
            mask_matrix: Attention mask for cyclic shift.
        """
        B, L, C = x.shape
        H, W = self.H, self.W
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # pad feature maps to multiples of window size
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
            attn_mask = mask_matrix.type(x.dtype)
        else:
            shifted_x = x
            attn_mask = None

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        if pad_r > 0 or pad_b > 0:
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class PatchMerging(nn.Module):
    """ Patch Merging Layer

    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """ Forward function.

        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x


class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of feature channels
        depth (int): Depths of this stage.
        num_heads (int): Number of attention head.
        window_size (int): Local window size. Default: 7.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self,
                 dim,
                 depth,
                 num_heads,
                 window_size=7,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 norm_layer=nn.LayerNorm,
                 downsample=None,
                 use_checkpoint=False):
        super().__init__()
        self.window_size = window_size
        self.shift_size = window_size // 2
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else window_size // 2,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x, H, W):
        """ Forward function.

        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """

        # calculate attention mask for SW-MSA
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

        for blk in self.blocks:
            blk.H, blk.W = H, W
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, attn_mask)
            else:
                x = blk(x, attn_mask)
        if self.downsample is not None:
            x_down = self.downsample(x, H, W)
            Wh, Ww = (H + 1) // 2, (W + 1) // 2
            return x, H, W, x_down, Wh, Ww
        else:
            return x, H, W, x, H, W


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding

    Args:
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        """Forward function."""
        # padding
        _, _, H, W = x.size()
        if W % self.patch_size[1] != 0:
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
        if H % self.patch_size[0] != 0:
            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))

        x = self.proj(x)  # B C Wh Ww
        if self.norm is not None:
            Wh, Ww = x.size(2), x.size(3)
            x = x.flatten(2).transpose(1, 2)
            x = self.norm(x)
            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)

        return x

class SwinTransformer(nn.Module):
    """ Swin Transformer backbone.
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030

    Args:
        pretrain_img_size (int): Input image size for training the pretrained model,
            used in absolute postion embedding. Default 224.
        patch_size (int | tuple(int)): Patch size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        depths (tuple[int]): Depths of each Swin Transformer stage.
        num_heads (tuple[int]): Number of attention head of each stage.
        window_size (int): Window size. Default: 7.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
        drop_rate (float): Dropout rate.
        attn_drop_rate (float): Attention dropout rate. Default: 0.
        drop_path_rate (float): Stochastic depth rate. Default: 0.2.
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
        patch_norm (bool): If True, add normalization after patch embedding. Default: True.
        out_indices (Sequence[int]): Output from which stages.
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any parameters.
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self,
                 pretrain_img_size=224,
                 patch_size=4,
                 in_chans=3,
                 embed_dim=96,
                 depths=[2, 2, 6, 2],
                 num_heads=[3, 6, 12, 24],
                 window_size=7,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.2,
                 norm_layer=nn.LayerNorm,
                 ape=False,
                 patch_norm=True,
                 out_indices=(0, 1, 2, 3),
                 frozen_stages=-1,
                 use_checkpoint=False):
        super().__init__()

        self.pretrain_img_size = pretrain_img_size
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.out_indices = out_indices
        self.frozen_stages = frozen_stages

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)

        # absolute position embedding
        if self.ape:
            pretrain_img_size = to_2tuple(pretrain_img_size)
            patch_size = to_2tuple(patch_size)
            patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]

            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(
                dim=int(embed_dim * 2 ** i_layer),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                norm_layer=norm_layer,
                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                use_checkpoint=use_checkpoint)
            self.layers.append(layer)

        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
        self.num_features = num_features

        # add a norm layer for each output
        for i_layer in out_indices:
            layer = norm_layer(num_features[i_layer])
            layer_name = f'norm{i_layer}'
            self.add_module(layer_name, layer)
        self.width_list = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]

    def forward(self, x):
        """Forward function."""
        x = self.patch_embed(x)

        Wh, Ww = x.size(2), x.size(3)
        if self.ape:
            # interpolate the position embedding to the corresponding size
            absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
            x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C
        else:
            x = x.flatten(2).transpose(1, 2)
        x = self.pos_drop(x)

        outs = []
        for i in range(self.num_layers):
            layer = self.layers[i]
            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)

            if i in self.out_indices:
                norm_layer = getattr(self, f'norm{i}')
                x_out = norm_layer(x_out)

                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
                outs.append(out)

        return outs

四、手把手教你添加Swin Transformer网络结构

 下面教大家如何修改该网络结构,主干网络结构的修改步骤比较复杂,我也会将task.py文件上传到CSDN的文件中,大家如果自己修改不正确,可以尝试用我的task.py文件替换你的,然后只需要修改其中的第1、2、3、5步即可。

⭐修改过程中大家一定要仔细⭐


4.1 修改一

首先我门中到如下“ultralytics/nn”的目录,我们在这个目录下在创建一个新的目录,名字为'Addmodules'(此文件之后就用于存放我们的所有改进机制),之后我们在创建的目录内创建一个新的py文件复制粘贴进去 ,可以根据文章改进机制来起,这里大家根据自己的习惯命名即可。


4.2 修改二 

第二步我们在我们创建的目录内创建一个新的py文件名字为'__init__.py'(只需要创建一个即可),然后在其内部导入我们本文的改进机制即可,其余代码均为未发大家没有不用理会!


4.3 修改三 

第三步我门中到如下文件'ultralytics/nn/tasks.py'然后在开头导入我们的所有改进机制(如果你用了我多个改进机制,这一步只需要修改一次即可)


4.4 修改四

添加如下两行代码!!!


4.5 修改五

找到七百多行大概把具体看图片,按照图片来修改就行,添加红框内的部分,注意没有()只是函数名(此处我的文件里已经添加很多了后期都会发出来,大家没有的不用理会即可)。

        elif m in {自行添加对应的模型即可,下面都是一样的}:
            m = m(*args)
            c2 = m.width_list  # 返回通道列表
            backbone = True


4.6 修改六

用下面的代码替换红框内的内容。 

if isinstance(c2, list):
    m_ = m
    m_.backbone = True
else:
    m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
    t = str(m)[8:-2].replace('__main__.', '')  # module type
m.np = sum(x.numel() for x in m_.parameters())  # number params
m_.i, m_.f, m_.type = i + 4 if backbone else i, f, t  # attach index, 'from' index, type
if verbose:
    LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f}  {t:<45}{str(args):<30}')  # print
save.extend(
    x % (i + 4 if backbone else i) for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
layers.append(m_)
if i == 0:
    ch = []
if isinstance(c2, list):
    ch.extend(c2)
    if len(c2) != 5:
        ch.insert(0, 0)
else:
    ch.append(c2)


4.7 修改七 

修改七这里非常要注意,不是文件开头YOLOv8的那predict,是400+行的RTDETR的predict!!!初始模型如下,用我给的代码替换即可!!!

代码如下->

 def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
        """
        Perform a forward pass through the model.

        Args:
            x (torch.Tensor): The input tensor.
            profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
            visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
            batch (dict, optional): Ground truth data for evaluation. Defaults to None.
            augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
            embed (list, optional): A list of feature vectors/embeddings to return.

        Returns:
            (torch.Tensor): Model's output tensor.
        """
        y, dt, embeddings = [], [], []  # outputs
        for m in self.model[:-1]:  # except the head part
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
            if profile:
                self._profile_one_layer(m, x, dt)
            if hasattr(m, 'backbone'):
                x = m(x)
                if len(x) != 5:  # 0 - 5
                    x.insert(0, None)
                for index, i in enumerate(x):
                    if index in self.save:
                        y.append(i)
                    else:
                        y.append(None)
                x = x[-1]  # 最后一个输出传给下一层
            else:
                x = m(x)  # run
                y.append(x if m.i in self.save else None)  # save output
            if visualize:
                feature_visualization(x, m.type, m.i, save_dir=visualize)
            if embed and m.i in embed:
                embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flatten
                if m.i == max(embed):
                    return torch.unbind(torch.cat(embeddings, 1), dim=0)
        head = self.model[-1]
        x = head([y[j] for j in head.f], batch)  # head inference
        return x

4.8 修改八

我们将下面的s用640替换即可,这一步也是部分的主干可以不修改,但有的不修改就会报错,所以我们还是修改一下。


4.9 RT-DETR不能打印计算量问题的解决

计算的GFLOPs计算异常不打印,所以需要额外修改一处, 我们找到如下文件'ultralytics/utils/torch_utils.py'文件内有如下的代码按照如下的图片进行修改,大家看好函数就行,其中红框的640可能和你的不一样, 然后用我给的代码替换掉整个代码即可。

def get_flops(model, imgsz=640):
    """Return a YOLO model's FLOPs."""
    try:
        model = de_parallel(model)
        p = next(model.parameters())
        # stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32  # max stride
        stride = 640
        im = torch.empty((1, 3, stride, stride), device=p.device)  # input image in BCHW format
        flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 if thop else 0  # stride GFLOPs
        imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz]  # expand if int/float
        return flops * imgsz[0] / stride * imgsz[1] / stride  # 640x640 GFLOPs
    except Exception:
        return 0


4.10 可选修改

有些读者的数据集部分图片比较特殊,在验证的时候会导致形状不匹配的报错,如果大家在验证的时候报错形状不匹配的错误可以固定验证集的图片尺寸,方法如下 ->

找到下面这个文件ultralytics/models/yolo/detect/train.py然后其中有一个类是DetectionTrainer class中的build_dataset函数中的一个参数rect=mode == 'val'改为rect=False


五、Swin Transformer的yaml文件

5.1 yaml文件

大家复制下面的yaml文件,然后通过我给大家的运行代码运行即可,RT-DETR的调参部分需要后面的文章给大家讲,现在目前免费给大家看这一部分不开放。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, SwinTransformer, []]  # 4

head:
  - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 5 input_proj.2
  - [-1, 1, AIFI, [1024, 8]] # 6
  - [-1, 1, Conv, [256, 1, 1]]  # 7, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 8
  - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 9 input_proj.1
  - [[-2, -1], 1, Concat, [1]] # 10
  - [-1, 3, RepC3, [256, 0.5]]  # 11, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]]   # 12, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 13
  - [2, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 14 input_proj.0
  - [[-2, -1], 1, Concat, [1]]  # 15 cat backbone P4
  - [-1, 3, RepC3, [256, 0.5]]    # X3 (16), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]]   # 17, downsample_convs.0
  - [[-1, 12], 1, Concat, [1]]  # 18 cat Y4
  - [-1, 3, RepC3, [256, 0.5]]    # F4 (19), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]]   # 20, downsample_convs.1
  - [[-1, 7], 1, Concat, [1]]  # 21 cat Y5
  - [-1, 3, RepC3, [256, 0.5]]    # F5 (22), pan_blocks.1

  - [[16, 19, 22], 1, RTDETRDecoder, [nc, 256, 300, 4, 8, 3]]  # Detect(P3, P4, P5)


5.2 运行文件

大家可以创建一个train.py文件将下面的代码粘贴进去然后替换你的文件运行即可开始训练。

import warnings
from ultralytics import RTDETR
warnings.filterwarnings('ignore')

if __name__ == '__main__':
    model = RTDETR('替换你想要运行的yaml文件')
    # model.load('') # 可以加载你的版本预训练权重
    model.train(data=r'替换你的数据集地址即可',
                cache=False,
                imgsz=640,
                epochs=72,
                batch=4,
                workers=0,
                device='0',
                project='runs/RT-DETR-train',
                name='exp',
                # amp=True
                )


5.3 成功训练截图

下面是成功运行的截图(确保我的改进机制是可用的),已经完成了有1个epochs的训练,图片太大截不全第2个epochs了。 


六、全文总结

从今天开始正式开始更新RT-DETR剑指论文专栏,本专栏的内容会迅速铺开,在短期呢大量更新,价格也会乘阶梯性上涨,所以想要和我一起学习RT-DETR改进,可以在前期直接关注,本文专栏旨在打造全网最好的RT-DETR专栏为想要发论文的家进行服务。

 专栏链接:RT-DETR剑指论文专栏,持续复现各种顶会内容——论文收割机RT-DETR


http://www.niftyadmin.cn/n/5340448.html

相关文章

Ultraleap 3Di配置以及在 Unity 中使用 Ultraleap 3Di手部跟踪

0 开发需求 1、硬件&#xff1a;Ultraleap 手部追踪相机&#xff08;Ultraleap 3Di&#xff09; 2、软件&#xff1a;在计算机上安装Ultraleap Gemini (V5.2) 手部跟踪软件。 3、版本&#xff1a;Unity 2021 LTS 或更高版本 4、Unity XR插件管理&#xff1a;可从软件包管理器窗…

设计模式-委托模式

设计模式专栏 模式介绍模式特点应用场景委托模式与代理模式的区别代码示例Java实现委托模式Python实现委托模式 委托模式在spring中的应用 模式介绍 委托模式是一种行为模式&#xff0c;用于在面向对象设计中解决多个对象接收并处理同一请求的问题。它通过将请求委托给另一个对…

Sqoop数据导入到Hive表的最佳实践

将数据从关系型数据库导入到Hive表是大数据领域中的常见任务之一&#xff0c;Sqoop是一个强大的工具&#xff0c;可以帮助实现这一目标。本文将提供Sqoop数据导入到Hive表的最佳实践&#xff0c;包括详细的步骤、示例代码和最佳建议&#xff0c;以确保数据导入过程的高效性和可…

Flink对接Kafka的topic数据消费offset设置参数

scan.startup.mode 是 Flink 中用于设置消费 Kafka topic 数据的起始 offset 的配置参数之一。 scan.startup.mode 可以设置为以下几种模式&#xff1a; earliest-offset&#xff1a;从最早的 offset 开始消费数据。latest-offset&#xff1a;从最新的 offset 开始消费数据。…

flink1.18 广播流 The Broadcast State Pattern 官方案例scala版本

对应官网 https://nightlies.apache.org/flink/flink-docs-master/docs/dev/datastream/fault-tolerance/broadcast_state/ 测试数据 * 广播流 官方案例 scala版本* 广播状态* https://nightlies.apache.org/flink/flink-docs-master/docs/dev/datastream/fault-tolerance…

RK3568 移植Ubuntu

使用ubuntu-base构建根文件系统 1、到ubuntu官网获取 ubuntu-base-18.04.5-base-arm64.tar.gz Ubuntu Base 18.04.5 LTS (Bionic Beaver) 2、将获取的文件拷贝到ubuntu虚拟机,新建目录,并解压 mkdir ubuntu_rootfs sudo tar -xpf u

Golang 通过开源库 go-redis 操作 NoSQL 缓存服务器

前置条件&#xff1a; 1、导入库&#xff1a; import ( "github.com/go-redis/redis/v8" ) 2、搭建哨兵模式集群 具体可以百度、谷歌搜索&#xff0c;网上现成配置教程太多了&#xff0c;不行还可以搜教程视频&#xff0c;跟着视频博主一步一个慢动作&#xff0…

【GitHub项目推荐--基于 Flutter 的游戏引擎】【转载】

Flame 引擎的目的是为使用 Flutter 开发的游戏会遇到的常见问题提供一套完整的解决方案。 目前 Flame 提供了以下功能&#xff1a; 游戏循环 (game loop) 组件/对象系统 (FCS) 特效与粒子效果 碰撞检测 手势和输入支持 图片、动画、精灵图 (sprite) 以及精灵图组 一些简化…