[Darknet 源码]2.模型训练的流程

正午 2020-07-05 PM 55℃ 0条

先介绍下怎么把框架用起来,比如要自定义一个自己的文本分类任务该在哪里修改,需要做哪些事情。Darknet训练模型的 pipline 如何。
截屏2020-07-05 下午9.54.25.png

官方文档的例子 跑一下 cifa10 的分类模型

./darknet classifier train cfg/cifar.data cfg/cifar_small.cfg

cfg/cifar.data 是数据相关的配置, 分类模型的配置在 cfg/cifar_small.cfg

任务定义

分析下图像分类模型的流程,Darknet 也主要用来做图像任务,我尝试来定义一个文本分类的任务,现在是空壳任务,只是把代码加进去

入口文件examples/darknet.c 只负责解析参数,搞清楚是想运行什么任务,classifier 就是要做分类任务
然后根据不同任务调用不同的函数, 不同任务的具体实现在单独的文件里,
classifier.c 实现的是分类任务,相应的还有很多比如 detector.c 或者分割 segmenter.c 这给了我们一个方向如果我有一个任务需要做些特殊处理,就可以在 examples 下面实现,然后修改 Makefile 具体修改方法参考上一篇 再添加到darknet.c里面。下面举个例子好了,

新建一个文本分类的任务
新建文件 example\text_classifier.c 并写入如下的代码

#include <stdio.h>

int run_text_classify(int argc, char **argv){
  printf("text classifier\n");
  return 0;
}

修改 example\darknet.c 中两行
增加函数声明

extern void run_text_classify(int argc, char **argv);

main 的函数中增加一个条件

 else if (0 == strcmp(argv[1], "text_classifier")) {
        run_text_classify(argc, argv);
}

修改Makefile

EXECOBJA=text_classifier.o captcha.o lsd.o super.o art.o tag.o cifar.o go.o rnn.o segmenter.o regressor.o classifier.o coco.o yolo.o detector.o nightmare.o instance-segmenter.o darknet.o

然后

make
./darknet text_classifier

就能在终端看到 run_text_classify 函数的输出,

代码执行流程

先不具体实现,宏观的感受下大致流程。
要实现这个文本分类,剩下的就是处理数据读入,和模型构建的问题。前者看似简单但也并不是很容易的事情, 当多卡训练的时候数据的读取速度也是很重要的. 当数据加载到内存,模型输入是如何定义的,如何实现 batch 的选择,等等一系列问题. 对比来看 Pytorch 为我们提供了 DataSetDataLoader 还有 Sampler 在 Datknet 中其实有对应的逻辑,只是不是面向对象,可复用也会弱一些。

classifier.c 里会根据第二个参数来判断具体是要训练模型还是要做预测等等,所以我也可以类似的设计

 if(0==strcmp(argv[2], "predict")) predict_classifier(data, cfg, weights, filename, top);
    else if(0==strcmp(argv[2], "fout")) file_output_classifier(data, cfg, weights, filename);
    else if(0==strcmp(argv[2], "try")) try_classifier(data, cfg, weights, filename, atoi(layer_s));
    else if(0==strcmp(argv[2], "train")) train_classifier(data, cfg, weights, gpus, ngpus, clear);
    else if(0==strcmp(argv[2], "demo")) demo_classifier(data, cfg, weights, cam_index, filename);
    else if(0==strcmp(argv[2], "gun")) gun_classifier(data, cfg, weights, cam_index, filename);
    else if(0==strcmp(argv[2], "threat")) threat_classifier(data, cfg, weights, cam_index, filename);
    else if(0==strcmp(argv[2], "test")) test_classifier(data, cfg, weights, layer);
    else if(0==strcmp(argv[2], "csv")) csv_classifier(data, cfg, weights);
    else if(0==strcmp(argv[2], "label")) label_classifier(data, cfg, weights);
    else if(0==strcmp(argv[2], "valid")) validate_classifier_single(data, cfg, weights);
    else if(0==strcmp(argv[2], "validmulti")) validate_classifier_multi(data, cfg, weights);
    else if(0==strcmp(argv[2], "valid10")) validate_classifier_10(data, cfg, weights);
    else if(0==strcmp(argv[2], "validcrop")) validate_classifier_crop(data, cfg, weights);
    else if(0==strcmp(argv[2], "validfull")) validate_classifier_full(data, cfg, weights);

直接去看训练部分的函数 train_classifier
我们理解这函数里的逻辑应该是和用 TensorflowPytorch 一样的, 构建模型net 然后把数据喂进去,前向传播,反向传播,然后更新参数
构造模型 ,这里是构造了多个 net 因为要在多卡上训练

network **nets = calloc(ngpus, sizeof(network*));

数据加载是多线程完成的

load_thread = load_data(args);

数据都有了然后就是一个 while 循环,迭代优化模型。模型的存储和 log 的输出都在这里可以找到,具体细节稍后等介绍了用到的 struct 再来实现。

标签: none

非特殊说明,本博所有文章均为博主原创。

评论