mmdetection data pipeline

正午 2020-01-30 PM 48℃ 0条

mmdetection 是基于 Pytorch 的目标检测框架。本文对整个数据的处理流程做一个梳理

Pytorch data utils

Dataset

Pytorch 定义了一个相当方便和简洁的数据流程,在 torchvision 中也有比较好的实现,

Dataset 的接口定义如下, 主要重写 object 的两个方法, __getitem____len__ 这个类的作用是存储原始数据相关的信息,我们要实现自己的Dataset 可以继承这个类,实现两个抽象方法就好

class Dataset(object):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

比如我们可以自定义个简单的 dataset 如下,从文件夹下读取所有的图片作为训练数据

import os
from torch.utils import data

class FloaderDataset(data.Dataset):
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self._load_image()
    
  def _load_image(self):
    self.images = [os.path.join(self.root_dir, fname) for fname in os.listdir(self.root_dir)]
    
  def __getitem__(self, index):
    return self.images[index]
  
  def __len__(self):
    return  len(self.images)

Pytorch 还定义了 IterableDatasetConcatDataset 等其他接口
上面的接口还是不够方便,因为数据要输入到模型,通常还有很多预处理流程,比如图像的 resize, crop, padding 或其他数据增强等操作,所以 torchvision 抽象出了 transform 的接口,在 __getitem__方法中调用预定义的 transformer 就能实现预处理的流程。下面就是 flip 的实现, 在 __init__ 可以定义当前 transform 的参数,最后调用 __call__ 方法实现转换过程。


class RandomHorizontalFlip(object):
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
            if "masks" in target:
                target["masks"] = target["masks"].flip(-1)
            if "keypoints" in target:
                keypoints = target["keypoints"]
                keypoints = _flip_coco_person_keypoints(keypoints, width)
                target["keypoints"] = keypoints
        return image, target

Sampler

训练的时候数据并不是顺序遍历的,而且放入模型的时候也是一个 batch ,有时候还有其他的设置,比如根据图像的尺度或者句子长度组成 batch Pytorch 抽象了 sampler 接口 定义如下 :


class Sampler(object):
    r"""Base class for all Samplers.

    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.

    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

sampler 的 iter 处理的是 dataset 中的 index 所以所有的 dataset 都需要实现 __len__方法

  1. SequentialSampler 是对数据顺序迭代,实现也很简单, data_source 就是 dataset 对象

class SequentialSampler(Sampler):
    r"""Samples elements sequentially, always in the same order.

    Arguments:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

2.RandomSampler 每次随机选择一定数量的样本。如果 replacement = true 每次随机从样本选择一定样本返回,如果 replacement=False 则是无放回采样,把样本 index 打乱返回一个迭代器就好。

  1. SubsetRandomSampler 只对传入的部分 index 做采样。
  2. WeightedRandomSampler 根据样本权重采用
  3. BatchSampler 上面的 sampler 都是每次返回一个样本的迭代器,BatchSampler 需要传入上面的 sampler 作为参数,每次返回 batch_size 个 样本。
  4. DistributedSampler 并行的时候不同设备负责训练部分数据,这个 sampler 会根据环境信息,划分数据集。

DataLoader

有了 Dataset 就有了原始数据,通过 transform 完成数据预处理,Sampler 实现了每次迭代采用的方法,但是要组合这些接口,而且在数据加载过程通常是比较耗时的,多线程或者多进程实现必不可少,DataLoader 负责整合 Dataset 和 Sampler 并实现了多进程的支持

mmdetection data pipeline 实现

在 Pytorch 提供的接口上 mmdetection 利用注册机制,让 transformer 等接口可配置,整个数据的处理流程更加顺畅和灵活。犹豫整体接口和 Pytorch 定义一致,这里主要关注 mmdetection 实现的 transformer

  1. Resize
  2. 的所有 transformer 传入所有字段的 dict ,对处理结果和必要的参数都存放在 dict 里面,Reize 会对 传入的 box 和 mask 同时 resize , 还支持多尺度的输出。
  3. RandomFlip 随机翻转图像和 bbox 和 mask
  4. Pad 对 image 和 mask 做 padding
  5. Normalize 对图像做标准化
  6. RandomCrop 随机 crop
  7. PhotoMetricDistortion 数据增强,可以调整包括亮度,对比度,饱和度等常见的
    还有一些其他的 transformer 比如数据加载的,把 image 转换为 tensor 的 formating 等
标签: none

非特殊说明,本博所有文章均为博主原创。

上一篇 没有了
下一篇 利用序列模型实现 HTML 信息抽取

评论