转yolov5模型时,不转后处理部分

news/2024/7/11 1:10:26 标签: YOLO, 转换模型

一、背景

由于部署在硬件上的时候,后处理部分硬件处理不支持,需要挪到cpu上处理。

二、转int8.tflite版本时挪出后处理部分

  1. 需要修改的文件models/tf.py
    TFDetect(keras.layers.Layer)call()函数修改为下面部分
    def call(self, inputs):
        print('************* deploy ******************')
        z = []  # inference output
        x = []
        for i in range(self.nl):
        	###  原始 ####
            # x.append(self.m[i](inputs[i])) 
            ######## 新增 ########
            if True: 
                temp = self.m[i](inputs[i])
                z.append(tf.reshape(temp,[-1,6]))
                # print('shape', self.m[i](inputs[i]).reshape(-1,6).shape)
            continue
            ######## 新增 ########
            # x(bs,20,20,255) to x(bs,3,20,20,85)
            ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
            x[i] = tf.reshape(x[i], [-1, ny * nx, self.na, self.no])

            if not self.training:  # inference
                y = x[i]
                grid = tf.transpose(self.grid[i], [0, 2, 1, 3]) - 0.5
                anchor_grid = tf.transpose(self.anchor_grid[i], [0, 2, 1, 3]) * 4
                xy = (tf.sigmoid(y[..., 0:2]) * 2 + grid) * self.stride[i]  # xy
                wh = tf.sigmoid(y[..., 2:4]) ** 2 * anchor_grid
                # Normalize xywh to 0-1 to reduce calibration error
                xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
                wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
                y = tf.concat([xy, wh, tf.sigmoid(y[..., 4:5 + self.nc]), y[..., 5 + self.nc:]], -1)
                z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
        ######## 新增 ########
        return (tf.concat(z, 0), )
        ######## 新增 ########
        ######## 原始 ########
        # return tf.transpose(x, [0, 2, 1, 3]) if self.training else (tf.concat(z, 1), ) ## org
  1. 模型转换代码
python export.py --weights ckpt/best_620.pt --imgsz 320   --opset 10 --include tflite --int8
  1. 变化如下图,坐标变换部分删掉了(这里使用的是320320的,检测头删掉了4040分辨率部分)
    在这里插入图片描述

二、转onnx版本时挪出后处理部分

  1. 需要修改的文件models/yolo.py
    Detect(nn.Module)forward()函数修改为下面部分
    def forward(self, x):
        z = []  # inference output
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
           
            ######## 新增 ########
            if True:
                z.append(x[i].reshape(18, -1))
            continue
            ######## 新增 ########
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

            if not self.training:  # inference
                if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

                if isinstance(self, Segment):  # (boxes + masks)
                    xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4)
                    xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i]  # xy
                    wh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, conf.sigmoid(), mask), 4)
                else:  # Detect (boxes only)
                    xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
                    xy = (xy * 2 + self.grid[i]) * self.stride[i]  # xy
                    wh = (wh * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, conf), 4)
                z.append(y.view(bs, self.na * nx * ny, self.no))
        print('self.export',self.export)

        return  x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)

将坐标变换部分跳过,不执行
2. 模型转换代码

python export.py --weights ckpt/best_620.pt --imgsz 320   --opset 10 --include onnx

http://www.niftyadmin.cn/n/5346802.html

相关文章

iOS 升级Xcode15报错问题

一、Xcode 15 libarclite 缺失问题 升级到Xcode 15运行项目报错,报错信息如下:SDK does not contain libarclite at the path /Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/arc/libarclite_iphonesimulato…

《开始使用PyQT》 第01章 PyQT入门 02 安装Python3和PyQT6

02 安装Python3和PyQT6 《开始使用PyQT》 第01章 PyQT入门 02 安装Python3和PyQT6 So that all readers are on the same page, let’s begin by installing or updating your version of Python. 为了让所有读者都能理解,让我们从安装或更新 Python 版本开始。 …

牛刀小试 - C++ 推箱子小游戏

参考文档 C笔记:推箱子小游戏 copy函数 memcpy()函数用法(可复制数组) 使用memcpy踩出来的坑,值得注意 完整代码 /********************************************************************* 程序名:推箱子小游戏 说明&#x…

Hbuilder从gitlab上面拉取项目

要先下载TortoiseGit-2.15.0.0-64bit这个软件 在HBuilder中从GitLab上拉取项目,请按照以下步骤操作: 1. 打开HBuilder,点击左上角的“文件”菜单,然后选择“新建”->“项目”。 2. 在弹出的对话框中,选择“从Git导…

哪些 3D 建模软件值得推荐?

云端地球是一款免费的在线实景三维建模软件,不需要复杂的技巧,只要需要手机,多拍几张照片,就可以得到完整的三维模型! 无论是大场景倾斜摄影测量还是小场景、小物体建模,都可以通过云端地球将二维数据向三…

Git 教程 | 将本地修改后的文件推送到 Github 指定远程分支上

Git 是一种分布式版本控制系统,用于敏捷高效地处理任何大小的项目。它是由 Linus Torvalds 为了帮助管理 Linux 内核开发而开发的开源版本控制软件。Git 的本地克隆就是一个完整的版本控制存储库,无论脱机还是远程都能轻松工作。开发人员会在本地提交其工…

Redis 持久化详解:RDB 与 AOF 的配置、触发机制和实际测试

什么是持久化? 就是 Redis 将内存数据持久化到硬盘,避免从数据库恢复数据。之所以避免从数据库恢复数据是因为后端数据通常有性能瓶颈,大量数据从数据库恢复可能会给数据库造成巨大压力。 Redis 持久化通常有 RDB 和 AOF 两种方式&#xff…

opencv010 卷积02(方盒滤波和均值滤波)

今天继续学习滤波器的相关知识!这篇比较简单,也短一些,明天写高斯滤波 方盒滤波 boxFilter(scr, ddepth, ksize[, dst[, anchor[, normalize[, borderType]]]]) 方盒滤波的卷积核如下: normalize(标准化&#xff0…