yolov5训练coco数据集的部分类别

news/2024/7/10 23:32:34 标签: YOLO

yolov5训练coco数据集的部分类别

    • 创建容器
    • 准备yolov5环境
    • 定义需要训练的类别(coco-6.yaml)
    • 根据coco-6.yaml中保留的类别,生成新的数据集
    • 生成新数据集
    • 训练
    • 测试

在测试yolov5系列不同类别的模型在各种加速卡上的精度和性能时,我们希望得到一个准确的评估结果。因此,本文从一个COCO数据集中创建一个子集,该子集仅包含特定的类别。具体来说,它首先从源数据集中读取JSON文件,然后过滤出所需的类别,并将它们保存到新的JSON文件中。接下来,它将所需的图像和标签复制到新的目标目录中。最后,它创建一个包含所有图像文件路径的文本文件,并更新数据集的YAML配置文件。以此为数据集,训练并测试模型,从而得到准确的评估结果。

创建容器

mkdir yolov5
cd yolov5
docker run -it --gpus all --name yolov5_dev -v $PWD:/home/ cuda_dev_image:v1.0 bash

准备yolov5环境

apt update
apt install git -y
git clone https://github.com/ultralytics/yolov5
cd yolov5
bash data/scripts/get_coco.sh --train --val
cd ../datasets/coco/
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
rm -rf annotations
unzip annotations_trainval2017.zip

定义需要训练的类别(coco-6.yaml)

path: /home/dataset/coco 
train: train2017.txt 
val: val2017.txt
names:
  0: person
  1: bicycle
  2: car
  3: motorcycle
  4: airplane
  5: bus

根据coco-6.yaml中保留的类别,生成新的数据集

#create_sub_coco_dataset.py

import json
import yaml
import sys
import os
import shutil
import tqdm

def MakeDirs(dir):
	if not os.path.exists(dir):
		os.makedirs(dir,True)

def create_sub_coco_dataset(data_yaml="coco-6.yaml",
							src_root_dir="../datasets/coco",
							dst_root_dir="class6/coco",
							folder="val2017"
							):

	MakeDirs(dst_root_dir+"/annotations/")
	MakeDirs(dst_root_dir+"/images/"+folder)
	MakeDirs(dst_root_dir+"/labels/"+folder)

	keep_names=[x+1 for x in yaml.safe_load(open(data_yaml).read())['names'].keys()]
	all_annotations=json.loads(open(src_root_dir+"/annotations/instances_{}.json".format(folder)).read())
	keep_categories=[x for x in all_annotations["categories"] if x["id"] in keep_names]
	keep_annotations=[x for x in all_annotations['annotations'] if x['category_id'] in keep_names]

	all_annotations['annotations']=keep_annotations
	all_annotations["categories"]=keep_categories
	
	if not os.path.exists(dst_root_dir+"/annotations/instances_{}.json".format(folder)):
		with open(dst_root_dir+"/annotations/instances_{}.json".format(folder), "w") as f:
			json.dump(all_annotations, f)
		
	filelist=set()
	for i in tqdm.tqdm(keep_annotations):
		img_src_path="/images/{}/{:012d}.jpg".format(folder,i["image_id"])
		label_src_path="/labels/{}/{:012d}.txt".format(folder,i["image_id"])
		if not os.path.exists(dst_root_dir+img_src_path):
			shutil.copy(src_root_dir+img_src_path, dst_root_dir+img_src_path)
		if not os.path.exists(dst_root_dir+label_src_path):
			keep_records=[x for x in open(src_root_dir+label_src_path,"r").readlines() if (int(x.strip().split(" ")[0])+1) in keep_names]
			with open(dst_root_dir+label_src_path,"w") as f:
				for r in keep_records:
					f.write(r)
		filelist.add("./images/{}/{:012d}.jpg\n".format(folder,i["image_id"]))
		
	with open(dst_root_dir+"/{}.txt".format(folder),"w") as f:
		for r in filelist:
			f.write(r)
			
	new_data_yaml=yaml.safe_load(open(data_yaml).read())
	new_data_yaml["path"]=dst_root_dir

	with open(dst_root_dir+"/coco.yaml", 'w') as f:
		f.write(yaml.dump(new_data_yaml, allow_unicode=True))

