YOLOv5模型转ONNX,ONNX转TensorRT Engine

news/2024/7/11 1:30:01 标签: YOLO

系列文章目录

第一章 YOLOv5模型训练集标注、训练流程
第二章 YOLOv5模型转ONNX,ONNX转TensorRT Engine
第三章 TensorRT量化

文章目录

  • 系列文章目录
  • 前言
  • 一、yolov5模型导出ONNX
    • 1.1 工作机制
    • 1.2 修改yolov5代码,输出ONNX
  • 二、TensorRT部署
    • 2.1 模型部署
    • 2.2 模型推理
  • 总结


前言

学习笔记–恩培老师


一、yolov5模型导出ONNX

1.1 工作机制

使用tensort deconde plugin 来替代yolov5代码中的deconde操作,需要修改yolov5代码导出onnx模型的部分。

在这里插入图片描述

1.2 修改yolov5代码,输出ONNX

批量修改

#将patch复制到yolov5文件夹
cp export.patch yolov5/
#进入yolov5文件夹
cd yolov5/
#应用patch
git am export.patch

安装需要依赖

pip install seaborn
pip install onnx-graphsurgeon
pip install opencv-python==4.5.5.64
pip install onnx-simplifier==0.3.10

apt update
apt install -y libgl1-mesa-glx

安装完成后,准备训练好的模型文件,默认为yolov5s.pt,然后执行下列代码,生成Onnx文件。

安装不上onnx-graphsurgeon,使用下面的命令再次安装

pip install nvidia-pyindex
pip install onnx-graphsurgeon
python export.py --weights weights/yolov5s_person.pt --include onnx --simplify

这里的yolov5s_person.pt文件就是我们刚刚训练好的best.pt复制过来的。
可视化模型工具

pip install netron
netron ./weights/yolov5s_person.onnx

二、TensorRT部署

使用TensorRT docker容器:

docker run --gpus all -it --name env_trt -v ${pwd}: /app nvcr.io/nvidia/tensorrt:22.08-py3

2.1 模型部署

推荐博客TensorRT部署流程
yolov5转到onnx后进行模型的构建并保存序列化后的模型为文件。

  • 模型导出成 ONNX 格式。
  • 把 ONNX 格式模型输入给 TensorRT,并指定优化参数。
  • 使用 TensorRT 优化得到 TensorRT Engine。
  • 使用 TensorRT Engine 进行 inference。
  1. 创建builder
    这里使用了std::unqique_ptr,只能指针包装我们的builder,实现自动管理指针生命周期。
//**************1.创建builder***************//

auto builder = std::unique_ptr<nvinferl::IBuilder>
(nvinfer1::IBuilder::createInferBuilder(sampelr::gLogger.getTRTLogger())));
if (!builder)
{
    std::cerr<<"Failed to create builder"<<std::endl;
    return -1;
}

  1. 创建网络。这里指定了explicitBatch

  2. 创建onnxparser,用于解析onnx文件

4.配置网络参数。
我们需要告诉tensorrt我们最终运行时,输入图像的范围,batch size范围。

#include <iostream>
#include "NvInfer.h"

int main() {
    // Create a logger
    nvinfer1::ILogger* logger = new nvinfer1::ILogger();

    // Create a builder
    nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(*logger);
    if (!builder) {
        std::cerr << "Failed to create builder" << std::endl;
        return -1;
    }

    // Set up builder configurations (optional)
    builder->setMaxBatchSize(1);
    builder->setMaxWorkspaceSize(1 << 30); // 1GB

    // Create a network definition
    nvinfer1::INetworkDefinition* network = builder->createNetworkV2(0U);

    // ... Add layers and define the network ...

    // Build the engine
    nvinfer1::ICudaEngine* engine = builder->buildCudaEngine(*network);

    if (!engine) {
        std::cerr << "Failed to build engine" << std::endl;
        return -1;
    }

    // Clean up
    network->destroy();
    engine->destroy();
    builder->destroy();
    logger->log(nvinfer1::ILogger::Severity::kINFO, "Finished successfully!");

    delete logger;

    return 0;
}


