YOLOV5代码yolo.py文件解读

news/2024/7/10 23:34:29 标签: YOLO

YOLOV5源码的下载:

git clone https://github.com/ultralytics/yolov5.git

YOLOV5代码yolo.py文件解读:

import argparse
import logging
import sys
from copy import deepcopy
from pathlib import Path

import math

sys.path.append('./')  # to run '$ python *.py' files in subdirectories
logger = logging.getLogger(__name__)

import torch
import torch.nn as nn

from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape
from models.experimental import MixConv2d, CrossConv, C3
from utils.general import check_anchor_order, make_divisible, check_file, set_logging
from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
    select_device, copy_attr
    
class Detect(nn.Module):
    stride = None  # strides computed during build
    export = False  # onnx export

    def __init__(self, nc=80, anchors=(), ch=()):  # detection layer
        super(Detect, self).__init__()
        self.nc = nc  # number of classes
        self.no = nc + 5  # number of outputs per anchor. VOC: 20+5=25
        self.nl = len(anchors)  # number of detection layers = 3
        self.na = len(anchors[0]) // 2  # number of anchors  =3
        self.grid = [torch.zeros(1)] * self.nl  # init grid
        a = torch.tensor(anchors).float().view(self.nl, -1, 2)
        # 模型中需要保存下来的参数包括两种: 一种是反向传播需要被optimizer更新的,称之为 parameter;
        # 一种是反向传播不需要被optimizer更新,称之为 buffer。
        # 第二种参数我们需要创建tensor, 然后将tensor通过register_buffer()进行注册,
        # 可以通过model.buffers() 返回,注册完后参数也会自动保存到OrderDict中去。
        # 注意:buffer的更新在forward中,optim.step只能更新nn.parameter类型的参数
        self.register_buffer('anchors', a)  # shape(nl,na,2)
        self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2))  # shape(nl,1,na,1,1,2)
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv 1*1卷积

    def forward(self, x):
        # x = x.copy()  # for profiling
        z = []  # inference output
        self.training |= self.export
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

            if not self.training:  # inference
                if self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i] = self._make_grid(nx, ny).to(x[i].device)

                y = x[i].sigmoid()
                y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i]  # xy
                y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                z.append(y.view(bs, -1, self.no)) # 预测框坐标信息

        return x if self.training else (torch.cat(z, 1), x) # 预测框坐标, obj, cls

    @staticmethod
    def _make_grid(nx=20, ny=20):
        # 划分为单元网格
        yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
        return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()