create_sub_coco_dataset(sys.argv[1],sys.argv[2],sys.argv[3],sys.argv[4])

生成新数据集

cd /home/yolov5
rm -rf class6
python create_sub_coco_dataset.py coco-6.yaml ../datasets/coco class6/coco train2017
python create_sub_coco_dataset.py coco-6.yaml ../datasets/coco class6/coco val2017

训练

python train.py --data class6/coco/coco.yaml \
				--weights '' --cfg models/yolov5m.yaml \
				--img 640 --workers 0 --device 0

测试

python val.py --weights best.pt --data class6/coco/coco.yaml \
					--img 640 --conf-thres 0.001 --iou-thres 0.6 \
					--workers 0 --device 0 --half --batch-size 1

http://www.niftyadmin.cn/n/5418512.html

相关文章

近红外透光率测量积分球

透光率积分球是一个空心的漫反射球腔,具有漫反射的内表面。它的基本原理是光通过采样口被积分球收集,在积分球内部经过多次反射后非常均匀地散射在积分球内部。这种设计使得积分球可以作为光收集器,被收集的光可以用作漫反射光源或被测光源。…

【杂记】IDEA和Eclipse如何查看GC日志

1.Eclipse查看GC日志 1.1 右击代码编辑区 -> Run As -> Run Configurations 1.2 点击Arguments栏 -> VM arguments:区域填写XX参数 -> Run 1.3 控制台输出GC详细日志 2.IDEA查看GC日志 2.1 鼠标右击代码编辑器空白区域,选择Edit 项目名.main()... 2.…

pip3命令行下载

使用pip3下载包时,报错如下,发现没有安装pip3: 安装步骤如下: 1.更新包列表: 其中需要输入服务器密码 sudo apt update 2.安装pip3: 输入Y sudo apt install python3-pip 3.查看pip3版本,出现则表示安装成功 pip…

蓝桥杯2023年-飞机降落(暴力枚举,next_permutation)

题目描述 N 架飞机准备降落到某个只有一条跑道的机场。其中第 i 架飞机在 Ti 时刻到达机场上空,到达时它的剩余油料还可以继续盘旋 Di 个单位时间,即它最早 可以于 Ti 时刻开始降落,最晚可以于 Ti Di 时刻开始降落。降落过程需要 Li个单位…

分割模型TransNetR的pytorch代码学习笔记

这个模型在U-net的基础上融合了Transformer模块和残差网络的原理。 论文地址:https://arxiv.org/pdf/2303.07428.pdf 具体的网络结构如下: 网络的原理还是比较简单的, 编码分支用的是预训练的resnet模块,解码分支则重新设计了。…

Python 语句介绍

Python 解释 Python是一种高级编程语言,以其简洁、易读和易用而闻名。它是一种通用的、解释型的编程语言,适用于广泛的应用领域,包括软件开发、数据分析、人工智能等。python是一种解释型,面向对象、动态数据类型的高级程序设计…

HTTPS安全机制解析:如何保护我们的数据传输

目录 引言 HTTPS的核心安全机制 结论 引言 在数字时代,网络安全成为了互联网用户和服务提供者不可忽视的关键议题。特别是,HTTPS(全称为HyperText Transfer Protocol Secure)相比于其前身HTTP(HyperText Transfer P…

llama2c(4)之forward、sample、decode

1、forward float* logits forward(transformer, token, pos); 输入transformer的参数,当前token,pos位置,预测出下一个token的预测值(用矩阵乘,加减乘除等运算构成Transformer) (gdb) p *logits $9 2.19…