欢迎访问:沃派博客 每天不定时发布IT文章相关资讯
当前位置:沃派博客-沃派网 > IT文章 > 正文

潮科技行业入门指南 | 深度学习理论与实战:提高篇(14)——​Mask R-CNN代码简介

03-25 IT文章

编者按:本文节选自《深度学习理论与实战:提高篇 》一书,原文链接http://fancyerii.github.io/2019/03/14/dl-book/ 。作者李理,环信人工智能研发中心vp,有十多年自然语言处理和人工智能研发经验,主持研发过多款智能硬件的问答和对话系统,负责环信中文语义分析开放平台和环信智能机器人的设计与研发。

以下为正文。

潮科技行业入门指南 | 深度学习理论与实战:提高篇(14)——​Mask R-CNN代码简介

目录

  • 安装

    demo.ipynb

  • 运行

    关键代码

    train_shapes.ipynb

  • 配置

    Dataset

    创建模型

    训练

    检测

    测试

    inspect_data.ipynb

  • 选择数据集

    加载Dataset

    显示样本

    Bounding Box

    Mini Masks

    Anchor

    训练数据生成器

    Facebook(Mask R-CNN的作者He Kaiming等人目前在Facebook)的实现在这里。但是这是用Caffe2实现的,本书没有介绍这个框架,因此我们介绍Tensorflow和Keras的版本实现的版本。但是建议有兴趣的读者也可以尝试一下Facebook提供的代码。

    安装(((0)))

    demo.ipynb

    1、运行

    jupyter notebook
    打开文件samples/demo.ipynb,运行所有的Cell

    2、关键代码

    这里是使用预训练的模型,会自动上网下载,所以第一次运行会比较慢。这是下载模型参数的代码:

    COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
    # Download COCO trained weights from Releases if needed
    if not os.path.exists(COCO_MODEL_PATH):
    utils.download_trained_weights(COCO_MODEL_PATH)

    创建模型和加载参数:

    # 创建MaskRCNN对象,模式是inferencemodel = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)

    # 加载模型参数 model.load_weights(COCO_MODEL_PATH, by_name=True)

    读取图片并且进行分割:

    # 随机加载一张图片
    file_names = next(os.walk(IMAGE_DIR))[2]
    image = skimage.io.imread(os.path.join(IMAGE_DIR, random.choice(file_names)))

    # 进行目标检测和分割
    results = model.detect([image], verbose=1)

    # 显示结果
    r = results[0]
    visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'],
    class_names, r['scores'])

    检测结果r包括rois(RoI)、masks(对应RoI的每个像素是否属于目标物体)、scores(得分)和class_ids(类别)。

    下图是运行的效果,我们可以看到它检测出来4个目标物体,并且精确到像素级的分割处理物体和背景。

    潮科技行业入门指南 | 深度学习理论与实战:提高篇(14)——​Mask R-CNN代码简介

    图:Mask RCNN检测效果

    train_shapes.ipynb

    除了可以使用训练好的模型,我们也可以用自己的数据进行训练,为了演示,这里使用了一个很小的shape数据集。这个数据集是on-the-fly的用代码生成的一些三角形、正方形、圆形,因此不需要下载数据。

    1、配置

    代码提供了基础的类Config,我们只需要继承并稍作修改:

    class ShapesConfig(Config):
    """用于训练shape数据集的配置
    继承子基本的Config类,然后override了一些配置项。
    """
    # 起个好记的名字
    NAME = "shapes"

    # 使用一个GPU训练,每个GPU上8个图片。因此batch大小是8 (GPUs * images/GPU).
    GPU_COUNT = 1
    IMAGES_PER_GPU = 8

    # 分类数(需要包括背景类)
    NUM_CLASSES = 1 + 3 # background + 3 shapes

    # 图片为固定的128x128
    IMAGE_MIN_DIM = 128
    IMAGE_MAX_DIM = 128

    # 因为图片比较小,所以RPN anchor也是比较小的
    RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128) # anchor side in pixels

    # 每张图片建议的RoI数量,对于这个小图片的例子可以取比较小的值。
    TRAIN_ROIS_PER_IMAGE = 32

    # 每个epoch的数据量
    STEPS_PER_EPOCH = 100

    # 每5步验证一下。
    VALIDATION_STEPS = 5

    config = ShapesConfig()
    config.display()

    2、Dataset

    对于我们自己的数据集,我们需要继承utils.Dataset类,并且重写如下方法:

  • load_image

    load_mask

    image_reference

    在重写这3个方法之前我们首先来看load_shapes,这个函数on-the-fly的生成数据。

    class ShapesDataset(utils.Dataset):
    """随机生成shape数据。包括三角形,正方形和圆形,以及它的位置。
    这是on-th-fly的生成数据,因此不需要访问文件。
    """

    def load_shapes(self, count, height, width):
    """生成图片
    count: 返回的图片数量
    height, width: 生成图片的height和width
    """
    # 类别
    self.add_class("shapes", 1, "square")
    self.add_class("shapes", 2, "circle")
    self.add_class("shapes", 3, "triangle")

    # 注意:这里只是生成图片的specifications(说明书),
    # 具体包括性质、颜色、大小和位置等信息。
    # 真正的图片是在load_image()函数里根据这些specifications
    # 来on-th-fly的生成。
    for i in range(count):
    bg_color, shapes = self.random_image(height, width)
    self.add_image("shapes", image_id=i, path=None,
    width=width, height=height,
    bg_color=bg_color, shapes=shapes)

    其中add_image是在基类中定义:

  • 版权保护: 本文由 沃派博客-沃派网 编辑,转载请保留链接: http://www.bdice.cn/html/46744.html