2.2 模型推理

推理过程

  • 读取模型文件
  • 对输入进行预处理
  • 读取模型输出
  • 后处理(NMS)

1.创建运行时
2.反序列化模型得到推理Engine
3.创建执行上下文
4.创建输入输出缓冲区管理器
5.读取视频文件,并逐帧读取图像送入模型,进行推理

#include <iostream>
#include <fstream>
#include <string>
#include <sstream>
#include <chrono>
#include <opencv2/opencv.hpp>
#include "NvInfer.h"

int main() {
    // Create a logger
    nvinfer1::ILogger* logger = new nvinfer1::ILogger();
    
    // Create a runtime
    nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(*logger);
    if (!runtime) {
        std::cerr << "Failed to create runtime" << std::endl;
        return -1;
    }

    // Deserialize the engine
    const std::string engineFilePath = "path/to/your/engine.plan";
    std::ifstream engineFile(engineFilePath, std::ios::binary);
    if (!engineFile) {
        std::cerr << "Failed to open engine file" << std::endl;
        return -1;
    }
    engineFile.seekg(0, engineFile.end);
    const int engineSize = engineFile.tellg();
    engineFile.seekg(0, engineFile.beg);
    char* engineData = new char[engineSize];
    engineFile.read(engineData, engineSize);
    engineFile.close();

    nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(engineData, engineSize, nullptr);
    if (!engine) {
        std::cerr << "Failed to deserialize engine" << std::endl;
        return -1;
    }

    delete[] engineData;

    // Create an execution context
    nvinfer1::IExecutionContext* context = engine->createExecutionContext();
    if (!context) {
        std::cerr << "Failed to create execution context" << std::endl;
        return -1;
    }

    // Create input and output buffer managers
    const int maxBatchSize = engine->getMaxBatchSize();
    nvinfer1::Dims inputDims = engine->getBindingDimensions(0);
    const int inputSize = inputDims.d[1] * inputDims.d[2] * inputDims.d[3];
    nvinfer1::Dims outputDims = engine->getBindingDimensions(1);
    const int outputSize = outputDims.d[1];

    nvinfer1::IHostMemory* inputMemory = engine->createHostMemory(engine->getBindingDataType(0), maxBatchSize * inputSize);
    void* inputBuffer = inputMemory->data();

    nvinfer1::IHostMemory* outputMemory = engine->createHostMemory(engine->getBindingDataType(1), maxBatchSize * outputSize);
    void* outputBuffer = outputMemory->data();

    // Open the video file
    const std::string videoFilePath = "path/to/your/video.mp4";
    cv::VideoCapture cap(videoFilePath);
    if (!cap.isOpened()) {
        std::cerr << "Failed to open video file" << std::endl;
        return -1;
    }

    // Loop through all frames
    cv::Mat frame;
    int frameCount = 0;
    auto startTime = std::chrono::high_resolution_clock::now();

    while (true) {
        // Read the next frame
        cap >> frame;
        if (frame.empty()) {
            break;
        }

        // Prepare the input data
        cv::Mat resizedFrame;
        cv::resize(frame, resizedFrame, cv::Size(inputDims.d[3], inputDims.d[2]));
        float* inputData = static_cast<float*>(inputBuffer) + frameCount * inputSize;

        const int channelSize = inputDims.d[2] * inputDims.d[3];
        for (int c = 0; c < inputDims.d[1]; ++c) {
            for (int h = 0; h < inputDims.d[2]; ++h) {
                for (int w = 0; w < inputDims.d[3]; ++w) {
                    const float pixel = resizedFrame.at<cv::Vec3b>(h, w)[c] / 255.0f;
                    inputData[c * channelSize + h * inputDims.d[3] + w] = pixel;
                }
            }
        }

        // Run inference
        context->executeV2(&inputBuffer, &outputBuffer);

        // Process the output data
        float* outputData = static_cast<float*>(outputBuffer) + frameCount * outputSize;

        // ... Process the output data ...

        ++frameCount;
    }

    // Measure and print the inference time
    auto endTime = std::chrono::high_resolution_clock::now();
    auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
    std::cout << "Inference time: " << elapsedTime.count() << "ms" << std::endl;

    // Clean up
    inputMemory->destroy();
    outputMemory->destroy();
    context->destroy();
    engine->destroy();
    runtime->destroy();
    logger->log(nvinfer1::ILogger::Severity::kINFO, "Finished successfully!");

    delete logger;

    return 0;
}

