YOLOv8-seg 分割代码详解(三)Val

news/2024/7/10 23:05:25 标签: YOLO, python, 机器学习, 计算机视觉, 深度学习

前言

YOLOv8-seg 分割代码详解(一)Predict
YOLOv8-seg 分割代码详解(二)Train
YOLOv8-seg 分割代码详解(三)Val

  本文主要以源码+注释为主,可以了解 YOLOv8 计算评价指标的具体实现方法。

模型原始输出

python">preds = model(batch['img'], augment=augment)
preds: (list:2)
	0: (Tensor:(b, 4+cls_n+32, anchors))
	1: (tuple:3)
		0: (list:3)
			0: (Tensor:(b, 64+cls_n, 80, 80))
			1: (Tensor:(b, 64+cls_n, 40, 40))
			2: (Tensor:(b, 64+cls_n, 20, 20))
		1: (Tensor:(b, 32, anchors))
		2: (Tensor:(b, 32, 160, 160))

输出预处理

  NMS, 38 = 4 + class_score+class + 32 38=4+\text{class\_score+class}+32 38=4+class_score+class+32

python">preds = self.postprocess(preds)
preds: (tuple:2)
	0: (list,b)
		i: (Tensor:(obj_i, 38))
	1: (Tensor:(b, 32, 160, 160))

更新指标

python">self.update_metrics(preds, batch)


def update_metrics(self, preds, batch):
    """Metrics."""
    """遍历每张图像的输出"""
    for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
        idx = batch['batch_idx'] == si
        cls = batch['cls'][idx]
        bbox = batch['bboxes'][idx]
        nl, npr = cls.shape[0], pred.shape[0]  # number of labels, predictions
        shape = batch['ori_shape'][si]
        correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)  # init
        correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)  # init
        self.seen += 1

        if npr == 0:
            if nl:
                self.stats.append((correct_bboxes, correct_masks, *torch.zeros(
                    (2, 0), device=self.device), cls.squeeze(-1)))
                if self.args.plots:
                    self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
            continue

        # Masks
        midx = [si] if self.args.overlap_mask else idx
        gt_masks = batch['masks'][midx]
        pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:])

        # Predictions
        if self.args.single_cls:
            pred[:, 5] = 0
        predn = pred.clone()
        ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
                        ratio_pad=batch['ratio_pad'][si])  # native-space pred

		"""
		predn: (Tensor:(pn, 38))
		pred_mask: (Tensor:(pn, 160, 160))
		gt_mask: (Tensor:(1, 160, 160))
		"""
        # Evaluate
        if nl:
            height, width = batch['img'].shape[2:]
            tbox = ops.xywh2xyxy(bbox) * torch.tensor(
                (width, height, width, height), device=self.device)  # target boxes
            ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
                            ratio_pad=batch['ratio_pad'][si])  # native-space labels
            """
            labelsn: (Tensor:(ln, 5))
            correct_bboxes: (Tensor:(pn, 10))
			correct_masks: (Tensor:(pn, 10))
            """
            labelsn = torch.cat((cls, tbox), 1)  # native-space labels
            correct_bboxes = self._process_batch(predn, labelsn)
            correct_masks = self._process_batch(predn,
                                                labelsn,
                                                pred_masks,
                                                gt_masks,
                                                overlap=self.args.overlap_mask,
                                                masks=True)
            if self.args.plots:
                self.confusion_matrix.process_batch(predn, labelsn)

        # Append correct_masks, correct_boxes, pconf, pcls, tcls
        self.stats.append((correct_bboxes, correct_masks, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
        """后续关于绘图和存储的代码省略"""

self._process_batch

python">def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
    """
    Return correct prediction matrix
    Arguments:
        detections (array[N, 6]), x1, y1, x2, y2, conf, class
        labels (array[M, 5]), class, x1, y1, x2, y2
    Returns:
        correct (array[N, 10]), for 10 IoU levels
    """
    if masks:
    	"""one-hot"""
        if overlap:
            nl = len(labels)
            index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
            gt_masks = gt_masks.repeat(nl, 1, 1)  # shape(1,640,640) -> (n,640,640)
            gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
        if gt_masks.shape[1:] != pred_masks.shape[1:]:
            gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0]
            gt_masks = gt_masks.gt_(0.5)
        iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
    else:  # boxes
    	"""最普通的iou"""
        iou = box_iou(labels[:, 1:], detections[:, :4])

    return self.match_predictions(detections[:, 5], labels[:, 0], iou)

IoU 细节

  都是最简单的 IoU:交集 / 并集

python">def box_iou(box1, box2, eps=1e-7):
    """
    Calculate intersection-over-union (IoU) of boxes.
    Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
    Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py

    Args:
        box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
        box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.

    Returns:
        (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
    """

    # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
    (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
    inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)

    # IoU = inter / (area1 + area2 - inter)
    return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)


