基于YOLOv5n/s/m不同参数量级模型开发构建茶叶嫩芽检测识别模型,使用pruning剪枝技术来对模型进行轻量化处理,探索不同剪枝水平下模型性能影响

news/2024/7/11 0:39:09 标签: YOLO, 剪枝, 算法

今天有点时间就想着之前遗留的一个问题正好拿过来做一下看看,主要的目的就是想要对训练好的目标检测模型进行剪枝处理,这里就以茶叶嫩芽检测数据场景为例了,在我前面的博文中已经有过相关的实践介绍了,感兴趣的话可以自行移步阅读即可:

《融合CBAM注意力机制基于YOLOv5开发构建毛尖茶叶嫩芽检测识别系统》

这里就不再赘述了。

本文选取了n/s/m三款不同量级的模型来依次构建训练模型,所有的参数保持同样的设置,之后探索在不同剪枝处理操作下的性能影响。

简单看下数据集情况:

 三款模型的训练指令如下所示:

#yolov5n
python3 train.py --cfg models/yolov5n.yaml --weights weights/yolov5n.pt --name yolov5n --epochs 100 --batch-size 4 --img-size 416

#yolov5s
python3 train.py --cfg models/yolov5s.yaml --weights weights/yolov5s.pt --name yolov5s --epochs 100 --batch-size 4 --img-size 416

#yolov5m
python3 train.py --cfg models/yolov5m.yaml --weights weights/yolov5m.pt --name yolov5m --epochs 100 --batch-size 4 --img-size 416

主要是两点,一是batchsize这里设置的比较小因为同时在跑三款模型,这里设置的都是4;另一方面是imgsize,这里为了加快实验节奏,设置的是416,比较低的分辨率而不是640。

默认都是100次epoch的迭代计算,接下来依次看下实际训练情况:
【yolov5n】

 【yolov5s】

 【yolov5m】

 从最终模型的评估结果上面来看:s系列的模型结果还不如n系列的模型,或者说是二者差异不大,m系列模型的结果要优于其他两款模型。

为了能够整体直观地对三款不同参数量级的模型进行直观地对比分析,这里对其主要指标进行了可视化处理,如下所示:

【Precision曲线】
精确率曲线(Precision-Recall Curve)是一种用于评估二分类模型在不同阈值下的精确率性能的可视化工具。它通过绘制不同阈值下的精确率和召回率之间的关系图来帮助我们了解模型在不同阈值下的表现。
精确率(Precision)是指被正确预测为正例的样本数占所有预测为正例的样本数的比例。召回率(Recall)是指被正确预测为正例的样本数占所有实际为正例的样本数的比例。
绘制精确率曲线的步骤如下:
使用不同的阈值将预测概率转换为二进制类别标签。通常,当预测概率大于阈值时,样本被分类为正例,否则分类为负例。
对于每个阈值,计算相应的精确率和召回率。
将每个阈值下的精确率和召回率绘制在同一个图表上,形成精确率曲线。
根据精确率曲线的形状和变化趋势,可以选择适当的阈值以达到所需的性能要求。
通过观察精确率曲线,我们可以根据需求确定最佳的阈值,以平衡精确率和召回率。较高的精确率意味着较少的误报,而较高的召回率则表示较少的漏报。根据具体的业务需求和成本权衡,可以在曲线上选择合适的操作点或阈值。
精确率曲线通常与召回率曲线(Recall Curve)一起使用,以提供更全面的分类器性能分析,并帮助评估和比较不同模型的性能。


【Recall曲线】
召回率曲线(Recall Curve)是一种用于评估二分类模型在不同阈值下的召回率性能的可视化工具。它通过绘制不同阈值下的召回率和对应的精确率之间的关系图来帮助我们了解模型在不同阈值下的表现。
召回率(Recall)是指被正确预测为正例的样本数占所有实际为正例的样本数的比例。召回率也被称为灵敏度(Sensitivity)或真正例率(True Positive Rate)。
绘制召回率曲线的步骤如下:
使用不同的阈值将预测概率转换为二进制类别标签。通常,当预测概率大于阈值时,样本被分类为正例,否则分类为负例。
对于每个阈值,计算相应的召回率和对应的精确率。
将每个阈值下的召回率和精确率绘制在同一个图表上,形成召回率曲线。
根据召回率曲线的形状和变化趋势,可以选择适当的阈值以达到所需的性能要求。
通过观察召回率曲线,我们可以根据需求确定最佳的阈值,以平衡召回率和精确率。较高的召回率表示较少的漏报,而较高的精确率意味着较少的误报。根据具体的业务需求和成本权衡,可以在曲线上选择合适的操作点或阈值。
召回率曲线通常与精确率曲线(Precision Curve)一起使用,以提供更全面的分类器性能分析,并帮助评估和比较不同模型的性能。


