YOLOv8 训练自己的数据集

news/2024/7/10 22:55:40 标签: YOLO, 数学建模, 深度学习, 人工智能

本范例我们使用 ultralytics中的YOLOv8目标检测模型训练自己的数据集,从而能够检测气球。

#安装
!pip install -U ultralytics -i https://pypi.tuna.tsinghua.edu.cn/simple
import ultralytics 
ultralytics.checks()

一,准备数据

公众号算法美食屋后台回复关键词:yolov8,获取本文notebook源代码和数据集。

训练yolo模型需要将数据集整理成yolo数据集格式。然后写一个yaml的数据集配置文件。

yolo_dataset
├── images
│   ├── train
│   │   ├── train0.jpg
│   │   └── train1.jpg
│   ├── val
│   │   ├── val0.jpg
│   │   └── val1.jpg
│   └── test
│       ├── test0.jpg
│       └── test1.jpg
└── labels
    ├── train
    │   ├── train0.txt
    │   └── train1.txt
    ├── val
    │   ├── val0.txt
    │   └── val1.txt
    └── test
        ├── test0.txt
        └── test1.txt

其中标签文件(如train0.txt)格式如下:

class_id center_x center_y bbox_width bbox_height
0 0.300926 0.617063 0.601852 0.765873
1 0.575 0.319531 0.4 0.551562

注意class_id从0开始,中心点坐标和高宽都是相对坐标。

使用 Labelme或者 makesense标注样本可以直接导出该种类型样本。

%%writefile balloon.yaml
# Ultralytics YOLO 🚀, GPL-3.0 license

path: /tf/liangyun2/torchkeras/notebooks/datasets/balloon   # dataset root dir
train: images/train  # train images (relative to 'path') 128 images
val: images/val  # val images (relative to 'path') 128 images
test:  # test images (optional)

# Classes
names:
  0: ballon
Overwriting balloon.yaml
import torch
from torch.utils.data import DataLoader
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.utils import DEFAULT_CFG,yaml_load 
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.yolo.data import build_yolo_dataset,build_dataloader

overrides = {'task':'detect',
             'data':'balloon.yaml',
             'imgsz':640,
             'workers':4
            }
cfg = get_cfg(cfg = DEFAULT_CFG,overrides=overrides)
data_info = check_det_dataset(cfg.data)
ds_train = build_yolo_dataset(cfg,img_path=data_info['train'],batch=cfg.batch,
                              data_info = data_info,mode='train',rect=False,stride=32)

ds_val = build_yolo_dataset(cfg,img_path=data_info['val'],batch=cfg.batch,data_info = data_info,
    mode='val',rect=False,stride=32)
#dl_train = build_dataloader(ds_train,batch=cfg.batch,workers=0)
#dl_val = build_dataloader(ds_val,batch=cfg.batch,workers =0,shuffle=False)
dl_train = DataLoader(ds_train,batch_size = cfg.batch, num_workers = cfg.workers,
                      collate_fn = ds_train.collate_fn)

dl_val = DataLoader(ds_val,batch_size = cfg.batch, num_workers = cfg.workers,
                      collate_fn = ds_val.collate_fn)
for batch in dl_val:
    break
batch.keys()
dict_keys(['im_file', 'ori_shape', 'resized_shape', 'ratio_pad', 'img', 'cls', 'bboxes', 'batch_idx'])

二,定义模型

from ultralytics.nn.tasks import DetectionModel

model = DetectionModel(cfg = 'yolov8n.yaml', ch=3, nc=1)
#weights = torch.hub.load_state_dict_from_url('https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt')
weights = torch.load('yolov8n.pt')
model.load(weights['model'])
model.args = cfg
model.nc = data_info['nc']  # attach number of classes to model
model.names = data_info['names']

三,训练模型

1,使用ultralytics原生接口

使用ultralytics的原生接口,只需要以下几行代码即可。

from ultralytics import YOLO 
yolo_model = YOLO('yolov8n.pt')

yolo_model.train(data='balloon.yaml',epochs=10)

0796ae19d15665d4d116e5ece0842c5f.png


