OpenCV实战——使用YOLO进行目标检测

news/2024/7/10 3:04:50 标签: opencv, YOLO, 目标检测

OpenCV实战——使用YOLO进行目标检测

0. 前言

在本节中,我们将使用 YOLO 算法执行目标检测目标检测是计算机视觉中的一项常见任务,借助深度学习技术,我们可以实现高准确度的检测。YOLOCOCO 数据集(数据集中包含 80 个类别和超过 300000 张图像)中可以达到 60.6mAP (20 fps) 或 33mAP (220 fps)。

YOLO__3">1. YOLO 模型简介

YOLO 是深度学习网络目标检测的一类重要分枝,其将输入图像划分为 SxS 网格。对于每个网格,YOLO 检查 B 个边界框,然后深度学习模型提取每个网格的边界框、包含可能对象的置信度以及每个边界框中(训练数据集中)每个类别的置信度:

<a class=YOLO 网格" />
YOLO 使用 19x19 个网格,每个网格包含 5 个边界框,训练数据集中包含 80 个类别。网络的输出结果为 19x19x425,其中 425 来自边界框 (x,y,width,height)、边界框中是否包含对象的置信度、对象属于每个类别(共 80 个类别)的置信度:

5_bounding box*(x,y,w,h,object_confidence,classify_confidence[80])=5*(4 + 1 + 80)

YOLO 架构基于 DarkNet (包含 53 层网络),YOLODarkNet 的基础上增加了 53 层网络,共 106 层网络。如果我们需要预测速度更快的架构,可以使用包含较少网络层 TinyYOLO 架构。

YOLO__12">2. 基于 YOLO 实现目标检测

在本节中,我们使用与深度学习简介一节相同的函数和类来加载模型、预处理图像和预测结果,同时介绍非极大值抑制 (non-maximum suppression, NMS),以及绘制带有标签的预测结果:

(1) 创建 object_detection_yolo.cpp 文件,导入所需的头文件,初始化所需的全局变量:

#include <fstream>
#include <sstream>
#include <iostream>

#include <opencv2/core.hpp>
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>

using namespace cv;
using namespace dnn;
using namespace std;

// Initialize the parameters
float confThreshold = 0.5; // Confidence threshold
float nmsThreshold = 0.4;  // Non-maximum suppression threshold
int inpWidth = 416;  // Width of network's input image
int inpHeight = 416; // Height of network's input image
vector<string> classes;

(2) 我们从 main 函数开始,首先读取存储模型可以预测的所有类别的文件:

int main(int argc, char** argv) {
    // 加载类别名
    string classesFile = "data/coco.names";
    ifstream ifs(classesFile.c_str());
    string line;
    while (getline(ifs, line)) classes.push_back(line);

(3) 使用模型定义和权重文件加载模型:

    // 提供模型的配置和权重文件
    String modelConfiguration = "data/yolov3.cfg";
    String modelWeights = "data/yolov3.weights";
    // 加载网络
    Net net = readNetFromDarknet(modelConfiguration, modelWeights);

(4) 加载图像并将其转换为 blob

    Mat input, blob;
    input= imread(argv[1]);
    if (input.empty()) {
        cout << "No input image" << endl;
        return 0;
    }
    // 创建输入
    blobFromImage(input, blob, 1/255.0, Size(inpWidth, inpHeight), Scalar(0,0,0), true, false);

(5) 使用 setInputforward 函数检测所有对象及其类别:

    // 设定网络输入
    net.setInput(blob);
    // 执行前向传播
    vector<Mat> outs;
    net.forward(outs, getOutputsNames(net));

(6) 对输出结果进行后处理,绘制检测到的目标及预测置信度:

    // 移除低置信度边界框
    postprocess(input, outs);

(7)postprocess 函数中,存储所有预测置信度高于 confThreshold 的边界框框:

    vector<int> classIds;
    vector<float> confidences;
    vector<Rect> boxes;
    for (size_t i = 0; i < outs.size(); ++i) {
        // 扫描网络输出的所有边界框,仅保留具有高置信度分数的边界框
        // 将边界框的类标签指定为边界框得分最高的类别
        float* data = (float*)outs[i].data;
        for (int j = 0; j < outs[i].rows; ++j, data += outs[i].cols) {
            Mat scores = outs[i].row(j).colRange(5, outs[i].cols);
            Point classIdPoint;
            double confidence;
            // 获取最大分数的值和位置
            minMaxLoc(scores, 0, &confidence, 0, &classIdPoint);
            if (confidence > confThreshold) {
                int centerX = (int)(data[0] * frame.cols);
                int centerY = (int)(data[1] * frame.rows);
                int width = (int)(data[2] * frame.cols);
                int height = (int)(data[3] * frame.rows);
                int left = centerX - width / 2;
                int top = centerY - height / 2;
                
                classIds.push_back(classIdPoint.x);
                confidences.push_back((float)confidence);
                boxes.push_back(Rect(left, top, width, height));
            }
        }
    }

(8) 使用 NMSBoxes 函数应用非极大值抑制,只得到具有高置信度的非重叠边界框并进行绘制:

    // 执行非极大值抑制
    // 消除具有较低置信度的冗余重叠边界框
    vector<int> indices;
    NMSBoxes(boxes, confidences, confThreshold, nmsThreshold, indices);
    for (size_t i = 0; i < indices.size(); ++i) {
        int idx = indices[i];
        Rect box = boxes[idx];
        drawPred(classIds[idx], confidences[idx], box.x, box.y,
                 box.x + box.width, box.y + box.height, frame);
    }

使用 YOLO 执行目标检测的结果如下所示:

检测结果

3. 完整代码

完整代码 object_detection_yolo.cpp 如下所示:

#include <fstream>
#include <sstream>
#include <iostream>

#include <opencv2/core/core.hpp>
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>

using namespace cv;
using namespace dnn;
using namespace std;

// 初始化参数
float confThreshold = 0.5;  // 置信度阈值
float nmsThreshold = 0.4;   // 非极大值抑制阈值
int inpWidth = 416;         // 网络输入图像宽度
int inpHeight = 416;        // 网络输入图像高度
vector<string> classes;

// 绘制预测边界框
void drawPred(int classId, float conf, int left, int top, int right, int bottom, Mat& frame) {
    // 绘制显示边界框矩形
    rectangle(frame, Point(left, top), Point(right, bottom), Scalar(255, 255, 255), 1);
    // 获取类别名的标签及其置信度
    string conf_label = format("%.2f", conf);
    string label="";
    if (!classes.empty()) {
        label = classes[classId] + ":" + conf_label;
    }
    // 在边界框顶部显示标签
    int baseLine;
    Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
    top = max(top, labelSize.height);
    rectangle(frame, Point(left, top - labelSize.height), Point(left + labelSize.width, top + baseLine), Scalar(255, 255, 255), FILLED);
    putText(frame, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0,0,0),1,LINE_AA);
}

// 使用非最大值抑制移除置信度低的边界框
void postprocess(Mat& frame, const vector<Mat>& outs) {
    vector<int> classIds;
    vector<float> confidences;
    vector<Rect> boxes;
    for (size_t i = 0; i < outs.size(); ++i) {
        // 扫描网络输出的所有边界框,仅保留具有高置信度分数的边界框
        // 将边界框的类标签指定为边界框得分最高的类别
        float* data = (float*)outs[i].data;
        for (int j = 0; j < outs[i].rows; ++j, data += outs[i].cols) {
            Mat scores = outs[i].row(j).colRange(5, outs[i].cols);
            Point classIdPoint;
            double confidence;
            // 获取最大分数的值和位置
            minMaxLoc(scores, 0, &confidence, 0, &classIdPoint);
            if (confidence > confThreshold) {
                int centerX = (int)(data[0] * frame.cols);
                int centerY = (int)(data[1] * frame.rows);
                int width = (int)(data[2] * frame.cols);
                int height = (int)(data[3] * frame.rows);
                int left = centerX - width / 2;
                int top = centerY - height / 2;
                
                classIds.push_back(classIdPoint.x);
                confidences.push_back((float)confidence);
                boxes.push_back(Rect(left, top, width, height));
            }
        }
    }
    
    // 执行非极大值抑制
    // 消除具有较低置信度的冗余重叠边界框
    vector<int> indices;
    NMSBoxes(boxes, confidences, confThreshold, nmsThreshold, indices);
    for (size_t i = 0; i < indices.size(); ++i) {
        int idx = indices[i];
        Rect box = boxes[idx];
        drawPred(classIds[idx], confidences[idx], box.x, box.y,
                 box.x + box.width, box.y + box.height, frame);
    }
}