# 网络模型类
class Model(nn.Module):
    def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None):  # model, input channels, number of classes
        super(Model, self).__init__()
        if isinstance(cfg, dict):
            self.yaml = cfg  # model dict
        else:  # is *.yaml
            import yaml  # for torch hub
            self.yaml_file = Path(cfg).name
            with open(cfg) as f:
                self.yaml = yaml.load(f, Loader=yaml.FullLoader)  # model dict

        # Define model
        if nc and nc != self.yaml['nc']:
            print('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
            self.yaml['nc'] = nc  # override yaml value
        self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch])  # model, savelist, ch_out
        # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])

        # Build strides, anchors
        m = self.model[-1]  # Detect()
        if isinstance(m, Detect):
            s = 128  # 2x min stride
            # m.stride = [8,16,32]
            m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])  # forward 
            # anchor大小计算, 例如 [10, 13] --> [1.25, 1.625]
            m.anchors /= m.stride.view(-1, 1, 1)
            check_anchor_order(m) # 检查anchor顺序和stride顺序是否一致
            self.stride = m.stride
            self._initialize_biases()  # 初始化偏置 only run once
            # print('Strides: %s' % m.stride.tolist())

        # Init weights, biases
        initialize_weights(self) # 初始化权重
        self.info()
        print('')

    def forward(self, x, augment=False, profile=False):
        if augment: # TTA (Test Time Augmentation)
            img_size = x.shape[-2:]  # height, width
            s = [1, 0.83, 0.67]  # scales
            f = [None, 3, None]  # flips (2-ud, 3-lr)
            y = []  # outputs
            for si, fi in zip(s, f):
                xi = scale_img(x.flip(fi) if fi else x, si) # 改变图像尺寸
                yi = self.forward_once(xi)[0]  # forward
                # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1])  # save
                yi[..., :4] /= si  # de-scale
                if fi == 2:
                    yi[..., 1] = img_size[0] - yi[..., 1]  # de-flip ud
                elif fi == 3:
                    yi[..., 0] = img_size[1] - yi[..., 0]  # de-flip lr
                y.append(yi)
            return torch.cat(y, 1), None  # augmented inference, train
        else:
            return self.forward_once(x, profile)  # single-scale inference, train

    def forward_once(self, x, profile=False):
        y, dt = [], []  # outputs
        for m in self.model:
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers

            if profile:
                try:
                    import thop # THOP: PyTorch-OpCounter 估算PyTorch模型的FLOPs模块
                    o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPS
                except:
                    o = 0
                t = time_synchronized()
                for _ in range(10):
                    _ = m(x)
                dt.append((time_synchronized() - t) * 100)
                print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))

            x = m(x)  # 执行网络组件操作
            y.append(x if m.i in self.save else None)  # save output

        if profile:
            print('%.1fms total' % sum(dt))
        return x

    def _initialize_biases(self, cf=None):  # initialize biases into Detect(), cf is class frequency
        # https://arxiv.org/abs/1708.02002 section 3.3
        # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
        m = self.model[-1]  # Detect() module
        for mi, s in zip(m.m, m.stride):  # from
            b = mi.bias.view(m.na, -1)  # conv.bias(255) to (3,85)
            b[:, 4] += math.log(8 / (640 / s) ** 2)  # obj (8 objects per 640 image)
            b[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum())  # cls
            mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)

    def _print_biases(self):
        m = self.model[-1]  # Detect() module
        for mi in m.m:  # from
            b = mi.bias.detach().view(m.na, -1).T  # conv.bias(255) to (3,85)
            print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))

    # def _print_weights(self):
    #     for m in self.model.modules():
    #         if type(m) is Bottleneck:
    #             print('%10.3g' % (m.w.detach().sigmoid() * 2))  # shortcut weights

    def fuse(self):  # fuse model Conv2d() + BatchNorm2d() layers
        print('Fusing layers... ')
        for m in self.model.modules():
            if type(m) is Conv and hasattr(m, 'bn'):
                m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatability
                m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update conv
                delattr(m, 'bn')  # remove batchnorm
                m.forward = m.fuseforward  # update forward
        self.info()
        return self
        
    def nms(self, mode=True):  # add or remove NMS module
        present = type(self.model[-1]) is NMS  # last layer is NMS
        if mode and not present:
            print('Adding NMS... ')
            m = NMS()  # module
            m.f = -1  # from
            m.i = self.model[-1].i + 1  # index
            self.model.add_module(name='%s' % m.i, module=m)  # add
            self.eval()
        elif not mode and present:
            print('Removing NMS... ')
            self.model = self.model[:-1]  # remove
        return self

    def autoshape(self):  # add autoShape module
        print('Adding autoShape... ')
        m = autoShape(self)  # wrap model
        copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=())  # copy attributes
        return m
        
    def info(self, verbose=False):  # print model information
        model_info(self, verbose)

# 解析网络模型配置文件并构建模型
def parse_model(d, ch):  # model_dict, input_channels(3)
    logger.info('\n%3s%18s%3s%10s  %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
    #将模型结构的depth_multiple,width_multiple提取出,赋值给gd (yolov5s: 0.33),gw (yolov5s:0.50)
    anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'] 
    na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchors =3
    no = na * (nc + 5)  # number of outputs = anchors * (classes + 5); VOC : 75

    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out = 3
    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
        m = eval(m) if isinstance(m, str) else m  # eval strings
        for j, a in enumerate(args):
            try:
                args[j] = eval(a) if isinstance(a, str) else a  # eval strings
            except:
                pass
        # 控制深度的代码
        n = max(round(n * gd), 1) if n > 1 else n  # depth gain
        if m in [Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
            c1, c2 = ch[f], args[0]

            # Normal
            # if i > 0 and args[0] != no:  # channel expansion factor
            #     ex = 1.75  # exponential (default 2.0)
            #     e = math.log(c2 / ch[1]) / math.log(2)
            #     c2 = int(ch[1] * ex ** e)
            # if m != Focus:

            # 控制宽度(卷积核个数)的代码
            c2 = make_divisible(c2 * gw, 8) if c2 != no else c2

            # Experimental
            # if i > 0 and args[0] != no:  # channel expansion factor
            #     ex = 1 + gw  # exponential (default 2.0)
            #     ch1 = 32  # ch[1]
            #     e = math.log(c2 / ch1) / math.log(2)  # level 1-n
            #     c2 = int(ch1 * ex ** e)
            # if m != Focus:
            #     c2 = make_divisible(c2, 8) if c2 != no else c2

            args = [c1, c2, *args[1:]]
            if m in [BottleneckCSP, C3]:
                args.insert(2, n)
                n = 1
        elif m is nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
        elif m is Detect:
            args.append([ch[x + 1] for x in f])
            if isinstance(args[1], int):  # number of anchors
                args[1] = [list(range(args[1] * 2))] * len(f)
        else:
            c2 = ch[f]

        # *args表示接收任意个数量的参数,调用时会将实际参数打包为一个元组传入实参
        m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args)  # module
        t = str(m)[8:-2].replace('__main__.', '')  # module type
        np = sum([x.numel() for x in m_.parameters()])  # number params
        m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
        logger.info('%3s%18s%3s%10.0f  %-40s%-30s' % (i, f, n, np, t, args))  # print
        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)
        ch.append(c2)
    return nn.Sequential(*layers), sorted(save)


