14 | 令人拍案叫绝的WGAN

简介

在DCGAN的基础上,介绍WGAN的原理和实现,并在LFW和CelebA两个数据集上进一步实践

问题

GAN一直面临以下问题和挑战

  • 训练困难,需要精心设计模型结构,并小心协调G和D的训练程度
  • G和D的损失函数无法指示训练过程,缺乏一个有意义的指标和生成图片的质量相关联
  • 模式崩坏(mode collapse),生成的图片虽然看起来像是真的,但是缺乏多样性

原理

相对于传统的GAN,WGAN只做了以下三点简单的改动

  • D最后一层去掉sigmoid
  • G和D的loss不取log(sigmoid_cross_entropy_with_logits
  • 每次更新D的参数之后,将其绝对值截断到不超过一个固定常数c,即gradient clipping(前作);或使用梯度惩罚,即gradient penalty(后作)

G的损失函数原本为

其导致的结果是,如果D训练得太好,G将学习不到有效的梯度

但是,如果D训练得不够好,G也学习不到有效的梯度

就像警察如果太厉害,便直接把小偷干掉了;但警察如果不厉害,就无法迫使小偷变得更厉害

因此以上损失函数导致GAN训练特别不稳定,需要小心协调G和D的训练程度

GAN的作者提出了G损失函数的另一个版本,即所谓的-logD trick

G需要最小化以上损失函数,等价于最小化以下损失函数

其中前者为KL散度(Kullback–Leibler Divergence)

后者为JS散度(Jensen-Shannon Divergence)

两者都可以用于衡量两个分布之间的距离,越小说明两个分布越相似

因此以上损失函数,一方面要减小KL散度,另一方面却要增大JS散度,一边拉近一边推远,从而导致训练不稳定

除此之外,KL散度的不对称性,导致对以下两种情况的不同惩罚

  • G生成了不真实的图片,即缺乏准确性,惩罚较高
  • G生成了和真实图片类似的图片,即缺乏多样性,惩罚较低

从而导致,G倾向于生成一些有把握但相似的图片,而不敢轻易地尝试去生成没把握的新图片,即所谓的mode collapse问题

WGAN所做的三点改动,解决了GAN训练困难和不稳定、mode collapse等问题,而且G的损失函数越小,对应生成的图片质量就越高

WGAN训练过程如下,gradient penalty使得D满足1-Lipschitz连续条件,详细原理和细节可以阅读相关论文进一步了解

论文中部分实验结果如下,WGAN虽然需要更长的训练时间,但收敛更加稳定

更重要的是,WGAN提供了一种更稳定的GAN框架。DCGAN中的G去掉Batch Normalization就会崩掉,但WGAN则没有这种限制

如果用Deep Convolutional结构实现WGAN,那么其结果和DCGAN差不多。但是在WGAN的框架下,可以用更深更复杂的网络实现G和D,例如ResNet(arxiv.org/abs/1512.03…),从而达到更好的生成效果

数据

还是之前使用过的两个人脸数据集

  • LFW:vis-www.cs.umass.edu/lfw/,Labeled Faces in the Wild,包括1680人共计超过1.3W张图片
  • CelebA:mmlab.ie.cuhk.edu.hk/projects/Ce…,CelebFaces Attributes Dataset,包括10177人共计超过20W张图片,并且每张图片还包括人脸的5个关键点位置和40个属性的01标注,例如是否有眼镜、帽子、胡子等

实现

加载库

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
%matplotlib inline
from imageio import imread, imsave, mimsave
import cv2
import glob
from tqdm import tqdm

选择数据集

dataset = 'lfw_new_imgs' # LFW
# dataset = 'celeba' # CelebA
images = glob.glob(os.path.join(dataset, '*.*')) 
print(len(images))

定义一些常量、网络输入、辅助函数

batch_size = 100
z_dim = 100
WIDTH = 64
HEIGHT = 64
LAMBDA = 10
DIS_ITERS = 3 # 5

OUTPUT_DIR = 'samples_' + dataset
if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)

X = tf.placeholder(dtype=tf.float32, shape=[batch_size, HEIGHT, WIDTH, 3], name='X')
noise = tf.placeholder(dtype=tf.float32, shape=[batch_size, z_dim], name='noise')
is_training = tf.placeholder(dtype=tf.bool, name='is_training')

