从零教你写一个完整的GAN

    xiaoxiao2021-04-16  198

    导言

    啦啦啦,现今 GAN 算法可以算作 ML 领域下比较热门的一个方向。事实上,GAN 已经作为一种思想来渗透在 ML 的其余领域,从而做出了很多很 Amazing 的东西。比如结合卷积神经网络,可以用于生成图片。或者结合 NLP,可以生成特定风格的短句子。(比如川普风格的 twitter......)

    可惜的是,网络上很多老司机开 GAN 的车最后都翻了,大多只是翻译了一篇论文,一旦涉及算法实现部分就直接放开源的实现地址,而那些开源的东东,缺少了必要的引导,实在对于新手来说很是懵逼。所以兔子哥哥带着开好车,开稳车的心态,特定来带一下各位想入门 GAN 的其他小兔兔们来飞一会。

    GAN 的介绍与训练

    先来阐述一下 GAN 的基本做法,这里不摆公式,因为你听完后,该怎么搭建和怎么训练你心里应该有数了。

    首先,GAN 全称为 Generative Adversarial Nets(生成对抗网络), 其构成分为两部份:

    Generator(生成器),下文简称 G

    Discriminator(辨别器), 下文简称 D。

    在本文,为了方便小兔兔理解,使用一个较为简单,也是 GAN 论文提及到的例子,训练 G 生成符合指定均值和标准差的数据,在这里,我们指定 MEAN=4,STD=1.5 的高斯分布(正态分布)。

    这货的样子大概长这样

    下面是数据生成的代码:

    def sample_data(size, length=100): """ 随机mean=4 std=1.5的数据 :param size: :param length: :return: """ data = [] for _ in range(size): data.append(sorted(np.random.normal(4, 1.5, length))) return np.array(data)

    在生成高斯分布的数据后,我们还对数据进行了排序,这时因为排序后的训练会相对平滑。具体原因看这个 [Generative Adversarial Nets in TensorFlow (Part I)]

    这一段是生成噪音的代码,既然是噪音,那么我们只需要随机产生 0~1 的数据就好。

    def random_data(size, length=100): """ 随机生成数据 :param size: :param length: :return: """ data = [] for _ in range(size): x = np.random.random(length) data.append(x) return np.array(data)

    随机分布的数据长这样

    接下来就是开撸 GAN 了。

    首先的一点就是,我们需要确定 G, 和 D 的具体结构,这里因为本文的例子太过于入门级,并不需要使用到复杂的神经网络结构,比如卷积层和递归层,这里 G 和 D 只需全连接的神经网络就好。全连接层的神经网络本质就是矩阵的花式相乘。为神马说是花式相乘呢,因为大多数时候,我们在矩阵相乘的结果后面会添加不同的激活函数。

    G 和 D 分别为三层的全链接的神经网络,其中 G 的激活函数分别为,relu,sigmoid,liner,这里前两层只是因为考虑到数据的非线性转换,并没有什么特别选择这两个激活函数的原因。其次 D 的三层分别为 relu,sigmoid,sigmoid。

    接下来就引出 GAN 的训练问题。GAN 的思想源于博弈论,直白一点就是套路与反套路。

    先从 D 开始分析,D 作为辨别器,它的职责就是区分于真实的高斯分布和 G 生成的” 假” 高斯分布。所以很显然,针对 D 来说,其需要解决的就是传统的二分类问题。

    在二分类问题中,我们习惯用交叉熵来衡量分类效果。

    从公式中不难看出,在全部分类正确时,交叉熵会接近于 0,因此,我们的目标就是通过拟合 D 的参数来最小化交叉熵的值。

    D 既然是传统的二分类问题,那么 D 的训练过程也很容易得出了

    即先把真实数据标识为‘1’(真实分布),由生成器生成的数据标识为’0‘(生成分布),反复迭代训练 D   ------------ (1)

    说 G 的训练之前先来打个比方,假如一男一女在一起了,现在两人性格出现矛盾了,女生并不愿意改变,但两个人都想继续在一起,这时,唯一的方法就是男生改变了。先忽略现实生活的问题,但从举例的角度说,显然久而久之男生就会变得更加 fit 这个女生。

    G 的训练也是如此:

    先将 G 拼接在 D 的上方,即 G 的输出作为 D 的输入(男生女生在一起),而同时固定 D 的参数(女生不愿意改变),并将进入 G 的噪音样本标签全部改成'1'(目标是两个人继续在一起,没有其他选择),为了最小化损失函数,此时就只能改变 G 的每一层权重,反复迭代后 G 的生成能力因此得以改进。(男生更适合女生)  ------------ (2)

    反复迭代(1)(2),最终 G 就会得到较好的生成能力。

    补充一点,在训练 D 的时候,我曾把数据直接放进去,这样的后果是最后生成的数据,能学习到高斯分布的轮廓,但标准差和均值则和真实样本相差很大。因此,这里我建议直接使用平均值和标准差作为 D 的输入。

    这使得 D 在训练前需要对数据进行预处理。

    def preprocess_data(x): """ 计算每一组数据平均值和方差 :param x: :return: """ return [[np.mean(data), np.std(data)] for data in x]

    G 和 D 的连接之间也需要做出处理。

    # 先求出G_output3的各行平均值和方差 MEAN = tf.reduce_mean(G_output3, 1) # 平均值,但是是1D向量 MEAN_T = tf.transpose(tf.expand_dims(MEAN, 0)) # 转置 STD = tf.sqrt(tf.reduce_mean(tf.square(G_output3 - MEAN_T), 1)) DATA = tf.concat(1, [MEAN_T, tf.transpose(tf.expand_dims(STD, 0))] # 拼接起来

    以下是损失函数变化图:

    蓝色是 D 单独作二分类问题处理时的变化

    绿色是拼接 G 在 D 的上方后损失函数的变化

    不难看出,两者在经历反复震荡 (互相博弈而导致),最后稳定于 0.5 附近,这时我们可以认为,G 的生成能力已经达到以假乱真,D 再也不能分出真假。

    接下来的这个就是 D-G 博弈 200 次后的结果:

    绿色是真实分布

    蓝色是噪音原本的分布

    红色是生成分布

    后话

    兔子哥哥的车这次就开到这里了。作为一个大三且数学能力较为一般的学生, 从比较感性的角度来描述了一次 GAN 的基本过程,有说得不对地方请各位见谅和指点。

    如果各位读者需要更严格的数学公式和证明,可以阅读 GAN 的开山之作([1406.2661] Generative Adversarial Networks) , 另外本文提及的代码都可在这里找到(MashiMaroLjc/learn-GAN),有需要的童鞋也可以私信交流。

    这就是全部内容了,下次心情好的话怼 DCGAN,看看能不能生成点好玩的图片,嗯~ 睡觉去~

    本文作者:AI研习社 本文转自雷锋网禁止二次转载, 原文链接 相关资源:Matlab的简单GAN基础:基于matlab的简单的生成对抗网络-源码

    最新回复(0)