一、背景
由于部署在硬件上的时候,后处理部分硬件处理不支持,需要挪到cpu上处理。
二、转int8.tflite版本时挪出后处理部分
- 需要修改的文件
models/tf.py
将TFDetect(keras.layers.Layer)
的call()
函数修改为下面部分
def call(self, inputs):
print('************* deploy ******************')
z = [] # inference output
x = []
for i in range(self.nl):
### 原始 ####
# x.append(self.m[i](inputs[i]))
######## 新增 ########
if True:
temp = self.m[i](inputs[i])
z.append(tf.reshape(temp,[-1,6]))
# print('shape', self.m[i](inputs[i]).reshape(-1,6).shape)
continue
######## 新增 ########
# x(bs,20,20,255) to x(bs,3,20,20,85)
ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
x[i] = tf.reshape(x[i], [-1, ny * nx, self.na, self.no])
if not self.training: # inference
y = x[i]
grid = tf.transpose(self.grid[i], [0, 2, 1, 3]) - 0.5
anchor_grid = tf.transpose(self.anchor_grid[i], [0, 2, 1, 3]) * 4
xy = (tf.sigmoid(y[..., 0:2]) * 2 + grid) * self.stride[i] # xy
wh = tf.sigmoid(y[..., 2:4]) ** 2 * anchor_grid
# Normalize xywh to 0-1 to reduce calibration error
xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
y = tf.concat([xy, wh, tf.sigmoid(y[..., 4:5 + self.nc]), y[..., 5 + self.nc:]], -1)
z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
######## 新增 ########
return (tf.concat(z, 0), )
######## 新增 ########
######## 原始 ########
# return tf.transpose(x, [0, 2, 1, 3]) if self.training else (tf.concat(z, 1), ) ## org
- 模型转换代码
python export.py --weights ckpt/best_620.pt --imgsz 320 --opset 10 --include tflite --int8
- 变化如下图,坐标变换部分删掉了(这里使用的是320320的,检测头删掉了4040分辨率部分)
二、转onnx版本时挪出后处理部分
- 需要修改的文件
models/yolo.py
将Detect(nn.Module)
的forward()
函数修改为下面部分
def forward(self, x):
z = [] # inference output
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
######## 新增 ########
if True:
z.append(x[i].reshape(18, -1))
continue
######## 新增 ########
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if not self.training: # inference
if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
if isinstance(self, Segment): # (boxes + masks)
xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4)
xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i] # xy
wh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, conf.sigmoid(), mask), 4)
else: # Detect (boxes only)
xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy
wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, conf), 4)
z.append(y.view(bs, self.na * nx * ny, self.no))
print('self.export',self.export)
return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
将坐标变换部分跳过,不执行
2. 模型转换代码
python export.py --weights ckpt/best_620.pt --imgsz 320 --opset 10 --include onnx