def lrelu(x, leak=0.2):
    return tf.maximum(x, leak * x)

判别器部分,注意需要去掉Batch Normalization,否则会导致batch之间的相关性,从而影响gradient penalty的计算

def discriminator(image, reuse=None, is_training=is_training):
    momentum = 0.9
    with tf.variable_scope('discriminator', reuse=reuse):
        h0 = lrelu(tf.layers.conv2d(image, kernel_size=5, filters=64, strides=2, padding='same'))

        h1 = lrelu(tf.layers.conv2d(h0, kernel_size=5, filters=128, strides=2, padding='same'))

        h2 = lrelu(tf.layers.conv2d(h1, kernel_size=5, filters=256, strides=2, padding='same'))

        h3 = lrelu(tf.layers.conv2d(h2, kernel_size=5, filters=512, strides=2, padding='same'))

        h4 = tf.contrib.layers.flatten(h3)
        h4 = tf.layers.dense(h4, units=1)
        return h4

生成器部分

def generator(z, is_training=is_training):
    momentum = 0.9
    with tf.variable_scope('generator', reuse=None):
        d = 4
        h0 = tf.layers.dense(z, units=d * d * 512)
        h0 = tf.reshape(h0, shape=[-1, d, d, 512])
        h0 = tf.nn.relu(tf.contrib.layers.batch_norm(h0, is_training=is_training, decay=momentum))

        h1 = tf.layers.conv2d_transpose(h0, kernel_size=5, filters=256, strides=2, padding='same')
        h1 = tf.nn.relu(tf.contrib.layers.batch_norm(h1, is_training=is_training, decay=momentum))

        h2 = tf.layers.conv2d_transpose(h1, kernel_size=5, filters=128, strides=2, padding='same')
        h2 = tf.nn.relu(tf.contrib.layers.batch_norm(h2, is_training=is_training, decay=momentum))

        h3 = tf.layers.conv2d_transpose(h2, kernel_size=5, filters=64, strides=2, padding='same')
        h3 = tf.nn.relu(tf.contrib.layers.batch_norm(h3, is_training=is_training, decay=momentum))

        h4 = tf.layers.conv2d_transpose(h3, kernel_size=5, filters=3, strides=2, padding='same', activation=tf.nn.tanh, name='g')
        return h4

损失函数

g = generator(noise)
d_real = discriminator(X)
d_fake = discriminator(g, reuse=True)

loss_d_real = -tf.reduce_mean(d_real)
loss_d_fake = tf.reduce_mean(d_fake)
loss_g = -tf.reduce_mean(d_fake)
loss_d = loss_d_real + loss_d_fake

alpha = tf.random_uniform(shape=[batch_size, 1, 1, 1], minval=0., maxval=1.)
interpolates = alpha * X + (1 - alpha) * g
grad = tf.gradients(discriminator(interpolates, reuse=True), [interpolates])[0]
slop = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1]))
gp = tf.reduce_mean((slop - 1.) ** 2)
loss_d += LAMBDA * gp

vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]
vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]

优化函数

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    optimizer_d = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_d, var_list=vars_d)
    optimizer_g = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_g, var_list=vars_g)

读取图片的函数

def read_image(path, height, width):
    image = imread(path)
    h = image.shape[0]
    w = image.shape[1]

    if h > w:
        image = image[h // 2 - w // 2: h // 2 + w // 2, :, :]
    else:
        image = image[:, w // 2 - h // 2: w // 2 + h // 2, :]    

    image = cv2.resize(image, (width, height))
    return image / 255.

合成图片的函数

def montage(images):    
    if isinstance(images, list):
        images = np.array(images)
    img_h = images.shape[1]
    img_w = images.shape[2]
    n_plots = int(np.ceil(np.sqrt(images.shape[0])))
    if len(images.shape) == 4 and images.shape[3] == 3:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1, 3)) * 0.5
    elif len(images.shape) == 4 and images.shape[3] == 1:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1, 1)) * 0.5
    elif len(images.shape) == 3:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1)) * 0.5
    else:
        raise ValueError('Could not parse image shape of {}'.format(images.shape))
    for i in range(n_plots):
        for j in range(n_plots):
            this_filter = i * n_plots + j
            if this_filter < images.shape[0]:
                this_img = images[this_filter]
                m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
                  1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img
    return m

