YOLOv8改进实战 | 更换主干网络Backbone(三)之轻量化模型ShuffleNetV2

news/2024/7/11 0:48:24 标签: YOLO, 网络, 深度学习, 人工智能, 目标检测

在这里插入图片描述


前言

轻量化网络设计是一种针对移动设备等资源受限环境的深度学习模型设计方法。下面是一些常见的轻量化网络设计方法:

  1. 网络剪枝:移除神经网络中冗余的连接和参数,以达到模型压缩和加速的目的。
  2. 分组卷积:将卷积操作分解为若干个较小的卷积操作,并将它们分别作用于输入的不同通道,从而减少计算量。
  3. 深度可分离卷积:将标准卷积分解成深度卷积和逐点卷积两个步骤,使得在大部分情况下可以大幅减少计算量。
  4. 跨层连接:通过跨越多个层级的连接方式来增加神经网络的深度和复杂性,同时减少了需要训练的参数数量。
  5. 模块化设计:将神经网络分解为多个可重复使用的模块,以提高模型的可调节性和适应性。

传统的YOLOv8系列中,Backbone采用的是较为复杂的C2f网络结构,这使得模型计算量大幅度的增加,检测速度较慢,应用受限,在某些真实的应用场景如移动或者嵌入式设备,如此大而复杂的模型时难以被应用的。为了解决这个问题,本章节通过采用ShuffleNetV2轻量化主干网络作为Backbone的基础结构,从而在保证检测性能的同时,将网络结构精简到最小,大大减小了模型的参数量和计算量。

目录

  • 一、ShuffleNetV2
  • 二、代码实现
    • 2.1 添加ShuffleNetV2
    • 2.2 注册模块
    • 2.3 配置yaml文件
      • yolov8-shufflenetv2.yaml
    • 2.3 模型验证
    • 2.4 模型训练
  • 三、总结

一、ShuffleNetV2

2018 论文链接:ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design
Pytorch code:ShuffleNet-Series

论文总结的四条轻量化模型设计的指导思想

  • G1:卷积层的输入特征channel和输出特征channel相等可以最小化 MAC(Memory Access Cost,即内存占用量);
  • G2过度的分组卷积会增加MAC
    • 分组卷积是现代网络体系结构的核心。通过更改所有通道之间的稀疏卷积(仅在通道组内),可以降低计算复杂度(FLOP)。一方面,在给定固定FLOPs的情况下,它允许使用更多通道,并增加了网络容量(因此具有更高的准确性)。但是,另一方面,增加的通道数会导致更多的MAC。
  • G3网络碎片化会降低并行化的程度;
    • GoogLeNet系列和自动生成的体系结构中,每个网络模块中广泛采用“多路径”结构,使用了很多不同的小卷积或者pooling。尽管这种零散的结构已显示出对准确率有利,但由于它对具有强大并行计算能力的设备(如GPU)不友好,因此可能会降低效率。它还引入了额外的开销,例如内核启动和同步。
  • G4:不可忽略元素级的操作
    • 对于元素级操作(element-wise operators),比如ReLUAdd,虽然它们的FLOPs较小,但是却需要较大的MAC。深度卷积也算是元素级操作,也具有较高的MAC/FLOPs的比例。实验发现如果将ResNet中残差单元中的ReLU和shortcut移除的话,速度有20%的提升。

总结

  1. 使用“平衡”卷积(输入通道尽可能等于输出通道);
  2. 注意使用组卷积的代价;
  3. 减少碎片程度;
  4. 减少按元素操作。

在这里插入图片描述

二、代码实现

2.1 添加ShuffleNetV2

ultralytics/nn/modules/block.py文件中加入以下代码:

# TODO:build shuffle block
# -------------------------------------------------------------------------

def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()
    channels_per_group = num_channels // groups

    # reshape
    x = x.view(batchsize, groups,
               channels_per_group, height, width)

    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x


