PytorchLearning-fileStructure

PytorchLearning-fileStructure

标签: DeepLearning Pytorch


概述

  “Dogs vs. Cats”是kaggle平台上4年前的一个比赛,今天学习了陈云实现的猫狗大战的一个开源项目,借鉴一下他的文件组织结构。

比赛简介

  “Dogs vs. Cats”是一个传统的二分类问题,要求辨别图片为猫还是狗。kaggle所给数据包括了训练集和测试集,训练集包括25000张图片,测试集包括了12500张图片,参赛者提供的结果为有格式要求的CSV文件。
  本文的目的主要是学习和借鉴一下陈云完成这个项目的文件结构。

代码结构分析

整体结构

  首先使用tree工具查看整个工程的结构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
$ tree
.
├── checkpoints
├── config.py
├── config.pyc
├── data
│   ├── dataset.py
│   ├── dataset.pyc
│   ├── get_data.sh
│   ├── __init__.py
│   └── __init__.pyc
├── main.py
├── models
│   ├── AlexNet.py
│   ├── AlexNet.pyc
│   ├── BasicModule.py
│   ├── BasicModule.pyc
│   ├── __init__.py
│   ├── __init__.pyc
│   ├── ResNet34.py
│   └── ResNet34.pyc
├── README.md
├── requirements.txt
└── utils
├── __init__.py
├── __init__.pyc
├── visualize.py
└── visualize.pyc

  其中包括了如下内容:
   checkpoints文件夹:用于保存训练好的模型
   data文件夹:用于进行原始数据的保存以及数据相关的操作
   models文件夹:用于模型的定义
   utils文件夹:用于保存所用的工具函数
   main.py:为主文件,训练和测试程序的入口
   config.py:配置文件,集中所有可配置的变量

data

  dataset.py中主要包括了如下函数:

1
def __init__(self,root,transforms=None,train=True,test=False):

  用于获取所有图片的地址,并根据训练,验证,测试划分数据

1
def __getitem__(self,index):

  用于返回一张图片的数据

model

  model用于模型的定义,文件夹中包含了三个文件,其中BasicModule.py是对于nn.Module的封装,提供了加载和保存模型的接口,在实际使用的时候直接调用其中的save函数和load函数即可实现模型的保存和加载。
  AlexNet.py实现了一个AlexNet网络,ResNet.py实现了简化的ResNet34网络。对于模型的选择可以在配置文件中进行相应的配置,配置方式如下:

1
model = 'ResNet34' # 使用的模型,名字必须与models/__init__.py中的名字一致

utils

  visualize.py中主要封装了一些visdom的操作,例如修改visdom配置,相关画图操作等,此处不在细说,可以去参考下他的源码。

main.py

  main.py为工程的主文件,在其中定义的了训练,验证,测试网络的步骤,训练的步骤如下:
   1、定义网络
   2、定义数据
   3、定义损失函数和优化器
   4、计算重要指标
   5、开始训练
  验证的目的是计算模型在交叉验证集上的准确率等信息,之后还需要将模型恢复到训练模型上,利用交叉验证机改进模型。
  测试的目的是使用测试集对于训练好的模型进行测试,并将结果输出保存到CSV文件中。
  train函数中的训练训练部分如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# train
for epoch in range(opt.max_epoch):

loss_meter.reset()
confusion_matrix.reset()

for ii,(data,label) in enumerate(train_dataloader):

# train model
input = Variable(data)
target = Variable(label)
if opt.use_gpu:
input = input.cuda()
target = target.cuda()

optimizer.zero_grad()
score = model(input)
loss = criterion(score,target)
loss.backward()
optimizer.step()


# meters update and visualize
loss_meter.add(loss.data[0])
confusion_matrix.add(score.data, target.data)

if ii%opt.print_freq==opt.print_freq-1:
vis.plot('loss', loss_meter.value()[0])

# 进入debug模式
if os.path.exists(opt.debug_file):
import ipdb;
ipdb.set_trace()


model.save()

# validate and visualize
val_cm,val_accuracy = val(model,val_dataloader)

vis.plot('val_accuracy',val_accuracy)
vis.log("epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}".format(
epoch = epoch,loss = loss_meter.value()[0],val_cm = str(val_cm.value()),train_cm=str(confusion_matrix.value()),lr=lr))

# update learning rate
if loss_meter.value()[0] > previous_loss:
lr = lr * opt.lr_decay
# 第二种降低学习率的方法:不会有moment等信息的丢失
for param_group in optimizer.param_groups:
param_group['lr'] = lr


previous_loss = loss_meter.value()[0]

config.py

  config.py作为配置文件,其中包含全部需要配置的文件,包括了训练模型,输入数据的规模,网络训练过程中的各种参数,以及是否使用GPU加速等配置。

小结

  虽然这种文件组织方式作者本人认为存在值得商榷的地方,但是他这种分工明确的组织方式仍然值得我去学习和借鉴,在之后的工程中应该注意好各程序部分之间分工和组织。