手把手带你用 Python 写一个 GAN
GAN
前言
我们接下来从这几个方面来探索.
GAN能用来做什么GAN的原理GAN的代码实现
用途
GAN自 2014 年诞生以来, 就一直备受关注, 著名的应用也随即产出, 比如比较著名的 GAN 的应用有 Pix2Pix、CycleGAN 等, 大家也将它用于各个地方.
- 缺失 /模糊像素的补充
- 图片修复
- etc...
我觉得还有一个比较重要的用途, 很多人都会缺少数据集, 那么就通过GAN去生成数据集了, 通过调节部分参数来进行数据集的产生的相似度.
原理
GAN 的基本原理其实非常简单,这里以生成图片为例进行说明。假设我们有两个网络,G(Generator) 和 D(Discriminator)。正如它的名字所暗示的那样, 它们的功能分别是:
G是一个生成图片的网络, 它接收一个随机的噪声(随机生成的图片)z, 通过这个噪声生成图片,记做G(z). D是一个判别网络, 判别一张图片是不是“真实的”。它的输入参数是x, x代表一张图片,输出D(x)代表x为真实图片的概率,如果为>0.5,就代表是真实(相似)的图片,反之,就代表不是真实的图片。

我们通过一个假产品宣传的例子来理解:
首先, 我们来定义一下角色.
- 进行宣传的
'专家'(生成网络) - 正在听讲的
'我们'(判别网络)
'专家'的手里面拿着一堆高仿的产品, 正在进行宣讲, 我们是熟知真品的相关信息的, 通过去对比两个产品之间的差距, 来判断是赝品的可能性.
这时, 我们就可以引出来一个概念, 如果'专家'团队比较厉害, 完美的仿造了我们的判断依据, 比如说产出方, 发明日期, 说明文等等, 那么我们就会觉得他是真的, 那么他就是一个好的生成网络, 反之, 我们会判断他是赝品.
从我们(判别网络)出发, 我们的判断条件越苛刻, 赝品和真品之间的差距会越来越小, 这样的最后的产出就是真假难分, 完全被模仿了.
相关资源
Generative Adversarial Networks, 深层的原理推荐大家可以去阅读以下这一篇论文.
* 损失函数等相关细节我们在实现里在讲
实现
接下来我们就以 mnist来实现 GAN吧.
- 首先, 我们先下载数据集.
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./Dataset/datasets/MNIST_data', one_hot=False)
我们通过tensorflow去下载mnist的数据集, 然后加载到内存, one-hot参数决定我们的label是否要经过编码(mnist 数据集是有 10 个类别), 但是我们判别网络是对比真实的和生成的之间的区别以及相似的可能性, 所以不需要执行one-hot编码了.
这里读取出来的图片已经归一化到[0, 1]之间了.
- 俗话说, 知己知彼, 百战百胜, 那我们拿到数据集, 就先来看看它长什么样.
def show_images(images):
images = np.reshape(images, [images.shape[0], -1]) # images reshape to (batch_size, D)
sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
fig = plt.figure(figsize=(sqrtn, sqrtn))
gs = gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace=0.05, hspace=0.05)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape([sqrtimg,sqrtimg]))

