潮科技行业入门指南 | 深度学习理论与实战:提高篇(14)——Mask R-CNN代码简介
编者按:本文节选自《深度学习理论与实战:提高篇 》一书,原文链接http://fancyerii.github.io/2019/03/14/dl-book/ 。作者李理,环信人工智能研发中心vp,有十多年自然语言处理和人工智能研发经验,主持研发过多款智能硬件的问答和对话系统,负责环信中文语义分析开放平台和环信智能机器人的设计与研发。
以下为正文。
目录
安装
demo.ipynb
运行
关键代码
train_shapes.ipynb
配置
Dataset
创建模型
训练
检测
测试
inspect_data.ipynb
选择数据集
加载Dataset
显示样本
Bounding Box
Mini Masks
Anchor
训练数据生成器
Facebook(Mask R-CNN的作者He Kaiming等人目前在Facebook)的实现在这里。但是这是用Caffe2实现的,本书没有介绍这个框架,因此我们介绍Tensorflow和Keras的版本实现的版本。但是建议有兴趣的读者也可以尝试一下Facebook提供的代码。
安装(((0)))
demo.ipynb
1、运行
jupyter notebook
打开文件samples/demo.ipynb,运行所有的Cell2、关键代码
这里是使用预训练的模型,会自动上网下载,所以第一次运行会比较慢。这是下载模型参数的代码:
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
# Download COCO trained weights from Releases if needed
if not os.path.exists(COCO_MODEL_PATH):
utils.download_trained_weights(COCO_MODEL_PATH)创建模型和加载参数:
# 创建MaskRCNN对象,模式是inferencemodel = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)
# 加载模型参数 model.load_weights(COCO_MODEL_PATH, by_name=True)
读取图片并且进行分割:
# 随机加载一张图片
file_names = next(os.walk(IMAGE_DIR))[2]
image = skimage.io.imread(os.path.join(IMAGE_DIR, random.choice(file_names)))
# 进行目标检测和分割
results = model.detect([image], verbose=1)
# 显示结果
r = results[0]
visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'],
class_names, r['scores'])检测结果r包括rois(RoI)、masks(对应RoI的每个像素是否属于目标物体)、scores(得分)和class_ids(类别)。
下图是运行的效果,我们可以看到它检测出来4个目标物体,并且精确到像素级的分割处理物体和背景。
图:Mask RCNN检测效果
train_shapes.ipynb
除了可以使用训练好的模型,我们也可以用自己的数据进行训练,为了演示,这里使用了一个很小的shape数据集。这个数据集是on-the-fly的用代码生成的一些三角形、正方形、圆形,因此不需要下载数据。
1、配置
代码提供了基础的类Config,我们只需要继承并稍作修改:
class ShapesConfig(Config):
"""用于训练shape数据集的配置
继承子基本的Config类,然后override了一些配置项。
"""
# 起个好记的名字
NAME = "shapes"
# 使用一个GPU训练,每个GPU上8个图片。因此batch大小是8 (GPUs * images/GPU).
GPU_COUNT = 1
IMAGES_PER_GPU = 8
# 分类数(需要包括背景类)
NUM_CLASSES = 1 + 3 # background + 3 shapes
# 图片为固定的128x128
IMAGE_MIN_DIM = 128
IMAGE_MAX_DIM = 128
# 因为图片比较小,所以RPN anchor也是比较小的
RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128) # anchor side in pixels
# 每张图片建议的RoI数量,对于这个小图片的例子可以取比较小的值。
TRAIN_ROIS_PER_IMAGE = 32
# 每个epoch的数据量
STEPS_PER_EPOCH = 100
# 每5步验证一下。
VALIDATION_STEPS = 5
config = ShapesConfig()
config.display()2、Dataset
对于我们自己的数据集,我们需要继承utils.Dataset类,并且重写如下方法:
load_image
load_mask
image_reference
在重写这3个方法之前我们首先来看load_shapes,这个函数on-the-fly的生成数据。
class ShapesDataset(utils.Dataset):
"""随机生成shape数据。包括三角形,正方形和圆形,以及它的位置。
这是on-th-fly的生成数据,因此不需要访问文件。
"""
def load_shapes(self, count, height, width):
"""生成图片
count: 返回的图片数量
height, width: 生成图片的height和width
"""
# 类别
self.add_class("shapes", 1, "square")
self.add_class("shapes", 2, "circle")
self.add_class("shapes", 3, "triangle")
# 注意:这里只是生成图片的specifications(说明书),
# 具体包括性质、颜色、大小和位置等信息。
# 真正的图片是在load_image()函数里根据这些specifications
# 来on-th-fly的生成。
for i in range(count):
bg_color, shapes = self.random_image(height, width)
self.add_image("shapes", image_id=i, path=None,
width=width, height=height,
bg_color=bg_color, shapes=shapes)其中add_image是在基类中定义:
版权保护: 本文由 沃派博客-沃派网 编辑,转载请保留链接: http://www.bdice.cn/html/46744.html