YOLO数据集划分(训练集、验证集、测试集)

news/2024/7/11 0:12:33 标签: YOLO, 深度学习, 机器学习, 人工智能, python

1.将训练集、验证集、测试集按照7:2:1随机划分

1.项目准备

1.在项目下新建一个py文件,名字就叫做splitDataset1.py

2.将自己需要划分的原数据集就放在项目文件夹下面

以我的为例,我的原数据集名字叫做hatDataXml

里面的JPEGImages装的是图片

Annotations里面装的是xml标签

2.代码实现

python">
# 将标签为xml格式的数据集按照7:2:1的比例划分为训练集,验证集和测试集

import os, shutil, random
from tqdm import tqdm


def split_img(img_path, label_path, split_list):
    try:
        Data = 'DataSet'
        # Data是你要将要创建的文件夹路径(路径一定是相对于你当前的这个脚本而言的)
        os.mkdir(Data)

        train_img_dir = Data + '/images/train'
        val_img_dir = Data + '/images/val'
        test_img_dir = Data + '/images/test'

        train_label_dir = Data + '/labels/train'
        val_label_dir = Data + '/labels/val'
        test_label_dir = Data + '/labels/test'

        # 创建文件夹
        os.makedirs(train_img_dir)
        os.makedirs(train_label_dir)
        os.makedirs(val_img_dir)
        os.makedirs(val_label_dir)
        os.makedirs(test_img_dir)
        os.makedirs(test_label_dir)

    except:
        print('文件目录已存在')

    train, val, test = split_list
    all_img = os.listdir(img_path)
    all_img_path = [os.path.join(img_path, img) for img in all_img]
    # all_label = os.listdir(label_path)
    # all_label_path = [os.path.join(label_path, label) for label in all_label]
    train_img = random.sample(all_img_path, int(train * len(all_img_path)))
    train_img_copy = [os.path.join(train_img_dir, img.split('\\')[-1]) for img in train_img]
    train_label = [toLabelPath(img, label_path) for img in train_img]
    train_label_copy = [os.path.join(train_label_dir, label.split('\\')[-1]) for label in train_label]
    for i in tqdm(range(len(train_img)), desc='train ', ncols=80, unit='img'):
        _copy(train_img[i], train_img_dir)
        _copy(train_label[i], train_label_dir)
        all_img_path.remove(train_img[i])
    val_img = random.sample(all_img_path, int(val / (val + test) * len(all_img_path)))
    val_label = [toLabelPath(img, label_path) for img in val_img]
    for i in tqdm(range(len(val_img)), desc='val ', ncols=80, unit='img'):
        _copy(val_img[i], val_img_dir)
        _copy(val_label[i], val_label_dir)
        all_img_path.remove(val_img[i])
    test_img = all_img_path
    test_label = [toLabelPath(img, label_path) for img in test_img]
    for i in tqdm(range(len(test_img)), desc='test ', ncols=80, unit='img'):
        _copy(test_img[i], test_img_dir)
        _copy(test_label[i], test_label_dir)


def _copy(from_path, to_path):
    shutil.copy(from_path, to_path)


def toLabelPath(img_path, label_path):
    img = img_path.split('\\')[-1]
    label = img.split('.jpg')[0] + '.xml'  # 因为这个数据集的标签是xml格式,所以将这里改成xml,如果标签格式是txt格式,就将这里改成txt
    return os.path.join(label_path, label)


def main():
    # 需要修改的地方:装图片的文件夹以及装标签的文件夹
    img_path = 'hatDataXml/JPEGImages'
    label_path = 'hatDataXml/Annotations'
    split_list = [0.7, 0.2, 0.1]  # 数据集划分比例[train:val:test]
    split_img(img_path, label_path, split_list)


if __name__ == '__main__':
    main()

3.需要修改的地方

1.代码65行,如果你的标签格式是txt,就将这里的xml改成txt即可

2.代码71,72行,将原数据集的图片路径和标签路径填写在这里

4.直接运行splitDataset1.py,转换成功

2.将训练集、验证集按照8:2随机划分

在项目下新建一个py文件,名字叫做splitDataset2.py

1.代码实现

python">
# 将标签格式为xml的数据集按照8:2的比例划分为训练集和验证集

import os
import shutil
import random
from tqdm import tqdm