def mask_iou(mask1, mask2, eps=1e-7):
    """
    Calculate masks IoU.

    Args:
        mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
                        product of image width and height.
        mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
                        product of image width and height.
        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.

    Returns:
        (torch.Tensor): A tensor of shape (N, M) representing masks IoU.
    """
    intersection = torch.matmul(mask1, mask2.T).clamp_(0)
    union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection  # (area1 + area2) - intersection
    return intersection / (union + eps)

预测框在不同 IoU 阈值下是否正确检测到目标

python">def match_predictions(self, pred_classes, true_classes, iou):
    """
    Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.

    Args:
        pred_classes (torch.Tensor): Predicted class indices of shape(N,).
        true_classes (torch.Tensor): Target class indices of shape(M,).
        iou (torch.Tensor): IoU thresholds from 0.50 to 0.95 in space of 0.05.

    Returns:
        (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
    """
    """
    self.iouv: 0.5~0.95, 0.05间隔
    correct: (Tensor:(pn, 10))
    correct_class: (Tensor:(ln, pn))
    """
    correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
    correct_class = true_classes[:, None] == pred_classes
    for i, iouv in enumerate(self.iouv):
    	"""
    	x: (Tensor:(n, 2))
    	n: 满足 IoU > threshold 且类别匹配的输出个数
    	2: [ln_idx,pn_idx]
    	"""
        x = torch.nonzero(iou.ge(iouv) & correct_class)  # IoU > threshold and classes match
        if x.shape[0]:
            # Concatenate [label, detect, iou]
            """matches: (Tensor:(n, 3)), 相当于在 x 每一项后面增添对应的 IoU 数值"""
            matches = torch.cat((x, iou[x[:, 0], x[:, 1]].unsqueeze(1)), 1).cpu().numpy()
            if x.shape[0] > 1:
            	"""根据 IoU 排序以后依次对 pn_idx 和 ln_idx 去重"""
                matches = matches[matches[:, 2].argsort()[::-1]]
                matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
                matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
            correct[matches[:, 1].astype(int), i] = True
    return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)

统计指标

  mAP 是对每个类别的 AP 取平均,AP 是 PR 曲线的面积。

python">def get_stats(self):
    """Returns metrics statistics and results dictionary."""
    stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)]  # to numpy
    if len(stats) and stats[0].any():
        self.metrics.process(*stats)
    """np.bincount: 统计非负整数出现次数, 此处为统计每个类别出现次数"""
    self.nt_per_class = np.bincount(stats[-1].astype(int), minlength=self.nc)  # number of targets per class
    return self.metrics.results_dict

平均精度

python">def process(self, tp_b, tp_m, conf, pred_cls, target_cls):
    """
    Processes the detection and segmentation metrics over the given set of predictions.

    Args:
        tp_b (list): List of True Positive boxes.
        tp_m (list): List of True Positive masks.
        conf (list): List of confidence scores.
        pred_cls (list): List of predicted classes.
        target_cls (list): List of target classes.
    """
    """
    tp_b: (ndarray:(pn,10))
    tp_m: (ndarray:(pn,10))
    conf: (ndarray:(pn))
    pred_cls: (ndarray:(pn))
    target_cls: (ndarray:(ln))
    """

    results_mask = ap_per_class(tp_m,
                                conf,
                                pred_cls,
                                target_cls,
                                plot=self.plot,
                                on_plot=self.on_plot,
                                save_dir=self.save_dir,
                                names=self.names,
                                prefix='Mask')[2:]
    self.seg.nc = len(self.names)
    self.seg.update(results_mask)
    results_box = ap_per_class(tp_b,
                               conf,
                               pred_cls,
                               target_cls,
                               plot=self.plot,
                               on_plot=self.on_plot,
                               save_dir=self.save_dir,
                               names=self.names,
                               prefix='Box')[2:]
    self.box.nc = len(self.names)
    self.box.update(results_box)
