YOLOv5 分类模型 预处理 OpenCV实现

news/2024/7/11 1:16:29 标签: YOLO, 分类, opencv

YOLOv5 分类模型 预处理 OpenCV实现

flyfish

YOLOv5 分类模型 预处理 PIL 实现
YOLOv5 分类模型 OpenCV和PIL两者实现预处理的差异

YOLOv5 分类模型 数据集加载 1 样本处理
YOLOv5 分类模型 数据集加载 2 切片处理
YOLOv5 分类模型 数据集加载 3 自定义类别

YOLOv5 分类模型的预处理(1) Resize 和 CenterCrop
YOLOv5 分类模型的预处理(2)ToTensor 和 Normalize

YOLOv5 分类模型 Top 1和Top 5 指标说明
YOLOv5 分类模型 Top 1和Top 5 指标实现

判断图像是否是np.ndarray类型和维度

OpenCV读取一张图像时,类型类型就是<class 'numpy.ndarray'>,这里判断图像是否是np.ndarray类型
dim是dimension维度的缩写,shape属性的长度也是它的ndim
灰度图的shape为HW,二个维度
RGB图的shape为HWC,三个维度
在这里插入图片描述

def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})

实现ToTensor和Normalize

def totensor_normalize(img):
    print("preprocess:",img.shape)
    images = (img/255-mean)/std
    images = images.transpose((2, 0, 1))# HWC to CHW
    images = np.ascontiguousarray(images)
    return images

实现Resize

插值可以是以下参数

# 'nearest': cv2.INTER_NEAREST,
# 'bilinear': cv2.INTER_LINEAR,
# 'area': cv2.INTER_AREA,
# 'bicubic': cv2.INTER_CUBIC,
# 'lanczos': cv2.INTER_LANCZOS4
def resize(img, size, interpolation=cv2.INTER_LINEAR):
    r"""Resize the input numpy ndarray to the given size.
    Args:
        img (numpy ndarray): Image to be resized.
        size: like pytroch about size interpretation flyfish.
        interpolation (int, optional): Desired interpolation. Default is``cv2.INTER_LINEAR``  
    Returns:
        numpy Image: Resized image.like opencv
    """
    if not _is_numpy_image(img):
        raise TypeError('img should be numpy image. Got {}'.format(type(img)))
    if not (isinstance(size, int) or (isinstance(size, collections.abc.Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))
    h, w = img.shape[0], img.shape[1]

    if isinstance(size, int):
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
        else:
            oh = size
            ow = int(size * w / h)
    else:
        ow, oh = size[1], size[0]
    output = cv2.resize(img, dsize=(ow, oh), interpolation=interpolation)
    if img.shape[2] == 1:
        return output[:, :, np.newaxis]
    else:
        return output

实现CenterCrop

def crop(img, i, j, h, w):
    """Crop the given Image flyfish.
    Args:
        img (numpy ndarray): Image to be cropped.
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
        h: Height of the cropped image.
        w: Width of the cropped image.
    Returns:
        numpy ndarray: Cropped image.
    """
    if not _is_numpy_image(img):
        raise TypeError('img should be numpy image. Got {}'.format(type(img)))

    return img[i:i + h, j:j + w, :]


def center_crop(img, output_size):
    if isinstance(output_size, numbers.Number):
        output_size = (int(output_size), int(output_size))
    h, w = img.shape[0:2]
    th, tw = output_size
    i = int(round((h - th) / 2.))
    j = int(round((w - tw) / 2.))
    return crop(img, i, j, th, tw)

完整

import time
from models.common import DetectMultiBackend
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
import collections
import torch
import numbers


classes_name=['n02086240', 'n02087394', 'n02088364', 'n02089973', 'n02093754', 'n02096294', 'n02099601', 'n02105641', 'n02111889', 'n02115641']

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


def totensor_normalize(img):
    print("preprocess:",img.shape)
    images = (img/255-mean)/std
    images = images.transpose((2, 0, 1))# HWC to CHW
    images = np.ascontiguousarray(images)
    return images

def resize(img, size, interpolation=cv2.INTER_LINEAR):
    r"""Resize the input numpy ndarray to the given size.
    Args:
        img (numpy ndarray): Image to be resized.
        size: like pytroch about size interpretation flyfish.
        interpolation (int, optional): Desired interpolation. Default is``cv2.INTER_LINEAR``  
    Returns:
        numpy Image: Resized image.like opencv
    """
    if not _is_numpy_image(img):
        raise TypeError('img should be numpy image. Got {}'.format(type(img)))
    if not (isinstance(size, int) or (isinstance(size, collections.abc.Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))
    h, w = img.shape[0], img.shape[1]

    if isinstance(size, int):
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
        else:
            oh = size
            ow = int(size * w / h)
    else:
        ow, oh = size[1], size[0]
    output = cv2.resize(img, dsize=(ow, oh), interpolation=interpolation)
    if img.shape[2] == 1:
        return output[:, :, np.newaxis]
    else:
        return output

def crop(img, i, j, h, w):
    """Crop the given Image flyfish.
    Args:
        img (numpy ndarray): Image to be cropped.
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
        h: Height of the cropped image.
        w: Width of the cropped image.
    Returns:
        numpy ndarray: Cropped image.
    """
    if not _is_numpy_image(img):
        raise TypeError('img should be numpy image. Got {}'.format(type(img)))

    return img[i:i + h, j:j + w, :]


def center_crop(img, output_size):
    if isinstance(output_size, numbers.Number):
        output_size = (int(output_size), int(output_size))
    h, w = img.shape[0:2]
    th, tw = output_size
    i = int(round((h - th) / 2.))
    j = int(round((w - tw) / 2.))
    return crop(img, i, j, th, tw)

class DatasetFolder:

    def __init__(
        self,
        root: str,

    ) -> None:
        self.root = root

        if classes_name is None or not classes_name:
            classes, class_to_idx = self.find_classes(self.root)
            print("not classes_name")

        else:
            classes = classes_name
            class_to_idx ={cls_name: i for i, cls_name in enumerate(classes)}
            print("is classes_name")

        print("classes:",classes)
        
        print("class_to_idx:",class_to_idx)
        samples = self.make_dataset(self.root, class_to_idx)

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Optional[Dict[str, int]] = None,

    ) -> List[Tuple[str, int]]:

        directory = os.path.expanduser(directory)

        if class_to_idx is None:
            _, class_to_idx = self.find_classes(directory)
        elif not class_to_idx:
            raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

        instances = []
        available_classes = set()
        for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_dir = os.path.join(directory, target_class)
            if not os.path.isdir(target_dir):
                continue
            for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                for fname in sorted(fnames):
                    path = os.path.join(root, fname)
                    if 1:  # 验证:
                        item = path, class_index
                        instances.append(item)

                        if target_class not in available_classes:
                            available_classes.add(target_class)

        empty_classes = set(class_to_idx.keys()) - available_classes
        if empty_classes:
            msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "

        return instances

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:

        classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
        if not classes:
            raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

    def __getitem__(self, index: int) -> Tuple[Any, Any]:

        path, target = self.samples[index]
        sample = self.loader(path)

        return sample, target

    def __len__(self) -> int:
        return len(self.samples)

    def loader(self, path):
        print("path:", path)
        img = cv2.imread(path)  # BGR HWC
        img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)#RGB
        print("type:",type(img))
        return img