def split_img(img_path, label_path, split_list):
    try:  # 创建数据集文件夹
        Data = 'DataSet2parts'
        os.mkdir(Data)

        train_img_dir = Data + '/images/train'
        val_img_dir = Data + '/images/val'
        # test_img_dir = Data + '/images/test'

        train_label_dir = Data + '/labels/train'
        val_label_dir = Data + '/labels/val'
        # test_label_dir = Data + '/labels/test'

        # 创建文件夹
        os.makedirs(train_img_dir)
        os.makedirs(train_label_dir)
        os.makedirs(val_img_dir)
        os.makedirs(val_label_dir)
        # os.makedirs(test_img_dir)
        # os.makedirs(test_label_dir)

    except:
        print('文件目录已存在')

    train, val = split_list
    all_img = os.listdir(img_path)
    all_img_path = [os.path.join(img_path, img) for img in all_img]
    # all_label = os.listdir(label_path)
    # all_label_path = [os.path.join(label_path, label) for label in all_label]
    train_img = random.sample(all_img_path, int(train * len(all_img_path)))
    train_img_copy = [os.path.join(train_img_dir, img.split('\\')[-1]) for img in train_img]
    train_label = [toLabelPath(img, label_path) for img in train_img]
    train_label_copy = [os.path.join(train_label_dir, label.split('\\')[-1]) for label in train_label]
    for i in tqdm(range(len(train_img)), desc='train ', ncols=80, unit='img'):
        _copy(train_img[i], train_img_dir)
        _copy(train_label[i], train_label_dir)
        all_img_path.remove(train_img[i])
    val_img = all_img_path
    val_label = [toLabelPath(img, label_path) for img in val_img]
    for i in tqdm(range(len(val_img)), desc='val ', ncols=80, unit='img'):
        _copy(val_img[i], val_img_dir)
        _copy(val_label[i], val_label_dir)


def _copy(from_path, to_path):
    shutil.copy(from_path, to_path)


def toLabelPath(img_path, label_path):
    img = img_path.split('\\')[-1]
    label = img.split('.jpg')[0] + '.xml'
    return os.path.join(label_path, label)


def main():
    img_path = 'hatDataXml/JPEGImages'
    label_path = 'hatDataXml/Annotations'
    split_list = [0.8, 0.2]  # 数据集划分比例[train:val]
    split_img(img_path, label_path, split_list)


if __name__ == '__main__':
    main()

2.需要修改的地方

跟上面的一样,如果标签类型不一样就修改标签类型,然后修改原数据集的图片路径以及标签路径。

3.结果展示


http://www.niftyadmin.cn/n/4991762.html

相关文章

Acwing 1233. 全球变暖 (每日一题)

如果你觉得这篇题解对你有用,可以点个赞或关注再走呗,谢谢你的关注~ 题目描述 你有一张某海域 NN 像素的照片,”.”表示海洋、”#”表示陆地,如下所示: … .##… .##… …##. …####. …###. … 其中”上下左右”…

MVC、MVP、MVVM的成本角度结合业务,如何考虑选型?一文了解方方面面

大家都知道,使用架构的目的是使程序模块化,做到模块内部的高聚合和模块之间的低耦合,使得程序在开发的过程中,开发人员只需要专注于一点,提高程序开发的效率。那么MVC、MVP、MVVM,该怎么选?在什…

Unity中立体声平移的应用

实现的效果 若从左声道开始,播放效果逐渐从左声道过渡到右声道,再从右声道过渡到左声道,具体效果请戴上耳机播放下列视频。 StereoPanning 代码实现 public class AudioInfo {[HideInInspector] public float[] StereoTranslationValues;//立…

java 浅谈ThreadLocal底层源码(通俗易懂)

目录 一、ThreadLocal类基本介绍 1.概述 : 2.作用及特定 : 二、ThreadLocal类源码解读 1.代码准备 : 1.1 图示 1.2 数据对象 1.3 测试类 1.4 运行测试 2.源码分析 : 2.1 set方法解读 2.2 get方法解读 一、ThreadLocal类基本介绍 1.概述 : (1) ThreadLocal,本…

Is f(z)=1/z truly an analytic function

https://math.stackexchange.com/questions/755566/is-fz-1-z-truly-an-analytic-function

【C++】C++ 引用详解 ⑩ ( 常量引用案例 )

文章目录 一、常量引用语法1、语法简介2、常引用语法示例 二、常量引用语法1、int 类型常量引用示例2、结构体类型常量引用示例 在 C 语言中 , 常量引用 是 引用类型 的一种 ; 借助 常量引用 , 可以将一个变量引用 作为实参 传递给一个函数形参 , 同时保证该值不会在函数内部被…

【VTK】 vtkMapper

很高兴在雪易的CSDN遇见你 ,给你糖糖 欢迎大家加入雪易社区-CSDN社区云 前言 本文主要分享VTK中关于vtkMapper的相关知识和使用方法,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动力(^U^)ノ~YO </

Spring版本与JDK版本演变

Java各版本变更核心API Java8 lambada表达式函数式接口方法引用默认方法Stream API 对元素流进行函数式操作Optional 解决NullPointerExceptionDate Time API重复注解 RepeatableBase64使用元空间Metaspace代替持久代&#xff08;PermGen space&#xff09; Java7 switch 支…