YOLOv5 分类模型 数据集加载 1

news/2024/7/10 23:44:43 标签: YOLO, 分类

YOLOv5 分类模型 数据集加载 1

flyfish

数据集的加载 python实现,不使用torch库

目标:得到样本前面是图像文件路径,后面是标签索引

samples: [('/media/a/flyfish/test/n01440764/ILSVRC2012_val_00000293.JPEG', 0),
          ('/media/a/flyfish/test/n01440764/ILSVRC2012_val_00002138.JPEG', 0),
          ('/media/a/flyfish/test/n01440764/ILSVRC2012_val_00003014.JPEG', 0),
          ('/media/a/flyfish/test/n01440764/ILSVRC2012_val_00006697.JPEG', 0))]

简化实现

import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union


class DatasetFolder:

    def __init__(
        self,
        root: str,

    ) -> None:
        self.root=root
        classes, class_to_idx = self.find_classes(self.root)
        samples = self.make_dataset(self.root, class_to_idx)
        
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]


    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Optional[Dict[str, int]] = None,

    ) -> List[Tuple[str, int]]:
 
        directory = os.path.expanduser(directory)

        if class_to_idx is None:
            _, class_to_idx = self.find_classes(directory)
        elif not class_to_idx:
            raise ValueError("'class_to_index' must have at least one entry to collect any samples.")



        instances = []
        available_classes = set()
        for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_dir = os.path.join(directory, target_class)
            if not os.path.isdir(target_dir):
                continue
            for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                for fname in sorted(fnames):
                    path = os.path.join(root, fname)
                    if 1:#验证:
                        item = path, class_index
                        instances.append(item)

                        if target_class not in available_classes:
                            available_classes.add(target_class)

        empty_classes = set(class_to_idx.keys()) - available_classes
        if empty_classes:
            msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "


        return instances

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
 
        classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
        if not classes:
            raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx



dataset =  DatasetFolder(root="/media/a/flyfish/test");

print(dataset)
print("dataset.targets:",dataset.targets)
print("dataset.classes:",dataset.classes)
print("samples:",dataset.samples)

find_classes 将标签索引和标签内容对应

0,1,2是标签索引
'n01440764', 'n01443537', 'n01484850'是类别名字也是文件夹名字
按照升序排序

dataset.targets: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
dataset.classes: ['n01440764', 'n01443537', 'n01484850']

样本中一个是图像文件的绝对路径,后面的是标签

samples: [('/media/a/flyfish/test/n01440764/ILSVRC2012_val_00000293.JPEG', 0),
          ('/media/a/flyfish/test/n01440764/ILSVRC2012_val_00002138.JPEG', 0),
          ('/media/a/flyfish/test/n01440764/ILSVRC2012_val_00003014.JPEG', 0),
          ('/media/a/flyfish/test/n01440764/ILSVRC2012_val_00006697.JPEG', 0),
          ('/media/a/flyfish/test/n01443537/ILSVRC2012_val_00000236.JPEG', 1),
          ('/media/a/flyfish/test/n01443537/ILSVRC2012_val_00000262.JPEG', 1),
          ('/media/a/flyfish/test/n01443537/ILSVRC2012_val_00000307.JPEG', 1),
          ('/media/a/flyfish/test/n01443537/ILSVRC2012_val_00000994.JPEG', 1),
          ('/media/a/flyfish/test/n01484850/ILSVRC2012_val_00002338.JPEG', 2),
          ('/media/a/flyfish/test/n01484850/ILSVRC2012_val_00002752.JPEG', 2),
          ('/media/a/flyfish/test/n01484850/ILSVRC2012_val_00004311.JPEG', 2),
          ('/media/a/flyfish/test/n01484850/ILSVRC2012_val_00004329.JPEG', 2)]

可以功能丰富一些,例如检测文件的扩展名是否是支持的图像文件

IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")

def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
    """检查文件是否为允许的扩展名
    """
    return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))

def is_image_file(filename: str) -> bool:
    return has_file_allowed_extension(filename, IMG_EXTENSIONS)

测试

r=is_image_file("/media/a/flyfish/data/imagewoof/val/n02086240/1.jpeg");

print(r)#True

r=is_image_file("/media/a/flyfish/data/imagewoof/val/n02086240/1.txt");

print(r)#False

http://www.niftyadmin.cn/n/5182804.html

相关文章

PowerPoint技巧:如何将一张图片同时加到全部幻灯片里?

想把一张图片加到PPT每一张幻灯片的同一个位置,如果一张一张的添加就太耗时间了,一起来看看如何利用母版快速设置同时添加吧。 首先,打开需要编辑的PPT,在菜单栏依次点击【视图】→【幻灯片母版】; 打开母版后&#x…

数据库SQLite3 笔记

浅显易懂 SQLite3 笔记(04)— SQL数据更新(增加、删除、修改)_sqlite 删除新增字段-CSDN博客 SQLite 在一条语句中添加多个列|极客教程

普通测径仪升级的智能测径仪 增添11大实用功能!

普通测径仪能对各种钢材进行非接触式的外径及椭圆度在线检测,测量数据准确且无损,可测、监测、超差提示、系统分析等。在此基础上,为测径仪进行了进一步升级制成智能测径仪,为其增添更多智能化模块,让其使用更加方便。…

springboot整合vue2实现简单的新增删除,整合ECharts实现图表渲染

先看效果图&#xff1a; 1.后端接口 // 查询所有商品信息 // CrossOrigin(origins "*")RequestMapping("/list1")ResponseBodypublic List<Goodsinfo> list1(){List<Goodsinfo> list goodsService.list();return list;}// 删除 // …

python图神经网络,注意力机制、Transformer模型、目标检测算法、强化学习等

近年来&#xff0c;伴随着以卷积神经网络&#xff08;CNN&#xff09;为代表的深度学习的快速发展&#xff0c;人工智能迈入了第三次发展浪潮&#xff0c;AI技术在各个领域中的应用越来越广泛 本文重点为&#xff1a;注意力机制、Transformer模型&#xff08;BERT、GPT-1/2/3/…

2023年人工智能还好找工作吗?

人工智能的就业形势并不严峻&#xff0c;相反&#xff0c;很多岗位都是供不应求的状态&#xff0c;可以看一下下面的官方数据。 脉脉高聘人才智库发布《2023泛人工智能人才洞察》&#xff0c;对23年1-8月的人工智能行业现状进行了分析总结。 人工智能相关岗位数据&#xff1a…

【数据处理】python Matplotlib将图进行局部放大;标出所关注的局部放大子图

前言 在数据可视化中&#xff0c;很多时候需要对某一区间的数据进行局部放大&#xff0c;以获得对比度更高的可视化效果。下面利用 Python 语言的 Matplotlib 库实现一个简单的局部放大图效果。 依赖库 matplotlib&#xff1a;绘图库 numpy&#xff1a;支持大量的维度数组、…

KVM给虚拟Linux加磁盘

添加一块 qcow2的磁盘 virsh attach-disk centos /kvm/vdisks/centos-diskadd.qcow2 vdb --subdriver qcow2这个命令的含义是将一个额外的虚拟磁盘(centos-diskadd.qcow2)连接到名为centos的虚拟机上&#xff0c;并将它作为vdb设备进行挂载。 参数的含义&#xff1a; virsh:…