python">def ap_per_class(tp,
                 conf,
                 pred_cls,
                 target_cls,
                 plot=False,
                 on_plot=None,
                 save_dir=Path(),
                 names=(),
                 eps=1e-16,
                 prefix=''):
    """
    Computes the average precision per class for object detection evaluation.

    Args:
        tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
        conf (np.ndarray): Array of confidence scores of the detections.
        pred_cls (np.ndarray): Array of predicted classes of the detections.
        target_cls (np.ndarray): Array of true classes of the detections.
        plot (bool, optional): Whether to plot PR curves or not. Defaults to False.
        on_plot (func, optional): A callback to pass plots path and data when they are rendered. Defaults to None.
        save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path.
        names (tuple, optional): Tuple of class names to plot PR curves. Defaults to an empty tuple.
        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16.
        prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.

    Returns:
        (tuple): A tuple of six arrays and one array of unique classes, where:
            tp (np.ndarray): True positive counts for each class.
            fp (np.ndarray): False positive counts for each class.
            p (np.ndarray): Precision values at each confidence threshold.
            r (np.ndarray): Recall values at each confidence threshold.
            f1 (np.ndarray): F1-score values at each confidence threshold.
            ap (np.ndarray): Average precision for each class at different IoU thresholds.
            unique_classes (np.ndarray): An array of unique classes that have data.

    """

    # Sort by objectness
    """按预测分类概率从大到小排序"""
    i = np.argsort(-conf)
    tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]

    # Find unique classes
    """标签中出现的类别及其对应的目标个数"""
    unique_classes, nt = np.unique(target_cls, return_counts=True)
    nc = unique_classes.shape[0]  # number of classes, number of detections

    # Create Precision-Recall curve and compute AP for each class
    px, py = np.linspace(0, 1, 1000), []  # for plotting
    ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
    for ci, c in enumerate(unique_classes):
        i = pred_cls == c
        n_l = nt[ci]  # number of labels
        n_p = i.sum()  # number of predictions
        if n_p == 0 or n_l == 0:
            continue

        # Accumulate FPs and TPs
        """
        fpc: (ndarray:(n,10))
        tpc: (ndarray:(n,10))
        n为当前类别的预测框数量, 以累加的方式得到不同阈值下的 tpfp 数量
        """
        fpc = (1 - tp[i]).cumsum(0)
        tpc = tp[i].cumsum(0)

        # Recall
        recall = tpc / (n_l + eps)  # recall curve
        r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0)  # negative x, xp because xp decreases

        # Precision
        precision = tpc / (tpc + fpc)  # precision curve
        p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1)  # p at pr_score

        # AP from recall-precision curve
        for j in range(tp.shape[1]):
            ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
            if plot and j == 0:
                py.append(np.interp(px, mrec, mpre))  # precision at mAP@0.5

    # Compute F1 (harmonic mean of precision and recall)
    f1 = 2 * p * r / (p + r + eps)
    names = [v for k, v in names.items() if k in unique_classes]  # list: only classes that have data
    names = dict(enumerate(names))  # to dict
    if plot:
        plot_pr_curve(px, py, ap, save_dir / f'{prefix}PR_curve.png', names, on_plot=on_plot)
        plot_mc_curve(px, f1, save_dir / f'{prefix}F1_curve.png', names, ylabel='F1', on_plot=on_plot)
        plot_mc_curve(px, p, save_dir / f'{prefix}P_curve.png', names, ylabel='Precision', on_plot=on_plot)
        plot_mc_curve(px, r, save_dir / f'{prefix}R_curve.png', names, ylabel='Recall', on_plot=on_plot)

    i = smooth(f1.mean(0), 0.1).argmax()  # max F1 index
    p, r, f1 = p[:, i], r[:, i], f1[:, i]
    tp = (r * nt).round()  # true positives
    fp = (tp / (p + eps) - tp).round()  # false positives
    return tp, fp, p, r, f1, ap, unique_classes.astype(int)