这里有一个小问题, 如果是在Notebook中执行, 记得加上这句话, 否则需要执行两次才会绘制.
%matplotlib inline
- 数据看过了, 我们该对它进行一定的处理了, 这里我们只是将数据缩放到[-1, 1]之间.
def preprocess_img(x):
return 2 * x - 1.0
def deprocess_img(x):
return (x + 1.0) / 2.0
- 数据处理完了, 接下来我们要开始搭建模型了, 这一部分我们有两个模型, 一个生成网络, 一个判别网络.
生成网络
def generator(z):
with tf.variable_scope("generator"):
fc1 = tf.layers.dense(inputs=z, units=1024, activation=tf.nn.relu)
bn1 = tf.layers.batch_normalization(inputs=fc1, training=True)
fc2 = tf.layers.dense(inputs=bn1, units=7*7*128, activation=tf.nn.relu)
bn2 = tf.layers.batch_normalization(inputs=fc2, training=True)
reshaped = tf.reshape(bn2, shape=[-1, 7, 7, 128])
conv_transpose1 = tf.layers.conv2d_transpose(inputs=reshaped, filters=64, kernel_size=4, strides=2, activation=tf.nn.relu,
padding='same')
bn3 = tf.layers.batch_normalization(inputs=conv_transpose1, training=True)
conv_transpose2 = tf.layers.conv2d_transpose(inputs=bn3, filters=1, kernel_size=4, strides=2, activation=tf.nn.tanh,
padding='same')
img = tf.reshape(conv_transpose2, shape=[-1, 784])
return img
判别网络
def discriminator(x):
with tf.variable_scope("discriminator"):
unflatten = tf.reshape(x, shape=[-1, 28, 28, 1])
conv1 = tf.layers.conv2d(inputs=unflatten, kernel_size=5, strides=1, filters=32 ,activation=leaky_relu)
maxpool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=2, strides=2)
conv2 = tf.layers.conv2d(inputs=maxpool1, kernel_size=5, strides=1, filters=64,activation=leaky_relu)
maxpool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=2, strides=2)
flatten = tf.reshape(maxpool2, shape=[-1, 1024])
fc1 = tf.layers.dense(inputs=flatten, units=1024, activation=leaky_relu)
logits = tf.layers.dense(inputs=fc1, units=1)
return logits
激活函数我们使用了leaky_relu, 他的代码实现是
def leaky_relu(x, alpha=0.01):
activation = tf.maximum(x,alpha*x)
return activation
它和 relu的区别就是, 小于 0 的值也会给与一点小的权重进行保留.
- 建立损失函数
def gan_loss(logits_real, logits_fake):
# Target label vector for generator loss and used in discriminator loss.
true_labels = tf.ones_like(logits_fake)
# DISCRIMINATOR loss has 2 parts: how well it classifies real images and how well it
# classifies fake images.
real_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_real, labels=true_labels)
fake_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=1-true_labels)
# Combine and average losses over the batch
D_loss = real_image_loss + fake_image_loss
D_loss = tf.reduce_mean(D_loss)
# GENERATOR is trying to make the discriminator output 1 for all its images.
# So we use our target label vector of ones for computing generator loss.
G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=true_labels)
# Average generator loss over the batch.
G_loss = tf.reduce_mean(G_loss)
return D_loss, G_loss
损失我们分为两部分, 一部分是生成网络的, 一部分是判别网络的.
生成网络的损失定义为, 生成图像的类别与真实标签(全是 1)的交叉熵损失.
$$ loss_{generate} = -\sum_i^n(Y_ilogG_i + (1 - Y_i)logG_i) $$
判别网络的损失定义为, 我们将真实图片的标签设置为 1, 生成图片的标签设置为 0, 然后由真实图片的输出以及生成图片的输出的交叉熵损失和.
T: True, G: Generate
生成损失
$$ loss_{G} = -\sum_i^n(Y_ilogG_i + (1 - Y_i)logG_i) $$
真实图片损失
$$ loss_{T} = -\sum_i^n(Y_ilogT_i + (1 - Y_i)logT_i) $$
总损失
$$ Loss_{D} = loss_{G} + loss_{T} $$
- 训练
def run_a_gan(sess, G_train_step, G_loss, D_train_step, D_loss, G_extra_step, D_extra_step,\
show_every=250, print_every=50, batch_size=128, num_epoch=10):
# compute the number of iterations we need
max_iter = int(mnist.train.num_examples*num_epoch/batch_size)
for it in range(max_iter):
# every show often, show a sample result
if it % show_every == 0:
samples = sess.run(G_sample)
fig = show_images(samples[:16])
plt.show()
print()
# run a batch of data through the network
minibatch,minbatch_y = mnist.train.next_batch(batch_size)
_1, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={x: minibatch})
_2, G_loss_curr = sess.run([G_train_step, G_loss])
if it % show_every == 0:
print(_1,_2)
# print loss every so often.
# We want to make sure D_loss doesn't go to 0
if it % print_every == 0:
print('Iter: {}, D: {:.4}, G:{:.4}'.format(it,D_loss_curr,G_loss_curr))
print('Final images')
samples = sess.run(G_sample)
fig = show_images(samples[:16])
plt.show()
这里就是开始训练了, 并展示训练的结果.
- 查看结果
刚开始的时候, 还没学会怎么模仿.