使用cmake进行构建,cmake相关知识可看cmake学习笔记

cmake -S .-B build
cmake --build build
./build/build
./build/build ./weights/yolo5s_person.onnx
#执行推理
./build/runtime

视频文件

./weights/yolov5.engine ./media/c3.mp4

总结

接下来是了解TensorRT插件,Int8量化流程。

推荐视频链接:https://www.bilibili.com/video/BV1jj411Z7wG/?spm_id_from=333.337.search-card.all.click&vd_source=ce674108fa2e19e5322d710724193487

推荐链接:https://github.com/NVIDIA/trt-samples-for-hackathon-cn/tree/master/cookbook


http://www.niftyadmin.cn/n/5321136.html

相关文章

世界人口数据分析与探索

文章目录 世界人口数据集介绍数据集 1&#xff1a;世界国家统计数据&#xff1a;数据集 2&#xff1a;世界人口详细信息&#xff08;2023 年&#xff09;&#xff1a;数据集 3&#xff1a;按年份划分的世界人口&#xff08;1950-2023&#xff09;&#xff1a; 数据分析导入必要…

C++力扣题目530--二叉搜索树的最小绝对值

给你一个二叉搜索树的根节点 root &#xff0c;返回 树中任意两不同节点值之间的最小差值 。 差值是一个正数&#xff0c;其数值等于两值之差的绝对值。 示例 1&#xff1a; 输入&#xff1a;root [4,2,6,1,3] 输出&#xff1a;1示例 2&#xff1a; 输入&#xff1a;root […

D25XB80-ASEMI开关电源桥堆D25XB80

编辑&#xff1a;ll D25XB80-ASEMI开关电源桥堆D25XB80 型号&#xff1a;D25XB80 品牌&#xff1a;ASEMI 封装&#xff1a;GBJ-5&#xff08;带康铜丝&#xff09; 特性&#xff1a;插件、整流桥 平均正向整流电流&#xff08;Id&#xff09;&#xff1a;25A 最大反向击…

白学的小知识[css3轮播]

代码如下: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>轮播</title><style>.boss {position: relative;width: 200px;height: 200px;overflow: hidden;}.boss>div {width: 10000p…

【iOS】数据存储方式总结(持久化)沙盒结构

在iOS开发中&#xff0c;我们经常性地需要存储一些状态和数据&#xff0c;比如用户对于App的相关设置、需要在本地缓存的数据等等&#xff0c;本篇文章将介绍六个主要的数据存储方式 iOS中数据存储方式&#xff08;数据持久化&#xff09; 根据要存储的数据大小、存储数据以及…

搭建Docker私有镜像服务器

一、前言 1、本文主要内容 基于Decker Desktop&Docker Registry构建Docker私有镜像服务器测试在CentOS 7上基于Docker Registry搭建公共Docker镜像服务器修改Docker Engine配置以HTTP协议访问Docker Registry修改Docker Engine配置通过域名访问Docker Registry配置SSL证书…

【python】——turtle动态画

&#x1f383;个人专栏&#xff1a; &#x1f42c; 算法设计与分析&#xff1a;算法设计与分析_IT闫的博客-CSDN博客 &#x1f433;Java基础&#xff1a;Java基础_IT闫的博客-CSDN博客 &#x1f40b;c语言&#xff1a;c语言_IT闫的博客-CSDN博客 &#x1f41f;MySQL&#xff1a…

inflate流程分析

一.inflate的三参数重载方法else里面逻辑 我们先看到setContentView里面的inflate的调用链&#xff1a; public View inflate(LayoutRes int resource, Nullable ViewGroup root) {return inflate(resource, root, root ! null);}public View inflate(LayoutRes int resource…