mmdetection 是基于 Pytorch 的目标检测框架。本文对整个数据的处理流程做一个梳理
Pytorch data utils
Dataset
Pytorch 定义了一个相当方便和简洁的数据流程,在 torchvision 中也有比较好的实现,
Dataset 的接口定义如下, 主要重写 object 的两个方法, __getitem__
和 __len__
这个类的作用是存储原始数据相关的信息,我们要实现自己的Dataset 可以继承这个类,实现两个抽象方法就好
class Dataset(object):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
比如我们可以自定义个简单的 dataset 如下,从文件夹下读取所有的图片作为训练数据
import os
from torch.utils import data
class FloaderDataset(data.Dataset):
def __init__(self, root_dir):
self.root_dir = root_dir
self._load_image()
def _load_image(self):
self.images = [os.path.join(self.root_dir, fname) for fname in os.listdir(self.root_dir)]
def __getitem__(self, index):
return self.images[index]
def __len__(self):
return len(self.images)
Pytorch 还定义了 IterableDataset
和 ConcatDataset
等其他接口
上面的接口还是不够方便,因为数据要输入到模型,通常还有很多预处理流程,比如图像的 resize, crop, padding 或其他数据增强等操作,所以 torchvision 抽象出了 transform 的接口,在 __getitem__
方法中调用预定义的 transformer 就能实现预处理的流程。下面就是 flip 的实现, 在 __init__
可以定义当前 transform 的参数,最后调用 __call__
方法实现转换过程。
class RandomHorizontalFlip(object):
def __init__(self, prob):
self.prob = prob
def __call__(self, image, target):
if random.random() < self.prob:
height, width = image.shape[-2:]
image = image.flip(-1)
bbox = target["boxes"]
bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
target["boxes"] = bbox
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
if "keypoints" in target:
keypoints = target["keypoints"]
keypoints = _flip_coco_person_keypoints(keypoints, width)
target["keypoints"] = keypoints
return image, target
Sampler
训练的时候数据并不是顺序遍历的,而且放入模型的时候也是一个 batch ,有时候还有其他的设置,比如根据图像的尺度或者句子长度组成 batch Pytorch 抽象了 sampler 接口 定义如下 :
class Sampler(object):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
way to iterate over indices of dataset elements, and a :meth:`__len__` method
that returns the length of the returned iterators.
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~torch.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
"""
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
sampler 的 iter 处理的是 dataset 中的 index 所以所有的 dataset 都需要实现 __len__
方法
- SequentialSampler 是对数据顺序迭代,实现也很简单, data_source 就是 dataset 对象
class SequentialSampler(Sampler):
r"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
2.RandomSampler 每次随机选择一定数量的样本。如果 replacement = true 每次随机从样本选择一定样本返回,如果 replacement=False 则是无放回采样,把样本 index 打乱返回一个迭代器就好。
- SubsetRandomSampler 只对传入的部分 index 做采样。
- WeightedRandomSampler 根据样本权重采用
- BatchSampler 上面的 sampler 都是每次返回一个样本的迭代器,BatchSampler 需要传入上面的 sampler 作为参数,每次返回 batch_size 个 样本。
- DistributedSampler 并行的时候不同设备负责训练部分数据,这个 sampler 会根据环境信息,划分数据集。
DataLoader
有了 Dataset 就有了原始数据,通过 transform 完成数据预处理,Sampler 实现了每次迭代采用的方法,但是要组合这些接口,而且在数据加载过程通常是比较耗时的,多线程或者多进程实现必不可少,DataLoader 负责整合 Dataset 和 Sampler 并实现了多进程的支持
mmdetection data pipeline 实现
在 Pytorch 提供的接口上 mmdetection 利用注册机制,让 transformer 等接口可配置,整个数据的处理流程更加顺畅和灵活。犹豫整体接口和 Pytorch 定义一致,这里主要关注 mmdetection 实现的 transformer
- Resize
- 的所有 transformer 传入所有字段的 dict ,对处理结果和必要的参数都存放在 dict 里面,Reize 会对 传入的 box 和 mask 同时 resize , 还支持多尺度的输出。
- RandomFlip 随机翻转图像和 bbox 和 mask
- Pad 对 image 和 mask 做 padding
- Normalize 对图像做标准化
- RandomCrop 随机 crop
- PhotoMetricDistortion 数据增强,可以调整包括亮度,对比度,饱和度等常见的
还有一些其他的 transformer 比如数据加载的,把 image 转换为 tensor 的 formating 等