// 获取输出层的名称
vector<String> getOutputsNames(const Net& net) {
    static vector<String> names;
    if (names.empty()) {
        // 获取输出层的索引
        vector<int> outLayers = net.getUnconnectedOutLayers();
        // 获取网络中所有层的名称
        vector<String> layersNames = net.getLayerNames();
        // 获取names变量中输出层的名称
        names.resize(outLayers.size());
        for (size_t i = 0; i < outLayers.size(); ++i) {
            names[i] = layersNames[outLayers[i] - 1];
        }
    }
    return names;
}

int main(int argc, char** argv) {
    // 加载类别名
    string classesFile = "data/coco.names";
    ifstream ifs(classesFile.c_str());
    string line;
    while (getline(ifs, line)) classes.push_back(line);
    // 提供模型的配置和权重文件
    String modelConfiguration = "data/yolov3.cfg";
    String modelWeights = "data/yolov3.weights";
    // 加载网络
    Net net = readNetFromDarknet(modelConfiguration, modelWeights);
    net.setPreferableBackend(DNN_BACKEND_OPENCV);
    net.setPreferableTarget(DNN_TARGET_CPU);
    
    Mat input, blob;
    input= imread(argv[1]);
    if (input.empty()) {
        cout << "No input image" << endl;
        return 0;
    }
    // 创建输入
    blobFromImage(input, blob, 1/255.0, Size(inpWidth, inpHeight), Scalar(0,0,0), true, false);
    // 设定网络输入
    net.setInput(blob);
    // 执行前向传播
    vector<Mat> outs;
    net.forward(outs, getOutputsNames(net));
    // 移除低置信度边界框
    postprocess(input, outs);
    vector<double> layersTimes;
    double freq = getTickFrequency() / 1000;
    double t = net.getPerfProfile(layersTimes) / freq;
    string label = format("Inference time for compute the image : %.2f ms", t);
    cout << label << endl;
    
    imshow("YOLOv3", input);
    waitKey(0);
    return 0;
}

相关链接

OpenCV实战(1)——OpenCV与图像处理基础
OpenCV实战(2)——OpenCV核心数据结构
OpenCV实战(3)——图像感兴趣区域
OpenCV实战(4)——像素操作
OpenCV实战(5)——图像运算详解
OpenCV实战(6)——OpenCV策略设计模式
OpenCV实战(7)——OpenCV色彩空间转换
OpenCV实战(8)——直方图详解
OpenCV实战(9)——基于反向投影直方图检测图像内容
OpenCV实战(10)——积分图像详解
OpenCV实战(11)——形态学变换详解
OpenCV实战(12)——图像滤波详解
OpenCV实战(13)——高通滤波器及其应用
OpenCV实战(14)——图像线条提取
OpenCV实战(15)——轮廓检测详解
OpenCV实战(16)——角点检测详解
OpenCV实战(17)——FAST特征点检测
OpenCV实战(18)——特征匹配
OpenCV实战(19)——特征描述符
OpenCV实战(20)——图像投影关系
OpenCV实战(21)——基于随机样本一致匹配图像
OpenCV实战(22)——单应性及其应用
OpenCV实战(23)——相机标定
OpenCV实战(24)——相机姿态估计
OpenCV实战(25)——3D场景重建
OpenCV实战(26)——视频序列处理
OpenCV实战(27)——追踪视频中的特征点
OpenCV实战(28)——光流估计
OpenCV实战(29)——视频对象追踪
OpenCV实战(30)——OpenCV与机器学习的碰撞
OpenCV实战(31)——基于级联Haar特征的目标检测
OpenCV实战(32)——使用SVM和定向梯度直方图执行目标检测
OpenCV实战(33)——OpenCV与深度学习的碰撞


