YOLOv5 分类模型 预处理

news/2024/7/11 1:07:18 标签: YOLO, 分类

YOLOv5 分类模型 预处理
flyfish

主要是 替换 classify_transforms 分类模型的 4块预处理

PyTorch实现

def classify_transforms(size=224):
    # Transforms to apply if albumentations not installed
    return T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])

YOLOv5 分类模型的预处理(1) Resize 和 CenterCrop
YOLOv5 分类模型的预处理(2)ToTensor 和 Normalize
YOLOv5 分类模型 Top 1和Top 5 指标说明
YOLOv5 分类模型 Top 1和Top 5 指标实现

替换 PyTorch实现

T.Resize的实现

#实现 PyTorch Resize
target_size =224

img_w = images.width
img_h = images.height

if(img_h >= img_w):# hw
 
    resize_img = images.resize((target_size, int(target_size * img_h / img_w)), Image.BILINEAR)
else:
    resize_img = images.resize((int(target_size * img_w  / img_h),target_size), Image.BILINEAR)

T.CenterCrop的实现

#实现 PyTorch CenterCrop
width = resize_img.width
height = resize_img.height

center_x,center_y = width//2,height//2
left = center_x - (target_size//2)
top = center_y- (target_size//2)
right =center_x +target_size//2
bottom = center_y+target_size//2
cropped_img = resize_img.crop((left, top, right, bottom))

T.ToTensor和T.Normalize的实现

#实现 PyTorch ToTensor Normalize
images = np.asarray(cropped_img)
print("preprocess:",images.shape)
images = images.astype('float32')
images = (images/255-mean)/std
images = images.transpose((2, 0, 1))# HWC to CHW
print("preprocess:",images.shape)

images = np.ascontiguousarray(images)
images=torch.from_numpy(images)

完整描述

import time
from models.common import DetectMultiBackend
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
import cv2
import numpy as np

import torch
#from utils.augmentations import classify_transforms
from PIL import Image
import torchvision.transforms as transforms

class DatasetFolder:

    def __init__(
        self,
        root: str,

    ) -> None:
        self.root = root
        classes, class_to_idx = self.find_classes(self.root)
        samples = self.make_dataset(self.root, class_to_idx)

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Optional[Dict[str, int]] = None,

    ) -> List[Tuple[str, int]]:

        directory = os.path.expanduser(directory)

        if class_to_idx is None:
            _, class_to_idx = self.find_classes(directory)
        elif not class_to_idx:
            raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

        instances = []
        available_classes = set()
        for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_dir = os.path.join(directory, target_class)
            if not os.path.isdir(target_dir):
                continue
            for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                for fname in sorted(fnames):
                    path = os.path.join(root, fname)
                    if 1:  # 验证:
                        item = path, class_index
                        instances.append(item)

                        if target_class not in available_classes:
                            available_classes.add(target_class)

        empty_classes = set(class_to_idx.keys()) - available_classes
        if empty_classes:
            msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "

        return instances

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:

        classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
        if not classes:
            raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

    def __getitem__(self, index: int) -> Tuple[Any, Any]:

        path, target = self.samples[index]
        sample = self.loader(path)

        return sample, target

    def __len__(self) -> int:
        return len(self.samples)

    def loader(self, path):
        print("path:", path)
        #img = cv2.imread(path)  # BGR HWC
        img=Image.open(path) # RGB HWC
        return img


def time_sync():
    return time.time()


dataset = DatasetFolder(root="/media/a/flyfish/test/val")


# image, label=dataset[7]
# print(image.shape)
#
weights = "/media/a/flyfish/yolov5-6.2/classes10.pt"
device = "cpu"
model = DetectMultiBackend(weights, device=device, dnn=False, fp16=False)
model.eval()

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
def preprocess(images):
  

    #实现 PyTorch Resize
    target_size =224

    img_w = images.width
    img_h = images.height
    
    if(img_h >= img_w):# hw
 
        resize_img = images.resize((target_size, int(target_size * img_h / img_w)), Image.BILINEAR)
    else:
        resize_img = images.resize((int(target_size * img_w  / img_h),target_size), Image.BILINEAR)

    #实现 PyTorch CenterCrop
    width = resize_img.width
    height = resize_img.height

    center_x,center_y = width//2,height//2
    left = center_x - (target_size//2)
    top = center_y- (target_size//2)
    right =center_x +target_size//2
    bottom = center_y+target_size//2
    cropped_img = resize_img.crop((left, top, right, bottom))

    #实现 PyTorch ToTensor Normalize
    images = np.asarray(cropped_img)
    print("preprocess:",images.shape)
    images = images.astype('float32')
    images = (images/255-mean)/std
    images = images.transpose((2, 0, 1))# HWC to CHW
    print("preprocess:",images.shape)

    images = np.ascontiguousarray(images)
    images=torch.from_numpy(images)
    #images = images.unsqueeze(dim=0).float()
    return images

pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
# current batch size =1
for i, (images, labels) in enumerate(dataset):
    print("i:", i)
    im = preprocess(images)
    images = im.unsqueeze(0).to("cpu").float()
 
    print(images.shape)


        
    t1 = time_sync()
    images = images.to(device, non_blocking=True)
    t2 = time_sync()
    # dt[0] += t2 - t1

    y = model(images)
    y=y.numpy()
   
    print("y:", y)
    t3 = time_sync()
    # dt[1] += t3 - t2

    tmp1=y.argsort()[:,::-1][:, :5]
   
    print("tmp1:", tmp1)
    pred.append(tmp1)

    print("labels:", labels)

    
    targets.append(labels)

    print("for pred:", pred)  # list
    print("for targets:", targets)  # list
    # dt[2] += time_sync() - t3


pred, targets = np.concatenate(pred), np.array(targets)
print("pred:", pred)
print("pred:", pred.shape)
print("targets:", targets)
print("targets:", targets.shape)
correct = ((targets[:, None] == pred)).astype(np.float32)
print("correct:", correct.shape)
print("correct:", correct)
acc = np.stack((correct[:, 0], correct.max(1)), axis=1)  # (top1, top5) accuracy
print("acc:", acc.shape)
print("acc:", acc)
top = acc.mean(0)
print("top1:", top[0])
print("top5:", top[1])


http://www.niftyadmin.cn/n/5189399.html

相关文章

这款IDEA插件真的爱了

IDEA是一款功能强大的集成开发环境(IDE),它可以帮助开发人员更加高效地编写、调试和部署软件应用程序。我们在编写完接口代码后需要进行接口调试等操作,一般需要打开额外的调试工具。 今天给大家介绍一款IDEA插件:Api…

vite+vue3+electron开发环境搭建

环境 node 18.14.2 yarn 1.22 项目创建 yarn create vite test01安装vue环境 cd test01 yarn yarn dev说明vue环境搭建成功 安装electron # 因为有的版本会报错所以指定了版本 yarn add electron26.1.0 -D安装vite-plugin-electron yarn add -D vite-plugin-electron根目…

面试鸭 - 专注于面试刷题的网站

网上面试题有很多,但此套面试题真实、原创、高频,全网最强。 题目涵盖大中小公司,真实靠谱,有频率和难度的标记,助你成为Offer收割机。 面试鸭地址:https://mianshiya.skyofit.com/ 本套题是我原创&…

国家大基金三期线上金融正式倒计时!11月17日,共启芯片产业新篇章

国家大基金三期线上金融正式倒计时!11月17日,共启芯片产业新篇章 新时代浪潮下,全球化进程不断推动各科技大国的核心发展,芯片作为强有力的竞争标志,是国与国之间的重要技术战争焦点。同时,国内基金发展势…

米尔AM62x核心板,高配价低,AM335x升级首选

AM335x是TI经典的工业MPU,它引领了一个时代,即工业市场从MCU向MPU演进,帮助产业界从Arm9迅速迁移至高性能Cortex-A8处理器。随着工业4.0的发展,HMI人机交互、工业工控、医疗等领域的应用面临迫切的升级需求,AM62x处理器…

接口自动化和UI自动化的区别

接口自动化和UI自动化的区别 目录 1 自动化概念1.1 接口自动化1.1.1 接口概念1.1.2 接口自动化概念1.2 UI自动化概念1.2.1 UI测试概念1.2.2 UI自动化测试概念2 自动化结构2.1 接口自动化结构2.2 UI自动化结构3 自动化差别3.1 相同3.1.1 设计思想3.1.2 框架搭建3.1.3 持续集成3…

[ 云计算 | AWS ] AI 编程助手新势力 Amazon CodeWhisperer:优势功能及实用技巧

文章目录 一、Amazon CodeWhisperer 简介1.1 CodeWhisperer 是什么1.2 Amazon CodeWhisperer 是如何工作的 二、Amazon CodeWhisperer 的优势和功能2.1 Amazon CodeWhisperer 的优势2.2 Amazon CodeWhisperer 的代码功能 三、Amazon CodeWhisperer 安装3.1 安装到 IntelliJ IDE…

stylelint报错at-rule-no-unknown

stylelint报错at-rule-no-unknown stylelint还将各种 sass -rules 标记mixin为include显示未知错误 at-rule-no-unknown ✖ stylelint --fix:Deprecation warnings: 78:1 ✖ Unexpected unknown at-rule "mixin" at-rule-no-unknown 112:3 ✖ Unexpected un…