经过学习改进.


项目地址
声明
该文章参考了知乎的文章.
手把手带你用 Python 写一个 GAN
我来带你手搓一个简单的GAN(生成对抗网络),用MNIST数据集生成手写数字。这个代码可以直接运行,用的是PyTorch框架。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 超参数
latent_dim = 100
batch_size = 128
epochs = 50
learning_rate = 0.0002
# 数据加载
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(
root='./data', train=True, transform=transform, download=True
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset, batch_size=batch_size, shuffle=True
)
# 生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 28*28),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), 1, 28, 28)
return img
# 判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(28*28, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
flattened = img.view(img.size(0), -1)
validity = self.model(flattened)
return validity
# 初始化模型
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# 损失函数和优化器
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)
# 训练循环
for epoch in range(epochs):
for i, (imgs, _) in enumerate(train_loader):
batch_size = imgs.shape[0]
# 真实和假标签
real = torch.ones(batch_size, 1).to(device)
fake = torch.zeros(batch_size, 1).to(device)
# 真实图像
real_imgs = imgs.to(device)
# ============ 训练判别器 ============
optimizer_D.zero_grad()
# 真实图像的损失
real_loss = adversarial_loss(discriminator(real_imgs), real)
# 生成假图像
z = torch.randn(batch_size, latent_dim).to(device)
gen_imgs = generator(z)
# 假图像的损失
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
# 总判别器损失
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# ============ 训练生成器 ============
optimizer_G.zero_grad()
# 生成器想让判别器把假图像判断为真
z = torch.randn(batch_size, latent_dim).to(device)
gen_imgs = generator(z)
g_loss = adversarial_loss(discriminator(gen_imgs), real)
g_loss.backward()
optimizer_G.step()
# 打印进度
if i % 200 == 0:
print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(train_loader)}] "
f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
# 每个epoch保存生成的图像
with torch.no_grad():
z = torch.randn(16, latent_dim).to(device)
gen_imgs = generator(z).cpu()
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for idx, ax in enumerate(axes.flat):
ax.imshow(gen_imgs[idx].squeeze(), cmap='gray')
ax.axis('off')
plt.savefig(f'gan_generated_epoch_{epoch}.png')
plt.close()
print("训练完成!")
# 生成最终图像
with torch.no_grad():
z = torch.randn(64, latent_dim).to(device)
gen_imgs = generator(z).cpu()
fig, axes = plt.subplots(8, 8, figsize=(12, 12))
for idx, ax in enumerate(axes.flat):
ax.imshow(gen_imgs[idx].squeeze(), cmap='gray')
ax.axis('off')
plt.tight_layout()
plt.show()
这个GAN的核心思想是让生成器和判别器互相博弈:
- 生成器:接收随机噪声,生成假图像,目标是骗过判别器
- 判别器:判断输入图像是真实的还是生成的,目标是准确区分
训练过程分两步:
- 先固定生成器,训练判别器识别真假图像
- 再固定判别器,训练生成器生成更逼真的图像
代码里用了简单的全连接网络,你可以换成卷积网络(DCGAN)效果会更好。训练时注意观察损失值,如果判别器损失降得太快而生成器损失很高,说明训练不平衡,可以调整学习率或网络结构。
简单说就是:生成器造假币,判别器是验钞机,两者互相提升。