http://www.niftyadmin.cn/n/5111541.html

相关文章

MyBatis篇---第一篇

系列文章目录 文章目录 系列文章目录一、什么是MyBatis二、说说MyBatis的优点和缺点三、#{}和${}的区别是什么?一、什么是MyBatis (1)Mybatis是一个半ORM(对象关系映射)框架,它内部封装了JDBC,开发时只需要关注SQL 语句本身,不需要花费精力去处理加载驱动、创建连接、…

Telent

Telnet协议是一种远程登录协议&#xff0c;它允许用户通过网络连接到远程主机并在远程主机上执行命令。 Telnet协议是TCP/IP协议族中的一员&#xff0c;是Internet远程登录服务的标准协议和主要方式。它为用户提供了在本地计算机上完成远程主机工作的能力。在终端使用者的电脑…

算法学习之 背包01问题 , 备战leecode

来看题目 我们分析一下题目&#xff0c;首先我们要排序&#xff0c;这有助于我们得到最大的值&#xff0c;我们要得到一个递推公式 代码如下: class Solution { public:int maxSatisfaction(vector<int>& satisfaction) {int n satisfaction.size();vector<v…

42917-2023 消光制品用聚氯乙烯树脂

1 范围 本文件规定了消光制品用聚氯乙烯树脂的分类、技术要求、取样、试验方法、检验规则及标志、随行 文件、包装、运输和贮存。 本文件适用于氯乙烯与交联剂悬浮共聚所制得的用于生产消光制品的聚氯乙烯树脂。 2 规范性引用文件 下列文件中的内容通过文中的规范性引用而…

模拟量开关量防抖算法(模拟量超限报警功能块)

模拟量信号的防抖,除了了可以采用延时方法。还可以采用死区过滤器实现,死区过滤器详细算法解读和完整源代码,请查看下面文章博客: PLC信号处理系列之死区滤波器(DeadZone)-CSDN博客(*死区滤波器*)ELSErValue:=rX;END_IF;博途PLC信号处理系列之限幅消抖滤波_RXXW_Dor的博…

Java开发规范记录

不要使用 count(column)或 count(1)来替代 count(*)&#xff0c;count(*)是 SQL92 定义的 标准统计行数的语法&#xff0c;跟数据库无关&#xff0c;跟 NULL 和非 NULL 无关。 注意&#xff1a;count(*)会统计值为 NULL 的行&#xff0c;而 count(列名)不会统计此列为 NULL 值的…

黔院长 | 不忘初心在逆境中前行!

随着我国经济不断发展进步&#xff0c;以及人口老龄化程度的加深&#xff0c;加上自然环境质量的下降&#xff0c;人们越来越关注和重视自己的健康问题。据世界卫生组织相关数据显示&#xff0c;目前我国的亚健康率已经高达95%&#xff01;健康发展刻不容缓&#xff01; 国家政…

MYSQL常用函数详解

今天查缺陷发现同事写的一个MYSQL的SQL中用到函数JSON_CONTAINS&#xff0c;我当时第一反应是这个函数是Mysql8新加的么&#xff1f;原来小丑尽是我自己。 有必要巩固一下Mysql函数知识&#xff0c;并记录一下。&#xff08;如果对您也有用&#xff0c;麻烦您动动发财的手点个赞…