【F1值曲线】
F1值曲线是一种用于评估二分类模型在不同阈值下的性能的可视化工具。它通过绘制不同阈值下的精确率(Precision)、召回率(Recall)和F1分数的关系图来帮助我们理解模型的整体性能。
F1分数是精确率和召回率的调和平均值,它综合考虑了两者的性能指标。F1值曲线可以帮助我们确定在不同精确率和召回率之间找到一个平衡点,以选择最佳的阈值。
绘制F1值曲线的步骤如下:
使用不同的阈值将预测概率转换为二进制类别标签。通常,当预测概率大于阈值时,样本被分类为正例,否则分类为负例。
对于每个阈值,计算相应的精确率、召回率和F1分数。
将每个阈值下的精确率、召回率和F1分数绘制在同一个图表上,形成F1值曲线。
根据F1值曲线的形状和变化趋势,可以选择适当的阈值以达到所需的性能要求。
F1值曲线通常与接收者操作特征曲线(ROC曲线)一起使用,以帮助评估和比较不同模型的性能。它们提供了更全面的分类器性能分析,可以根据具体应用场景来选择合适的模型和阈值设置。

 【loss曲线】

 整体来看m系列模型无意识这三款不同参数量级模型中表现最好的,n和s系列模型的表现相近。

接下来要对三款模型进行剪枝处理,这里使用到一个很好用的第三方模块torch_pruning,官方项目地址在这里,如下所示:

 安装方式很简单如下所示:

pip install torch-pruning 
或者
git clone https://github.com/VainF/Torch-Pruning.git

在结构修剪中,“组”被定义为可以在深度网络中移除的最小单元。这些组由多个相互依赖的层组成,需要一起修剪,以保持生成结构的完整性。然而,深度网络的层之间往往存在复杂的依赖关系,这使得结构修剪成为一项具有挑战性的任务。这项工作通过引入一种名为“DepGraph”的自动化机制来解决这一挑战。DepGraph允许无缝的参数分组,并有助于在各种类型的深度网络中进行修剪。

官方提供很多可用的实例,如下所示:

【Naive pruning】

为了演示依赖性的含义,让我们在ResNet-18上尝试结构化修剪。以下代码片段尝试从第一个model.conv1中删除由0和1索引的通道:

from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()
tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) # remove channel 0 and channel 1
output = model(torch.randn(1,3,224,224)) # test


输出 
ResNet(
  (conv1): Conv2d(3, 62, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
...

【An improved version】

事实上,上述情况下的依赖关系比我们已经观察到的要复杂得多。让我们改进我们的代码,看看如果处理BN和Conv会发生什么。

from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()
tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) 
tp.prune_batchnorm_out_channels(model.bn1, idxs=[0,1])
tp.prune_batchnorm_in_channels(model.layer1[0].conv1, idxs=[0,1])
output = model(torch.randn(1,3,224,224)) 

【A Minimal Example】

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 2. Specify the to-be-pruned channels. Here we prune those channels indexed by [2, 6, 9].
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )

# 3. prune all grouped layers that are coupled with model.conv1 (included).
if DG.check_pruning_group(group): # avoid full pruning, i.e., channels=0.
    group.prune()
    
# 4. Save & Load
model.zero_grad() # We don't want to store gradient information
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the model object

上面的示例演示了使用DepGraph的基本修剪管道。目标层resnet.conv1与多个层耦合,这需要在结构修剪中同时移除。让我们打印该组,并观察修剪操作是如何“触发”其他修剪操作的。在以下输出中,A=>B表示修剪操作A触发修剪操作B。group[0]表示DG.get_pruning_group中的修剪根。

--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9] (Pruning Root)
[1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), idxs=[2, 6, 9]
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), idxs=[2, 6, 9]
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), idxs=[2, 6, 9]
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), idxs=[2, 6, 9]
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), idxs=[2, 6, 9]
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), idxs=[2, 6, 9]
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs=[2, 6, 9]
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
--------------------------------

【High-level Pruners】

利用DependencyGraph,我们在这个存储库中开发了几个高级修剪器,以方便轻松修剪。通过指定所需的通道稀疏性,可以修剪整个模型,并使用自己的训练代码对其进行微调。有关这个过程的详细信息,请参阅本教程,它展示了如何从头开始实现瘦身修剪器。此外,您可以在benchmarks/main.py中找到更实用的示例。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)

# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    if isinstance(imp, tp.importance.TaylorImportance):
        # Taylor expansion requires gradients for importance estimation
        loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
        loss.backward() # before pruner.step()
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    # finetune your model here
    # finetune(model)
    # ...

还有很多其他的功能实例,这里就不再一一赘述了,可以参考使用即可,这里我借鉴官方的实例来完成对yolov5n/s/m三款不同参数量级模型的剪枝处理。剪枝完成后结果如下所示:

 接下来我想直接使用剪枝后的模型来进行评估测试,不出意外的话结果应该会很差的,先来简单看下吧。

【yolov5n_layer_pruning】

 【yolov5s_layer_pruning】

 【yolov5m_layer_pruning】

 果然是惨不忍睹,直接使用剪枝后的模型文件是不行的,这样破坏了原始完整的模型结构,导致原有学习后的知识已经无效了。