2,使用torchkeras梦中情炉

尽管使用ultralytics原生接口非常简单,再使用torchkeras实现自定义训练逻辑似乎有些多此一举。

但ultralytics的源码结构相对复杂,不便于用户做个性化的控制和修改。

并且,torchkeras在可视化上会比ultralytics的原生训练代码优雅许多。

此外,掌握自定义训练逻辑对大家熟悉ultralytics这个库的代码结构也会有所帮助。

for batch in dl_train:
    break
from ultralytics.yolo.v8.detect.train import Loss 

model.cuda()
loss_fn = Loss(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) 


x = batch['img'].float()/255 

preds = model.forward(x.cuda())
loss = loss_fn(preds,batch)[0]
print(loss)
tensor(74.5465, device='cuda:0', grad_fn=<MulBackward0>)
from torchkeras import KerasModel 

#我们需要修改StepRunner以适应Yolov8的数据集格式

class StepRunner:
    def __init__(self, net, loss_fn, accelerator, stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator
        if self.stage=='train':
            self.net.train() 
        else:
            self.net.eval()
    
    def __call__(self, batch):
        
        features = batch['img'].float() / 255
        
        #loss
        preds = self.net(features)
        loss = self.loss_fn(preds,batch)[0]

        #backward()
        if self.optimizer is not None and self.stage=="train":
            self.accelerator.backward(loss)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
            
        all_preds = self.accelerator.gather(preds)
        all_loss = self.accelerator.gather(loss).sum()
        
        #losses
        step_losses = {self.stage+"_loss":all_loss.item()}
        
        #metrics
        step_metrics = {}
        
        if self.stage=="train":
            if self.optimizer is not None:
                step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
            else:
                step_metrics['lr'] = 0.0
        return step_losses,step_metrics
    
KerasModel.StepRunner = StepRunner
keras_model = KerasModel(net = model, 
                         loss_fn = loss_fn, 
                         optimizer = optimizer)
keras_model.fit(train_data=dl_train,
                val_data=dl_val,
                epochs = 200,
                ckpt_path='checkpoint.pt',
                patience=20,
                monitor='val_loss',
                mode='min',
                mixed_precision='no',
                plot= True,
                wandb = False,
                quiet = True
               )

9e02b43baca40414b19510ff7a3cb212.png

d242920e834f1e8615503e8581c88c0f.png

四,评估模型

为了便于评估 map等指标,我们将权重再次保存后,用ultralytics的原生YOLO接口进行加载后评估。

keras_model.evaluate(dl_val)
100%|██████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.32it/s, val_loss=28.7]



{'val_loss': 28.715129852294922}
from ultralytics import YOLO 
keras_model.load_ckpt('checkpoint.pt')
save_dic = dict(model = keras_model.net, train_args =dict(cfg))
torch.save(save_dic, 'best_yolo.pt')
from ultralytics import YOLO 
best_model = YOLO(model = 'best_yolo.pt')
metrics = best_model.val(data = cfg.data )
metrics.results_dict
{'metrics/precision(B)': 0.9188790992746612,
 'metrics/recall(B)': 0.74,
 'metrics/mAP50(B)': 0.8516599658911874,
 'metrics/mAP50-95(B)': 0.7321355695315829,
 'fitness': 0.7440880091675434}
import pandas as pd 
df = pd.DataFrame()
df['metric'] = metrics.keys
for i,c in best_model.names.items():
    df[c] = metrics.class_result(i)

df

f40c837b5009663440bae9772217fca8.png

五,使用模型

from pathlib import Path 
root_path = './datasets/balloon/'
data_root = Path(root_path)

best_model = YOLO(model = 'best_yolo.pt')
val_imgs = [str(x) for x in (data_root/'images'/'train').rglob("*.jpg") if 'checkpoint' not in str(x)]
img_path = val_imgs[5]
import os 
from PIL import Image 
result = best_model.predict(source = img_path,save=True)
best_model.predictor.save_dir/os.path.basename(img_path)
Image.open(best_model.predictor.save_dir/os.path.basename(img_path))

639b298b28c05879b3bef9a5c94303da.png

