PytorchLearning-testGAN
标签: DeepLearning Pytorch
概述
今天学习了一个简单的生成对抗网络,这个GAN学习的一个正态分布,具体内容如下。
GAN简介
生成对抗网络(GenerativeAdversarialNet)近年来深度学习中的一个很热门的话题,LeCun曾说过”GAN is the most interesting idea in the last 10 years in machine learning”.GAN解决了这样的问题:给定一批样本,训练一个系统可以生成相似的新样本。
GAN的简单示意图如下:
GAN主要包括了两个子网络,分别是生成器(generator)和判别器(discriminator)。生成器的输入是一组随机噪声,输出是利用噪声生成的一组数据。判别器则是判断输入的数据是真数据还是假数据。训练判别器时,需要利用生成器生成假数据和来自样本的真数据;训练生成器的时候,只利用噪声生成假数据,判别器用来评估生成的数据的质量。生成器的目标是生成尽可能以假乱真的图片,让判别器误判;判别器的目标是将样本的真数据和生成器生成的假数据分别开;有相反目标的两个网络在训练过程中互相对抗,互相提高,使得最后生成器可以生成与样本相似度很高的数据。
DCGAN是一个采用广泛的的全卷积网络,主要用作图像数据的生成,其输入为100维数据,利用上上卷积扩大输出,最后输出为36464的图片,其结构示意图如下:
GAN实现
实现思路
将要实现的GAN的目的是让网络学习一个均值和方差已知的正态分布,在输入噪声之后,生成器生成相同维度的符合相同正态分布的一组数据。
主要考虑的有以下方面的内容:
R:真实的数据,此模型中为1100的正态分布数据
I:进入生成器的随机噪声,此模型中随机噪声为1100
G:生成器网络模型
D:判别器网络模型
train:使得G和D相互竞争
在这个模型中,输入的真实数据是1100的符合正态分布u(4,1.25)的数据,输入的噪声为随机的1100的数据。生成器为一个前馈图,包含了两个隐层和3个全连接(linear),输出为一组1100的数据;判别器有形似的结构,只是每层的输入输出维度不同,其输入为1100的数据,输出为判定结果。
代码实现
这个GAN的的参数设置如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20# Data params
data_mean = 4
data_stddev = 1.25
# Model params
g_input_size = 1 # Random noise dimension coming into generator, per output vector
g_hidden_size = 50 # Generator complexity
g_output_size = 1 # size of generated output vector
d_input_size = 100 # Minibatch size - cardinality of distributions
d_hidden_size = 50 # Discriminator complexity
d_output_size = 1 # Single dimension for 'real' vs. 'fake'
minibatch_size = d_input_size
d_learning_rate = 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)
num_epochs = 30000
print_interval = 200
d_steps = 1
g_steps = 1
生成器实现如下:1
2
3
4
5
6
7
8
9
10
11class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.map1 = nn.Linear(input_size, hidden_size)
self.map2 = nn.Linear(hidden_size, hidden_size)
self.map3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = F.elu(self.map1(x))
x = F.sigmoid(self.map2(x))
return self.map3(x)
模型输出如下:1
2
3
4
5Generator (
(map1): Linear (1 -> 50)
(map2): Linear (50 -> 50)
(map3): Linear (50 -> 1)
)
判别器实现如下:1
2
3
4
5
6
7
8
9
10
11class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Discriminator, self).__init__()
self.map1 = nn.Linear(input_size, hidden_size)
self.map2 = nn.Linear(hidden_size, hidden_size)
self.map3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = F.elu(self.map1(x))
x = F.elu(self.map2(x))
return F.sigmoid(self.map3(x))
输出模型如下:1
2
3
4
5Discriminator (
(map1): Linear (200 -> 50)
(map2): Linear (50 -> 50)
(map3): Linear (50 -> 1)
)
单步训练如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30for epoch in range(num_epochs):
for d_index in range(d_steps):
# 1. Train D on real+fake
D.zero_grad()
# 1A: Train D on real
d_real_data = Variable(d_sampler(d_input_size))
d_real_decision = D(preprocess(d_real_data))
d_real_error = criterion(d_real_decision, Variable(torch.ones(1))) # ones = true
d_real_error.backward() # compute/store gradients, but don't change params
# 1B: Train D on fake
d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
d_fake_data = G(d_gen_input).detach() # detach to avoid training G on these labels
d_fake_decision = D(preprocess(d_fake_data.t()))
d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1))) # zeros = fake
d_fake_error.backward()
d_optimizer.step() # Only optimizes D's parameters; changes based on stored gradients from backward()
for g_index in range(g_steps):
# 2. Train G on D's response (but DO NOT train D on these labels)
G.zero_grad()
gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
g_fake_data = G(gen_input)
dg_fake_decision = D(preprocess(g_fake_data.t()))
g_error = criterion(dg_fake_decision, Variable(torch.ones(1))) # we want to fool, so pretend it's all genuine
g_error.backward()
g_optimizer.step() # Only optimizes G's parameters
完整代码可以参考github上的的gan_pytorch。
测试结果
代码中设置每200个epoch输出一个一次统计结果,其中200-1000个epoch之后结果如下:1
2
3
4
5200: D: 0.00214281608351/0.345613867044 G: 1.219622612 (Real: [3.6329677352309226, 1.2743346703142211], Fake: [0.27478053793311119, 0.061389553398676418])
400: D: 0.0024734903127/0.296345323324 G: 1.15904438496 (Real: [4.0885050618648533, 1.1591927512120153], Fake: [0.076806889995932576, 0.17435286763256208])
600: D: 0.000299081235426/0.135397613049 G: 2.17891645432 (Real: [3.8815199007093906, 1.358352249087974], Fake: [0.30949043005704879, 0.44819238084500146])
800: D: 0.00544052384794/0.122047036886 G: 3.39399766922 (Real: [4.0806429767608643, 1.2595771498812962], Fake: [1.2480358286574482, 0.6732695700859147])
1000: D: 0.199153736234/1.49040198326 G: 0.50436270237 (Real: [4.0617488116025928, 1.1983677948055262], Fake: [3.254411913752556, 1.0292862598070549])
可以发现生成器生成的数据的均值由最初的0左右训练到3左右。
22400到23000个epoch时,生成器生成的结果已经比较令人满意了,其结果如下:1
2
3
422400: D: 0.0718741938472/0.0800286009908 G: 1.94343280792 (Real: [4.236325685083866, 1.20664243159802], Fake: [4.294258806705475, 1.3486204260215457])
22600: D: 0.180983379483/0.119378611445 G: 2.08759236336 (Real: [3.9757382905483247, 1.3206865382611883], Fake: [4.3551864910125735, 1.2890139727165058])
22800: D: 0.0731378570199/0.15659108758 G: 1.91181635857 (Real: [3.8459619009494781, 1.2760463424123054], Fake: [4.5980649662017825, 1.3286327698283822])
23000: D: 0.617322683334/0.0718911290169 G: 1.75618612766 (Real: [4.1258033090829853, 1.2917487711772497], Fake: [4.7274960541725157, 1.1655121586747199])
34800到35800的epoch时学习的结果更近一步,结果如下1
2
3
4
5
634800: D: 1.21260821819/0.579729855061 G: 1.19163322449 (Real: [4.1337242549657818, 1.0945158014515735], Fake: [3.9489597564935686, 1.2484162018589191])
35000: D: 0.595994710922/0.492183923721 G: 0.961775839329 (Real: [4.0887473231554035, 1.2073327110558152], Fake: [4.3758284652233126, 1.227369298261034])
35200: D: 0.441894412041/0.215708702803 G: 1.49583280087 (Real: [3.7236255550384523, 1.2211940328679571], Fake: [4.0118666481971736, 1.2601506958034965])
35400: D: 0.0004737903364/0.229784876108 G: 2.48558044434 (Real: [3.9932019078731535, 1.2578601497639903], Fake: [4.0291512513160708, 1.3606264369187477])
35600: D: 0.126971721649/0.777751982212 G: 0.719469606876 (Real: [4.0348688819259406, 1.2721717771911061], Fake: [4.3053950923681263, 1.157315203477822])
35800: D: 0.0589950755239/0.277370721102 G: 1.22292232513 (Real: [4.2585669487714766, 1.30765629936759], Fake: [3.9463100868463514, 1.1169550739101126])
小结
这个GAN虽然比较简单,但是不简陋,实现了GAN的主要功能,让我更好的理解了一下GAN的实现。