接下来就需要基于剪枝后的结构来进行微调训练。这里我同样保持了与最初模型训练一样的参数设置,如下所示:

#yolov5n
python3 train.py --weights yolov5n_layer_pruning.pt --pt --name yolov5n_pruning  --epochs 100 --batch-size 4 --img-size 416

#yolov5s
python3 train.py --weights yolov5s_layer_pruning.pt --pt --name yolov5s_pruning  --epochs 100 --batch-size 4 --img-size 416

#yolov5m
python3 train.py --weights yolov5m_layer_pruning.pt --pt --name yolov5m_pruning  --epochs 100 --batch-size 4 --img-size 416

这里其实也可以不用训练100次epoch,只不过我想默认保持一样的参数设置,等待一段时间后来看下结果记录。

【yolov5n_pruning】

 【yolov5s_pruning】

 【yolov5m_pruning】

 这里从评估结果上来看:n<s<m。接下来我们同样对其进行对比可视化分析展示。

【F1值】

 【精确率】

 【召回率】

 【loss】

 上述的三组剪枝实验结果是建立在剪枝30%的基础上,产生的结果,可以看到:甚至剪枝后的效果还要优于原始的模型,这也说明了原始的模型中存在相当量的参数冗余。

接下来我们想要进一步探索不同程度剪枝水平对于模型性能的影响程度,前车之鉴,这里写CSDN博文都不敢一篇文章写太多内容,不然突然页面崩溃就会好心酸。。。。。。

我把这部分的内容放在下一篇博文中,如下所示:
《基于YOLOv5n/s/m不同参数量级模型开发构建茶叶嫩芽检测识别模型,预计pruning剪枝技术来对模型进行轻量化处理,探索不同剪枝水平下模型性能影响【续】》


http://www.niftyadmin.cn/n/4947306.html

相关文章

栈和队列实现

目录 ​编辑 &#x1f339;1.栈 &#x1f490;1.1 栈的概念和结构 &#x1f490;1.2 栈的实现 &#x1f338;1.2.1 初始化 &#x1f338;1.2.2 销毁 &#x1f338;1.2.3 入栈 &#x1f338;1.2.4 出栈 &#x1f338;1.2.5 获取栈顶元素 &#x1f338;1.2.6 判空 &am…

C#语音播报问题之 无法嵌入互操作类型SpVoiceClass,请改用适用的窗口

C#语音播报问题之 无法嵌入互操作类型SpVoiceClass&#xff0c;请改用适用的窗口 解决办法如下&#xff1a; 只需要将引入的Interop.SpeechLib的属性嵌入互操作类型改为false 改为false 即可解决&#xff01;

微服务-Nacos(配置管理)

配置更改热更新 在Nacos中添加配置信息&#xff1a; 在弹出表单中填写配置信息&#xff1a; 配置获取的步骤如下&#xff1a; 1.引入Nacos的配置管理客户端依赖&#xff08;A、B服务&#xff09;&#xff1a; <!--nacos的配置管理依赖--><dependency><groupId&…

Maven之tomcat7-maven-plugin 版本低的问题

tomcat7-maven-plugin 版本『低』的问题 相较于当前最新版的 tomcat 10 而言&#xff0c;tomcat7-maven-plugin 确实看起来很显老旧。但是&#xff0c;这个问题并不是问题&#xff0c;至少不是大问题。 原因 1&#xff1a;tomcat7-maven-plugin 仅用于我们&#xff08;程序员&…

rust入门系列之Rust介绍及开发环境搭建

Rust教程 Rust基本介绍 网站: https://www.rust-lang.org/ rust是什么 开发rust语言的初衷是&#xff1a; 在软件发展速度跟不上硬件发展速度&#xff0c;无法在语言层面充分的利用硬件多核cpu不断提升的性能和 在系统界别软件开发上&#xff0c;C出生比较早&#xff0c;内…

ARM DIY 硬件调试

前言 之前打样的几块 ARM 板&#xff0c;一直放着没去焊接。今天再次看到&#xff0c;决定把它焊起来。 加热台焊接 为了提高焊接效率&#xff0c;先使用加热台焊接。不过板子为双面贴片&#xff0c;使用加热台只能焊接一面&#xff0c;那就优先焊主芯片那面&#xff0c;并…

探索Perfetto:开源性能追踪工具的未来之光

探索Perfetto&#xff1a;开源性能追踪工具的未来之光 1. 引言 A. 介绍Perfetto的背景和作用 随着移动应用、桌面软件和嵌入式系统的不断发展&#xff0c;软件性能优化变得愈发重要。在这个背景下&#xff0c;Perfetto作为一款开源性能追踪工具&#xff0c;日益引起了开发者…

MATLAB算法实战应用案例精讲-【图像处理】图像分类模型ResNetResNeXtRes2Net

目录 ResNet 1. 更深层次的网络? 2. 为什么深度网络不仅仅是层数的堆叠? 2.1 梯度消失 or 爆炸