用 PR 计算 AP

python">def compute_ap(recall, precision):
    """
    Compute the average precision (AP) given the recall and precision curves.

    Arguments:
        recall (list): The recall curve.
        precision (list): The precision curve.

    Returns:
        (float): Average precision.
        (np.ndarray): Precision envelope curve.
        (np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
    """

    # Append sentinel values to beginning and end
    mrec = np.concatenate(([0.0], recall, [1.0]))
    mpre = np.concatenate(([1.0], precision, [0.0]))

    # Compute the precision envelope
    """
    np.maximum.accumulate(arr)
    返回数组每个元素是包括自身之前的所有元素最大值
    例如: [3, 1, 4, 1, 5, 9, 2, 6, 5] -> [3 3 4 4 5 9 9 9 9]
    """
    mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))

    # Integrate area under curve
    method = 'interp'  # methods: 'continuous', 'interp'
    if method == 'interp':
        x = np.linspace(0, 1, 101)  # 101-point interp (COCO)
        """np.trapz: 梯形法则积分"""
        ap = np.trapz(np.interp(x, mrec, mpre), x)  # integrate
    else:  # 'continuous'
        i = np.where(mrec[1:] != mrec[:-1])[0]  # points where x-axis (recall) changes
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])  # area under curve

    return ap, mpre, mrec

best model 与提前结束训练

  评估当前 epoch 训练好坏的指标是验证集上的 mAP,具体计算如下。训练时可用 patience=0 禁用提前结束训练,默认为50。

python">fitness = 0.1*(box_map50+seg_map50) + 0.9*(box_map+seg_map)

http://www.niftyadmin.cn/n/5179546.html

相关文章

vscode 快速打印console.log

第一步 输入这些 {// Print Selected Variabl 为自定义快捷键中需要使用的name,可以自行修改"Print Selected Variable": {"body": ["\nconsole.log("," %c $CLIPBOARD: ,"," background-color: #3756d4; padding:…

vue项目打包去掉console.log

安装依赖 npm install babel-plugin-transform-remove-console --save-devbabel.config.js 文件: // 项目需要用到的babel插件 //----------第一部分----------- // const allPlugins []; // if (process.env.NODE_ENV "production" || process.env.NODE_ENV &q…

景联文科技:驾驭数据浪潮,赋能AI产业——全球领先的数据标注解决方案供应商

根据IDC相关数据统计,全球数据量正在经历爆炸式增长,预计将从2016年的16.1ZB猛增至2025年的163ZB,其中大部分是非结构化数据,被直接利用,必须通过数据标注转化为AI可识别的格式,才能最大限度地发挥其应用价…

sinc 函数

See https://wuli.wiki/online/sinc.html 公式(3)的证明见 https://wuli.wiki/online/JdLem.html#ex_JdLem_1 百度百科

SQL Server使用语句创建数据库和表的方法

use mastergoif exists (select * from sysdatabases where name=Study)--判断Study数据库是否存在,如果是就进行删除drop database StudygoEXEC sp_configure show advanced options, 1GO-- 更新当前高级选项的配置信息RECONFIGUREGOEXEC sp_configure xp_cmdshell, 1GO-- 更…

go语言rpc初体验

go语言rpc初体验 package mainimport ("net""net/rpc" )// 注册一个接口进来 type HelloService struct { }func (s *HelloService) Hello(request string, replay *string) error {//返回值是通过修改replay的值*replay "hello " requestret…

python 爬虫之urllib 库的相关模块的介绍以及应用

文章目录 urllib.request 模块打开 URL:发送 HTTP 请求:处理响应: 应用如何读取并显示网页内容提交网页参数使用HTTP 代理访问页面 urllib.request 模块 在 Python 中,urllib.request 模块是用于处理 URL 请求的标准库模块之一。…

Java实现深拷贝的方式

文章目录 1. 实现 Cloneable 接口并重写 clone() 方法2. 使用序列化和反序列化实现深拷贝3. 第三方工具(1) Apache Commons BeanUtils 库(2) Apache Commons Lang 库(3) Spring Framework(4) Kryo 序列化库(5) FST 序列化库 1. 实现 Cloneable 接口并重写 clone() 方法 在 Jav…