六,导出模型

best_model.export(format='onnx')
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
predictor = DetectionPredictor(
    overrides=dict(model='best_yolo.onnx'))
results = list(predictor.stream_inference(source=img_path))

公众号算法美食屋后台回复关键词:yolov8,获取本文notebook源代码和数据集。


http://www.niftyadmin.cn/n/397254.html

相关文章

rimraf : 无法加载文件 C:\Program Files\nodejs\rimraf.ps1,因为在此系统上禁止运行脚本。

问题&#xff1a; rimraf 运行 rimraf node_modules 命令报错&#xff1a;&#xff08;rimraf 前端同学可以多多了解 Vue、React使用 rimraf 快速删除node_modules_码农键盘上的梦的博客-CSDN博客&#xff09;rimraf : 无法加载文件 C:\Program Files\nodejs\rimraf.ps1&…

零入门kubernetes网络实战-34->将物理网卡eth0挂载到虚拟网桥上使得内部网络能够跨主机ping通外网的方案

《零入门kubernetes网络实战》视频专栏地址 https://www.ixigua.com/7193641905282875942 本篇文章视频地址(稍后上传) 本篇文章模拟一下啊&#xff0c;将宿主机的对外的物理网卡&#xff0c;挂载到虚拟网桥上&#xff0c;测试一下&#xff0c; 网桥管理的内部网络如何跟宿主…

分别使用Observable、LiveData、RxJava监听List<T>的内容变化

在Java中&#xff0c;如果要监听List的内容变化&#xff0c;可以使用Java自带的观察者模式Observable或者第三方库(LiveData \RxJava)实现&#xff0c;下面分别介绍&#xff1a; 三者的区别优劣势、使用场景分析 下面是 Observable、LiveData、RxJava 监听 bean 对象的区别整…

讨论和总结 树模型 的三种序列化 方式的区别(模型存储大小、序列化所用内存、序列化速度)...

一、前言 本文总结常用树模型&#xff1a; rf&#xff0c;xgboost&#xff0c;catboost和lightgbm等模型的保存和加载&#xff08;序列化和反序列化&#xff09;的多种方式&#xff0c;并对多种方式从运行内存的使用和存储大小做对比 二、模型 2.1 安装环境 pip install xgboos…

【MySQL】一文带你了解MySQL中的子查询

文章目录 1. 需求分析与问题解决1. 1实际问题1.2 子查询的基本使用1.3 子查询的分类 2. 单行子查询2.1 单行比较操作符2.2 代码示例2.3 HAVING 中的子查询2.4 注意的问题 3. 多行子查询3.1 多行比较操作符3.2 代码示例 4. 相关子查询4.1 相关子查询执行流程4.2 代码示例 子查询…

2020年一月联考逻辑真题

2020年一月联考逻辑真题 真题&#xff08;2020-26&#xff09;-翻译推理题-递推推理 26.领导干部对于各种批评意见应采取有则改之&#xff0c;无则加勉的态度&#xff0c;营造言者无罪&#xff0c;闻者足戒的氛围。只有这样&#xff0c;人们才能知无不言&#xff0c;言无不尽。…

⑭【动态时空图卷积网络 · 注意力 · 交通速度预测】时空依赖关系挖掘 | 动态时空建模 | 智能交通系统 |

所谓成功,就是用自己的方式度过人生。 ————《明朝那些事儿》作者,当年明月 🎯作者主页: 追光者♂🔥 🌸个人简介: 💖[1] 计算机专业硕士研究生💖 🌟[2] 2022年度博客之星人工智能领域TOP4🌟 🏅[3] 阿里云社区特邀专家博主🏅 🏆[4]…

【Python 异步编程】零基础也能轻松掌握的学习路线与参考资料

Python 异步编程学习路线&#xff1a; 1.理解同步和异步编程模型的区别&#xff0c;了解使用异步编程的优缺点。 同步编程是指一个任务执行完毕后再执行下一个任务&#xff0c;而异步编程则是在任务执行的同时还可以继续执行其他任务。 异步编程优点&#xff1a; (1)性能优…