class conv_bn_relu_maxpool(nn.Module):
    def __init__(self, c1, c2):  # ch_in, ch_out
        super(conv_bn_relu_maxpool, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(c2),
            nn.ReLU(inplace=True),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

    def forward(self, x):
        return self.maxpool(self.conv(x))


class ShuffleV2Block(nn.Module):
    def __init__(self, inp, oup, stride):
        super(ShuffleV2Block, self).__init__()

        if not (1 <= stride <= 3):
            raise ValueError('illegal stride value')
        self.stride = stride

        branch_features = oup // 2
        assert (self.stride != 1) or (inp == branch_features << 1)

        if self.stride > 1:
            self.branch1 = nn.Sequential(
                self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
                nn.BatchNorm2d(inp),
                nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True),
            )
        else:
            self.branch1 = nn.Sequential()

        self.branch2 = nn.Sequential(
            nn.Conv2d(inp if (self.stride > 1) else branch_features,
                      branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
            self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
            nn.BatchNorm2d(branch_features),
            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
        )

    @staticmethod
    def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
        return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)

    def forward(self, x):
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

        out = channel_shuffle(out, 2)

        return out
# -------------------------------------------------------------------------

2.2 注册模块

修改ultralytics/nn/modules/__init__.py文件:

from .block import (ASFF2, ASFF3, C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
                    GhostBottleneck, HGBlock, HGStem, Proto, RepC3, conv_bn_relu_maxpool, ShuffleV2Block)
__all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus',
           'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer',
           'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3',
           'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect',
           'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI',
           'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP', 'ASFF2', 'ASFF3',
           'conv_bn_relu_maxpool', 'ShuffleV2Block')

修改ultralytics/nn/tasks.py文件中的parse_model函数:

from ultralytics.nn.modules import (AIFI, ASFF2, ASFF3, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f,
                                    C3Ghost, C3x, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv,
                                    DWConvTranspose2d, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3,
                                    RepConv, RTDETRDecoder, Segment, conv_bn_relu_maxpool, ShuffleV2Block)
if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
         BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3,
         conv_bn_relu_maxpool, ShuffleV2Block):
    c1, c2 = ch[f], args[0]
    if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
        c2 = make_divisible(min(c2, max_channels) * width, 8)

    args = [c1, c2, *args[1:]]

2.3 配置yaml文件

这里我们选择替换Backbone中的所有ConvC2f模块。当然也可以将所有ConvC2f模块全部替换掉,哪个效果更好,需要各位去实测一番。

yolov8-shufflenetv2.yaml

第一版

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, conv_bn_relu_maxpool, [32]]  # 0-P1/2

  - [ -1, 1, ShuffleV2Block, [116, 2] ] # 1-P3/8
  - [ -1, 9, ShuffleV2Block, [116, 1] ] # 2

  - [ -1, 1, ShuffleV2Block, [232, 2] ] # 3-P4/16
  - [ -1, 21, ShuffleV2Block, [232, 1] ] # 4

  - [ -1, 1, ShuffleV2Block, [464, 2] ] # 5-P5/32
  - [ -1, 9, ShuffleV2Block, [464, 1] ] # 6

  - [-1, 1, SPPF, [1024, 5]]  # 7

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 10

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 2], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 13 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 16 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 7], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 19 (P5/32-large)

  - [[13, 16, 19], 1, Detect, [nc]]  # Detect(P3, P4, P5)
YOLOv8n-shufflenetv2 summary: 336 layers, 2005904 parameters, 2005888 gradients, 5.9 GFLOPs

推荐:第二版(结合YOLOv5-Lite思想):

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, conv_bn_relu_maxpool, [32]]  # 0-P1/2

  - [ -1, 1, ShuffleV2Block, [116, 2] ] # 1-P3/8
  - [ -1, 9, ShuffleV2Block, [116, 1] ] # 2

  - [ -1, 1, ShuffleV2Block, [232, 2] ] # 3-P4/16
  - [ -1, 21, ShuffleV2Block, [232, 1] ] # 4

  - [ -1, 1, ShuffleV2Block, [464, 2] ] # 5-P5/32
  - [ -1, 3, ShuffleV2Block, [464, 1] ] # 6


# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 9

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 2], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 12 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 15 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 6], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 18 (P5/32-large)

  - [[12, 15, 18], 1, Detect, [nc]]  # Detect(P3, P4, P5)
YOLOv8n-shufflenetv2 summary: 231 layers, 1975560 parameters, 1975544 gradients, 5.8 GFLOPs

2.3 模型验证

from ultralytics import YOLO

# Load a model
model = YOLO("backbone/yolov8n-shufflenetv2.yaml")  # build a new model from scratch

2.4 模型训练

from ultralytics import YOLO

# Load a model
model = YOLO("backbone/yolov8n-shufflenetv2.yaml")  # build a new model from scratch

# Use the model
model.train(
    data="./mydata/data.yaml",
    epochs=300,
    batch=48)  # train the model

三、总结

  • 模型的训练具有很大的随机性,您可能需要点运气和更多的训练次数才能达到最高的 mAP。

在这里插入图片描述


http://www.niftyadmin.cn/n/5107790.html

相关文章

Redis实现附近商户

GEO数据结构的基本用法 GEO就是Geolocation的简写形式&#xff0c;代表地理坐标。Redis在3.2版本中加入了对GEO的支持&#xff0c;允许存储地理坐标信息&#xff0c;帮助我们根据经纬度来检索数据。常见的命令有&#xff1a; GEOADD&#xff1a;添加一个地理空间信息&#xf…

【ROS 2 基础-常用工具】-7 Rviz仿真机器人

所有内容请查看&#xff1a;博客学习目录_Howe_xixi的博客-CSDN博客

超低延迟直播技术路线,h265的无奈选择

超低延迟&#xff0c;多窗显示&#xff0c;自适应编解码和渲染&#xff0c;高分辨低码率&#xff0c;还有微信小程序的标配&#xff0c;这些在现今的监控和直播中都成刚需了&#xff0c;中国的音视频技术人面临着困境&#xff0c;核心门户浏览器不掌握在自己手上&#xff0c;老…

推荐《机动战士高达SEED DESTINY》

《机动战士高达SEED DESTINY》是《机动战士高达SEED》的续集&#xff0c;于日本时间2004年10月9日—2005年10月1日每周六下午六点在每日放送、TBS电视台系列电视台播出&#xff0c;全50话。 [1] 台湾版权由博英社取得&#xff0c;并于2005年10月8日起由中国电视公司在每周六播…

第六届“中国法研杯”司法人工智能挑战赛进行中!

第六届“中国法研杯”司法人工智能挑战赛 赛题上新&#xff01; 第六届“中国法研杯”司法人工智能挑战赛&#xff08;LAIC2023&#xff09;目前已发布司法大模型数据和服务集成调度 、证据推理、司法大数据征文比赛、案件要素识别四大任务。本届大赛中&#xff0c;“案件要素…

JavaScript的基本知识点解析

JavaScript的基本概念&#xff1a; 变量 变量是存储数据的容器。在JavaScript中&#xff0c;可以使用var、let或const关键字声明变量。例如&#xff1a; var x 10; // 使用var声明变量x并赋值为10 let y 20; // 使用let声明变量y并赋值为20 const z 30; // 使用const声明…

系列四、FileReader和FileWriter

一、概述 FileReader 和 FileWriter 是字符流&#xff0c;按照字符来操作IO。 1.1、继承体系 二、FileReader常用方法 new FileReader(File/String)# 每次读取单个字符就返回&#xff0c;如果读取到文件末尾返回-1 read()# 批量读取多个字符到数组&#xff0c;返回读取的字节…

Redis常用配置详解

目录 一、Redis查看当前配置命令二、Redis基本配置三、RDB全量持久化配置&#xff08;默认开启&#xff09;四、AOF增量持久化配置五、Redis key过期监听配置六、Redis内存淘汰策略七、总结 一、Redis查看当前配置命令 # Redis查看当前全部配置信息 127.0.0.1:6379> CONFIG…