if __name__ == '__main__':
    # 建立参数解析对象parser
    parser = argparse.ArgumentParser()
    # 添加属性:给xx实例增加一个aa属性,如 xx.add_argument("aa")
    parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    # 采用parser对象的parse_args函数获取解析的参数
    opt = parser.parse_args()
    opt.cfg = check_file(opt.cfg)  # check file
    set_logging()
    device = select_device(opt.device)

    # Create model
    model = Model(opt.cfg).to(device)
    model.train()

    # Profile
    # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
    # y = model(img, profile=True)

    # Tensorboard
    # from torch.utils.tensorboard import SummaryWriter
    # tb_writer = SummaryWriter()
    # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
    # tb_writer.add_graph(model.model, img)  # add model to tensorboard
    # tb_writer.add_image('test', img[0], dataformats='CWH')  # add model to tensorboard


http://www.niftyadmin.cn/n/385472.html

相关文章

C#,生信软件实践(06)——DNA数据库GenBank文件的详解介绍及解释器之完整C#源代码

1 GenBank 1.1 NCBI——美国国家生物技术信息中心(美国国立生物技术信息中心) NCBI(美国国立生物技术信息中心)是在NIH的国立医学图书馆(NLM)的一个分支。它的使命包括四项任务:1. 建立关于分…

「致童真」Tomcat这只猫你还没掌握正确的撸猫姿势,是因为你用错了方式!

Tomcat 的优化方式 Tomcat 作为 Web 服务器,它的处理性能直接关系到用户体验,下面是几种常见的 优化措施: 1、去掉对 web.xml 的监视,把 jsp 提前编辑成 Servlet。有富余物理内存的情况, 加大 tomcat 使用的 jvm 的内存。 2、…

Ubuntu开机桌面黑屏只有鼠标问题解决办法(搜狗输入法导致)

参考: Ubuntu开机桌面黑屏只有鼠标问题解决办法(搜狗输入法导致) 问题描述 笔者在安装完搜狗输入法重启电脑后,电脑开机黑屏,只有鼠标的光标可以移动。笔者一开始以为是系统问题,网上查阅资料才发现有大量…

网络工程师一定要会关键技能:如何进行IP子网划分?

对于所有从事IP网络方面工作的工程师来说,进行IP子网划分操作属于一个必备的关键技能,也属于必须掌握的专业内容;但对于初学者来说,真正理解IP子网划分的概念也是一件相当困难的事情。 怎样才能正确地进行子网划分操作 IP子网划…

【性能调优】真实体验 “系统调用是重开销”

实践背景是开发云原生背景下的指纹识别插件,主要针对的是镜像、容器等云时代的软件资产。 信息安全语境下的 指纹识别 指的是定位软件的特征,如名称、版本号、开源许可证等,就像指纹是人的独特生物凭证,这些特征是软件的独特电子凭…

LeetCode刷题 --- 哈希表

1. 两数之和 解题思路: 利用哈希表,key存数组当前值,value存数组下标两数之和等于target,可以看做是两个数是配对遍历数组,看哈希表中有没有值和这个当前值配对,如果没有,就存入哈希表如果有&am…

QListWidget和QListView的使用和item点击事件

QListWidget和QListView很常用,但是使用上功能类似,往往容易分不清区别,但是不知道如何选择。这里总结下二者之间的区别和使用,分享给有需要的人,有需要的可点击收藏。 QListView介绍 QListView是Qt中用于显示列表的一…

如果提取音乐的伴奏和人声,分享两个方法给大家!

音乐中的伴奏提取一直是许多音频爱好者关注的话题。在本文中,我们将介绍两种简单易用的方法,并且特别推荐一款记灵在线工具,它能够帮助你轻松提取音乐伴奏,并且支持批量处理! 方法一:Audacity 首先&#…