【目标检测】YOLOv5算法实现(九):模型预测

news/2024/7/11 0:12:09 标签: 目标检测, YOLO, yolov5, 模型预测

  本系列文章记录本人硕士阶段YOLO系列目标检测算法自学及其代码实现的过程。其中算法具体实现借鉴于ultralytics YOLO源码Github,删减了源码中部分内容,满足个人科研需求。
  本系列文章主要以YOLOv5为例完成算法的实现,后续修改、增加相关模块即可实现其他版本的YOLO算法。

文章地址:
YOLOv5算法实现(一):算法框架概述
YOLOv5算法实现(二):模型加载
YOLOv5算法实现(三):数据集加载
YOLOv5算法实现(四):正样本匹配与损失计算
YOLOv5算法实现(五):预测结果后处理
YOLOv5算法实现(六):评价指标及实现
YOLOv5算法实现(七):模型训练
YOLOv5算法实现(八):模型验证
YOLOv5算法实现(九):模型预测

本文目录

引言

  本篇文章综合之前文章中的功能,实现模型的预测。模型预测的逻辑如图1所示。

在这里插入图片描述

图1 模型预测流程

模型预测(predict.py)

def predice():
    img_size = 640  # 必须是32的整数倍 [416, 512, 608]
    file = "yolov5s"
    cfg = f"cfg/models/{file}.yaml"  # 改成生成的.cfg文件
    weights_path = f"weights/{file}/{file}.pt"  # 改成自己训练好的权重文件
    json_path = "data/dataset.json"  # json标签文件
    img_path = "test.jpg"
    save_path = f"results/{file}/test_result8.jpg"
    assert os.path.exists(cfg), "cfg file {} dose not exist.".format(cfg)
    assert os.path.exists(weights_path), "weights file {} dose not exist.".format(weights_path)
    assert os.path.exists(json_path), "json file {} dose not exist.".format(json_path)
    assert os.path.exists(img_path), "image file {} dose not exist.".format(img_path)

    with open(json_path, 'r') as f:
        class_dict = json.load(f)

    category_index = {str(v): str(k) for k, v in class_dict.items()}

    input_size = (img_size, img_size)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")
    model = Model(cfg, ch=3, nc=3)
    weights_dict = torch.load(weights_path, map_location='cpu')
    weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
    model.load_state_dict(weights_dict, strict=False)
    model.to(device)

    model.eval()
    with torch.no_grad():
        # init
        img = torch.zeros((1, 3, img_size, img_size), device=device)
        model(img)

        img_o = cv2.imread(img_path)  # BGR
        assert img_o is not None, "Image Not Found " + img_path

        img = letterbox(img_o, new_shape=input_size, auto=True, color=(0, 0, 0))[0]

        # Convert
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        img = np.ascontiguousarray(img)

        img = torch.from_numpy(img).to(device).float()
        img /= 255.0  # scale (0, 255) to (0, 1)
        img = img.unsqueeze(0)  # add batch dimension

        t1 = torch_utils.time_synchronized()
        pred = model(img)[0]  # only get inference result
        t2 = torch_utils.time_synchronized()
        print("inference time: {}s".format(t2 - t1))
        print('model: {}'.format(file))



        pred = utils.non_max_suppression(pred, conf_thres=0.1, iou_thres=0.6, multi_label=True)[0]
        t3 = time.time()
        print("post-processing time: {}s".format(t3 - t2))



        # process detections
        pred[:, :4] = utils.scale_coords(img.shape[2:], pred[:, :4], img_o.shape).round()


        bboxes = pred[:, :4].detach().cpu().numpy()
        scores = pred[:, 4].detach().cpu().numpy()
        classes = pred[:, 5].detach().cpu().numpy().astype(np.int) + 1

        pil_img = Image.fromarray(img_o[:, :, ::-1])
        plot_img = draw_objs(pil_img,
                             bboxes,
                             classes,
                             scores,
                             category_index=category_index,
                             box_thresh=0.2,
                             line_thickness=3,
                             font='arial.ttf',
                             font_size=30)
        plt.imshow(plot_img)
        plt.show()
        # 保存预测的图片结果

        plot_img.save(save_path)


if __name__ == "__main__":
    predict()

http://www.niftyadmin.cn/n/5337702.html

相关文章

快速了解spring boot中的@idempotent注解

目的:一定时间内,同样的请求(业务参数相同)访问同一个接口,则只能成功一次,其余被拒绝 幂等实现原理就是利用AOP面向切面编程,在执行业务逻辑之前插入一个方法,生成一个token,存入redis并插入到…

高性能前端UI库 SolidJS | 超棒 NPM 库

SolidJS是一个声明式的、高效的、编译时优化的JavaScript库,用于构建用户界面。它的核心特点是让你能够编写的代码既接近原生JavaScript,又能够享受到现代响应式框架提供的便利。 SolidJS的设计哲学强调了性能与简洁性。它不使用虚拟DOM(Vir…

Broadcom交换芯片56620架构

文章目录 架构1.系统逻辑视图2.逻辑芯片视图3.芯片框图4.MIIM(Medium Independent Interface Management)5.交换结构6.CAP 架构 1.系统逻辑视图 Ingress Chip作用: 解析报文128字节的头部(MMU(Memory Management Uni…

【机器学习300问】9、梯度下降是用来干嘛的?

当你和我一样对自己问出这个问题后,分析一下!其实我首先得知道梯度下降是什么,也就它的定义。其次我得了解它具体用在什么地方,也就是使用场景。最后才是这个问题,梯度下降有什么用?怎么用? 所以…

sqli-labs关卡25(基于get提交的过滤and和or的联合注入)

文章目录 前言一、回顾上一关知识点二、靶场第二十五关通关思路1、判断注入点2、爆字段个数3、爆显位位置4、爆数据库名5、爆数据库表名6、爆数据库列名7、爆数据库数据 总结 前言 此文章只用于学习和反思巩固sql注入知识,禁止用于做非法攻击。注意靶场是可以练习的…

【系统调用】常用系统调用函数(三)

系统调用概念 操作系统的职责 操作系统用来管理所有的资源,并将不同的设备和不同的程序关联起来。什么是Linux系统编程 在有操作系统的环境下编程,并使用操作系统提供的系统调用及各种库,对系统资源进行访问。系统编程主要就是为了让用户能够…

小程序使用echarts图表-雷达图

本文介绍下小程序中如何使用echarts 如果是通过npm安装,这样是全部安装的,体积有点大 我这边是使用echarts中的一个组件来实现的,下边是具体流程,实际效果是没有外边的红色边框的,加红色边框的效果是这篇说明 1.echa…

maven 基本知识/1.17

maven ●maven是一个基于项目对象模型(pom)的项目管理工具,帮助管理人员自动化构建、测试和部署项目 ●pom是一个xml文件,包含项目的元数据,如项目的坐标(GroupId,artifactId,version )、项目的依赖关系、构建过程 ●生命周期&…