随机产生批数据的函数

def get_random_batch(nums):
    img_index = np.arange(len(images))
    np.random.shuffle(img_index)
    img_index = img_index[:nums]
    batch = np.array([read_image(images[i], HEIGHT, WIDTH) for i in img_index])
    batch = (batch - 0.5) * 2

    return batch

模型的训练

sess = tf.Session()
sess.run(tf.global_variables_initializer())
z_samples = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
samples = []
loss = {'d': [], 'g': []}

for i in tqdm(range(60000)):
    for j in range(DIS_ITERS):
        n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
        batch = get_random_batch(batch_size)
        _, d_ls = sess.run([optimizer_d, loss_d], feed_dict={X: batch, noise: n, is_training: True})

    _, g_ls = sess.run([optimizer_g, loss_g], feed_dict={X: batch, noise: n, is_training: True})

    loss['d'].append(d_ls)
    loss['g'].append(g_ls)

    if i % 500 == 0:
        print(i, d_ls, g_ls)
        gen_imgs = sess.run(g, feed_dict={noise: z_samples, is_training: False})
        gen_imgs = (gen_imgs + 1) / 2
        imgs = [img[:, :, :] for img in gen_imgs]
        gen_imgs = montage(imgs)
        plt.axis('off')
        plt.imshow(gen_imgs)
        imsave(os.path.join(OUTPUT_DIR, 'sample_%d.jpg' % i), gen_imgs)
        plt.show()
        samples.append(gen_imgs)

plt.plot(loss['d'], label='Discriminator')
plt.plot(loss['g'], label='Generator')
plt.legend(loc='upper right')
plt.savefig(os.path.join(OUTPUT_DIR, 'Loss.png'))
plt.show()
mimsave(os.path.join(OUTPUT_DIR, 'samples.gif'), samples, fps=10)

LFW人脸生成结果如下,和DCGAN相比更加稳定

CelebA人脸生成结果如下

保存模型,便于后续使用

saver = tf.train.Saver()
saver.save(sess, os.path.join(OUTPUT_DIR, 'wgan_' + dataset), global_step=60000)

在单机上使用模型生成人脸图片

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

batch_size = 100
z_dim = 100
# dataset = 'lfw_new_imgs'
dataset = 'celeba'

def montage(images):    
    if isinstance(images, list):
        images = np.array(images)
    img_h = images.shape[1]
    img_w = images.shape[2]
    n_plots = int(np.ceil(np.sqrt(images.shape[0])))
    if len(images.shape) == 4 and images.shape[3] == 3:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1, 3)) * 0.5
    elif len(images.shape) == 4 and images.shape[3] == 1:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1, 1)) * 0.5
    elif len(images.shape) == 3:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1)) * 0.5
    else:
        raise ValueError('Could not parse image shape of {}'.format(images.shape))
    for i in range(n_plots):
        for j in range(n_plots):
            this_filter = i * n_plots + j
            if this_filter < images.shape[0]:
                this_img = images[this_filter]
                m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
                  1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img
    return m

sess = tf.Session()
sess.run(tf.global_variables_initializer())

saver = tf.train.import_meta_graph(os.path.join('samples_' + dataset, 'wgan_' + dataset + '-60000.meta'))
saver.restore(sess, tf.train.latest_checkpoint('samples_' + dataset))
graph = tf.get_default_graph()
g = graph.get_tensor_by_name('generator/g/Tanh:0')
noise = graph.get_tensor_by_name('noise:0')
is_training = graph.get_tensor_by_name('is_training:0')

n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
gen_imgs = sess.run(g, feed_dict={noise: n, is_training: False})
gen_imgs = (gen_imgs + 1) / 2
imgs = [img[:, :, :] for img in gen_imgs]
gen_imgs = montage(imgs)
gen_imgs = np.clip(gen_imgs, 0, 1)
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(gen_imgs)
plt.show()

参考

Wasserstein GANs:arxiv.org/abs/1701.07…

Improved Training of Wasserstein GANs:arxiv.org/abs/1704.00…

令人拍案叫绝的Wasserstein GAN:zhuanlan.zhihu.com/p/25071913

results matching ""

    No results matching ""