def time_sync():
    return time.time()



dataset = DatasetFolder(root="/media/flyfish/datasets/imagewoof/val")
weights = "/home/classes.pt"
device = "cpu"
model = DetectMultiBackend(weights, device=device, dnn=False, fp16=False)
model.eval()



def classify_transforms(img):
    img=resize(img,224)
    img=center_crop(img,224)
    img=totensor_normalize(img)
    return img;

pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
# current batch size =1
for i, (images, labels) in enumerate(dataset):
    print("i:", i)
    print(images.shape, labels)
    im = classify_transforms(images)


    images=torch.from_numpy(im).to(torch.float32) # numpy to tensor
    images = images.unsqueeze(0).to("cpu")
 
    print(images.shape)


        
    t1 = time_sync()
    images = images.to(device, non_blocking=True)
    t2 = time_sync()
    # dt[0] += t2 - t1

    y = model(images)
    y=y.numpy()
   
    print("y:", y)
    t3 = time_sync()
    # dt[1] += t3 - t2

    tmp1=y.argsort()[:,::-1][:, :5]
   
    print("tmp1:", tmp1)
    pred.append(tmp1)

    print("labels:", labels)

    
    targets.append(labels)

    print("for pred:", pred)  # list
    print("for targets:", targets)  # list

    # dt[2] += time_sync() - t3


pred, targets = np.concatenate(pred), np.array(targets)
print("pred:", pred)
print("pred:", pred.shape)
print("targets:", targets)
print("targets:", targets.shape)
correct = ((targets[:, None] == pred)).astype(np.float32)
print("correct:", correct.shape)
print("correct:", correct)
acc = np.stack((correct[:, 0], correct.max(1)), axis=1)  # (top1, top5) accuracy
print("acc:", acc.shape)
print("acc:", acc)
top = acc.mean(0)
print("top1:", top[0])
print("top5:", top[1])

