YOLOv7 pytorch,支持剪枝【附代码】

news/2024/7/10 23:27:04 标签: YOLO, pytorch, 深度学习

yolov7主干部分结构图:yolov7主干

yolov7数据集处理代码:yolov7数据集处理代码

yolov7训练参数解释:yolov7训练参数【与本文代码有区别】

yolov7训练代码详解:yolov7训练代码详解

目录

训练自己的训练集

生成推理阶段的模型

生成剪枝后的推理模型

torch转onnx

剪枝

剪枝后的微调训练

预测图像或视频


训练自己的训练集

此处的数据集是采用VOC的格式。

数据集存放格式:

─dataset
│  ├─Annotations  # 存放xml标签文件
│  ├─images # 存放图片
│  ├─ImageSets # 存放图片名称的txt文件
│  └─labels # 存放标签txt文件

先运行项目代码makeTXT:

python makeTXT.py

此时会在ImageSets下生成4个txt文件(这四个txt中仅包含每个图像的名称)

ImageSets/
|-- test.txt
|-- train.txt
|-- trainval.txt
`-- val.txt

打开voc_label.py.修改classes为自己的类。

然后运行该代码。

python voc_label.py

 将会在dataset文件下生成test.txt、train.txt、val.txt【这些txt仅包含图像路径】。然后在dataset/labels下会生成每个图像的txt【这些txt格式内容表示为类别索引+(center_x,center_y,w,h)】

接下来是配置文件的修改

打开cfg/training/yolov7.yaml。将nc修改为自己的类别数量。

接下来在data/文件下新建一个yaml文件【我这里写的是mydata.yaml】,内容如下,需要修改两个地方:

train: ./dataset/train.txt
val: ./dataset/val.txt
test: ./dataset/test.txt

# number of classes
nc: 1 # 修改处1  修改为自己的类

# class names
names: [ 'target' ]  # 修改处2 类的名称

 有关训练中的超参数设置【比如初始学习率,动量,权重衰减等,可自行在data/hyp.scratch.p5.yaml中修改】。

训练:

python train.py --weights yolov7.pt --batch-size 2 --device 0

 正常的训练将会看到以下信息。

2023-03-11 11:50:48.658 | INFO     | __main__:train:316 - 
     Epoch   gpu_mem       box       obj       cls     total    labels  img_size
     0/299     2.58G   0.04649    0.4474         0    0.4939         5       640: 100%|██████████████████████████████████████████████████████████████| 359/359 [02:39<00:00,  2.25it/s] 
               Class      Images      Labels           P           R      mAP@.5  mAP@.5:.95: 100%|████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.78it/s]
                 all          80         147         0.2       0.204       0.102      0.0191

生成推理阶段的模型

由于yolov7中训练与推理并不是一个模型,是将训练后的模型进行重参数生成新模型。

因此需要运行tools/Reparameterization.py文件。【运行前注意修改文件中的权重路径以及类的数量】

python tools/Reparameterization.py --ckpt yolov7.pt --num_classes 80 

生成剪枝后的推理模型

python tools/Reparameterization.py --ckpt runs/train/exp2/weights/best.pt --num_classes 1 --pruned 

不过发现重参数化后的剪枝模型,鲁棒性不如未参数化

将会在cfg/deploy下生成yolov7.pt

torch转onnx

修改tools/pytorch2onnx.py中的权重路径

运行该代码即可得到onnx模型

剪枝

进入tools文件,修改prunmodel.py文件中需要剪枝的权重路径。重点修改58~62行。这里是以修改model的前10层为例。

    included_layers = []
    for layer in model.model[:10]:  # 获取backbone
        if type(layer) is Conv:
            included_layers.append(layer.conv)
            included_layers.append(layer.bn)

下面代码是剪枝conv和BN层。【重点是tp.prune_conv】,自己修改amout

        if isinstance(m, nn.Conv2d) and m in included_layers:
            # amount是剪枝率
            # 卷积剪枝
            pruning_plan = DG.get_pruning_plan(m, tp.prune_conv, idxs=strategy(m.weight, amount=0.8))
            logger.info(pruning_plan)
            # 执行剪枝
            pruning_plan.exec()
        if isinstance(m, nn.BatchNorm2d) and m in included_layers:
            # BN层剪枝
            pruning_plan = DG.get_pruning_plan(m, tp.prune_batchnorm, idxs=strategy(m.weight, amount=0.8))
            logger.info(pruning_plan)
            pruning_plan.exec()
出现以下内容说明剪枝成功
【感觉不如yolov5剪的参数多,v7的剪枝感觉效果一般,请自行尝试】
2023-03-15 14:57:40.825 | INFO     | __main__:layer_pruning:84 -   Params: 37196556 => 36839795
​
2023-03-15 14:57:41.176 | INFO     | __main__:layer_pruning:95 - 剪枝完成

剪枝的模型会保存在model_data下

剪枝后的微调训练

与之前的训练一样。只不过需要传入weights,和pruned

python train.py --weights model_data/layer_pruning.pt --pruned

预测图像或视频

支持剪枝后的预测

python detect.py --weights cfg/deploy/yolov7.pt --source dataset/images/

代码:GitHub - YINYIPENG-EN/yolov7_torch: yolov7 pytorch

后续将更新tensorrt,请持续关注

如果剪枝遇到什么问题可以留言,有关精确度的问题还请自己尝试,因为每个人剪枝的地方不同,数据集不同,会有很多区别


http://www.niftyadmin.cn/n/160597.html

相关文章

BoT-SORT: Robust Associations Multi-Pedestrian Tracking 论文详细解读

BoT-SORT: Robust Associations Multi-Pedestrian Tracking 论文详细解读 文章目录BoT-SORT: Robust Associations Multi-Pedestrian Tracking 论文详细解读BoT-SORT:BoT-SORT简述修改卡尔曼滤波状态向量和其他矩阵参数相机的运动补偿IOU与Re-ID的融合实验效果MOT17&#xff1a…

论文阅读:chain of thought Prompting elicits reasoning in large language models

论文阅读&#xff1a;chain of thought Prompting elicits reasoning in large language models 跟着沐神读论文 视频链接&#xff1a;https://www.bilibili.com/video/BV1t8411e7Ug/?spm_id_from333.788&vd_source350cece3ec9a0c2aee50da8ccc315bf4 title:chain of tho…

pwn入门HTB_You know 0xDiablos例题讲解

我希望能将我的疑惑记录&#xff0c;但是堆栈函数调用这些这些&#xff0c;几句话我很难讲清楚&#xff0c;多看教程&#xff0c;好教程很多 至于解题基础&#xff0c;知道栈这种数据结构是一个线性表之后就够了&#xff0c;这题看不懂你来砍我 文章目录前言这题干什么举个例子…

vue面试题(day05)

vue面试题vue3中Composition API 的优势&#xff1f;1.了解 Options ApiCompositioncomposition APi&#xff1a;2.shallowReactive和shallowRef的区别&#xff1f;3.provide与inject如何使用&#xff1f;总结4.toRaw 与 markRaw是什么作用&#xff1f;5.readonly 与 shallowRe…

[ROC-RK3568-PC] [Firefly-Android] 10min带你了解LCD的使用

&#x1f347; 博主主页&#xff1a; 【Systemcall小酒屋】&#x1f347; 博主追寻&#xff1a;热衷于用简单的案例讲述复杂的技术&#xff0c;“假传万卷书&#xff0c;真传一案例”&#xff0c;这是林群院士说过的一句话&#xff0c;另外“成就是最好的老师”&#xff0c;技术…

API的常识与对接,商品详情数据案例分析

原则上API接口设计一般出现在开发的详细设计中&#xff0c;但是随着诸多公司建立开放平台&#xff0c;产品经理也逐渐需要能理解API接口&#xff0c;尤其是做平台性的产品&#xff0c;还要学会定义接口。本文就关于产品经理在设计接口中需要定义什么、需要注意什么来展开陈述。…

oracle19c迁移手册

windows10- 查看当前用户所有的表&#xff1a;select table_name from user_tables;- 创建用户给与权限&#xff1a;&#xff08;用户名是c##开头是因为oracle版本问题)- create user C##test identified by 1 default tablespace T1 temporary tablespace T2; grant connect,d…

Dynamics365业务理解

Dynamics 365 是微软的一套基于云端的企业应用软件&#xff0c;提供一系列业务功能模块&#xff0c;包括销售、客户服务、人力资源、财务和操作等领域。它能够帮助企业整合各个部门的数据和流程&#xff0c;提高工作效率、盈利能力和客户满意度。简介Dynamics 365是由微软公司开…