结果

pred: [[0 3 6 2 1]
 [0 7 2 9 3]
 [0 5 6 2 9]
 ...
 [9 8 7 6 1]
 [9 3 6 7 0]
 [9 5 0 2 7]]
pred: (3929, 5)
targets: [0 0 0 ... 9 9 9]
targets: (3929,)
correct: (3929, 5)
correct: [[          1           0           0           0           0]
 [          1           0           0           0           0]
 [          1           0           0           0           0]
 ...
 [          1           0           0           0           0]
 [          1           0           0           0           0]
 [          1           0           0           0           0]]
acc: (3929, 2)
acc: [[          1           1]
 [          1           1]
 [          1           1]
 ...
 [          1           1]
 [          1           1]
 [          1           1]]
top1: 0.86230594
top5: 0.98167473

http://www.niftyadmin.cn/n/5211313.html

相关文章

python+pytest接口自动化-requests发送post请求

简介 在HTTP协议中&#xff0c;与get请求把请求参数直接放在url中不同&#xff0c;post请求的请求数据需通过消息主体(request body)中传递。 且协议中并没有规定post请求的请求数据必须使用什么样的编码方式&#xff0c;所以其请求数据可以有不同的编码方式&#xff0c;服务…

MySQL 优化器 MRR

什么是 MRR MRR 的全称是 Multi-Range Read Optimization&#xff0c;是优化器将随机 IO 转化为顺序 IO 以降低查询过程中 IO 开销的一种手段&#xff0c;咱们对比一下 mrron & mrroff 时的执行计划&#xff1a; 其中表结构如下&#xff1a; mysql> show create tabl…

生态对对碰|华为OceanStor闪存存储与OceanBase完成兼容性互认证!

近日&#xff0c;北京奥星贝斯科技有限公司 OceanBase 数据库与华为技术有限公司 OceanStor Dorado 全闪存存储系统、OceanStor 混合闪存存储系统完成兼容性互认证。 OceanBase 数据库挂载 OceanStor 闪存存储做为数据盘和日志盘&#xff0c;在 OceanStor 闪存存储系统卓越性能…

C现代方法(第23章)笔记——库对数值和字符数据的支持

文章目录 第23章 库对数值和字符数据的支持23.1 <float.h>: 浮点类型的特性23.2 <limits.h>: 整数类型的大小23.3 <math.h>: 数学计算(C89)23.3.1 错误23.3.2 三角函数23.3.3 双曲函数23.3.4 指数函数和对数函数23.3.5 幂函数23.3.6 就近舍入、绝对值函数和取…

jQuery 第十一章(表单验证插件推荐)

文章目录 前言jValidateZebra FormjQuery.validValValidityValidForm BuilderForm ValidatorProgressionformvalidationjQuery Validation PluginjQuery Validation EnginejQuery ValidateValidarium后言 前言 hello world欢迎来到前端的新世界 &#x1f61c;当前文章系列专栏&…

深度学习卷积神经网络参数计算难点重点

目录 一、卷积层图像输出尺寸 二、池化层图像输出尺寸 三、全连接层输出尺寸 四、卷积层参数数量 五、全连接层参数数量 六、代码实现与验证 以LeNet5经典模型为例子并且通道数为1 LeNet5网络有7层&#xff1a; ​ 1.第1层&#xff1a;卷积层 ​ 输入&#xff1a;原始的图片像素…

五种多目标优化算法(NSDBO、NSGA3、MOGWO、NSWOA、MOPSO)求解微电网多目标优化调度(MATLAB代码)

一、多目标优化算法简介 &#xff08;1&#xff09;非支配排序的蜣螂优化算法NSDBO 多目标应用&#xff1a;基于非支配排序的蜣螂优化算法NSDBO求解微电网多目标优化调度&#xff08;MATLAB&#xff09;-CSDN博客 &#xff08;2&#xff09;NSGA3 NSGA-III求解微电网多目标…

模拟退火算法应用——求解函数的最小值

仅作自己学习使用 一、问题 需求&#xff1a; 计算函数 的极小值&#xff0c;其中个体x的维数n10&#xff0c;即x(x1,x2,…,x10)&#xff0c;其中每一个分量xi均需在[-20,20]内。因此可以知道&#xff0c;这个函数只有一个极小值点x (0,0,…,0)&#xff0c;且其极小值是0&…