醋醋百科网

Good Luck To You!

PyTorch 深度学习实战(6):生成对抗网络(GAN)与图像生成

在前几篇文章中,我们学习了如何使用卷积神经网络(CNN)和迁移学习解决图像分类问题。本文将介绍一种全新的深度学习模型——生成对抗网络(Generative Adversarial Network, GAN),并展示如何使用 GAN 生成逼真的图像。


一、生成对抗网络简介

生成对抗网络是由 Ian Goodfellow 等人于 2014 年提出的一种生成模型。它的核心思想是通过两个神经网络的对抗训练来生成数据:

  1. 生成器(Generator):生成虚假数据(如图像)。
  2. 判别器(Discriminator):区分真实数据和生成器生成的虚假数据。

1. GAN 的训练过程

  • 生成器试图生成越来越逼真的数据,以欺骗判别器。
  • 判别器试图区分真实数据和生成器生成的虚假数据。
  • 两者通过对抗训练共同提升性能。

2. GAN 的应用

  • 图像生成(如人脸、风景)。
  • 图像修复(如去噪、补全)。
  • 风格迁移(如将照片转换为油画风格)。

二、使用 GAN 生成手写数字图像

我们将使用 PyTorch 构建一个简单的 GAN 模型,并在 MNIST 数据集上训练生成器生成手写数字图像。

1. 实现步骤

  1. 加载和预处理数据。
  2. 定义生成器和判别器。
  3. 定义损失函数和优化器。
  4. 训练 GAN 模型。
  5. 可视化生成结果。

2. 代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# 设置 Matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置字体为 SimHei(黑体)
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 1. 加载和预处理数据
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化到 [-1, 1]
])

# 下载并加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)

# 2. 定义生成器
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()  # 输出范围 [-1, 1]
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(-1, 1, 28, 28)
        return img

# 3. 定义判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 输出范围 [0, 1]
        )

    def forward(self, img):
        img_flat = img.view(-1, 28 * 28)
        validity = self.model(img_flat)
        return validity

# 4. 初始化模型、损失函数和优化器
latent_dim = 100
generator = Generator(latent_dim)
discriminator = Discriminator()

criterion = nn.BCELoss()  # 二分类交叉熵损失
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 5. 训练 GAN 模型
num_epochs = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(train_loader):
        # 将数据移动到设备
        real_imgs = imgs.to(device)
        
        # ---------------------
        #  训练判别器
        # ---------------------
        optimizer_D.zero_grad()
        
        # 生成随机噪声
        z = torch.randn(imgs.size(0), latent_dim).to(device)
        fake_imgs = generator(z)
        
        # 计算判别器损失
        real_loss = criterion(discriminator(real_imgs), torch.ones(imgs.size(0), 1).to(device))
        fake_loss = criterion(discriminator(fake_imgs.detach()), torch.zeros(imgs.size(0), 1).to(device))
        d_loss = real_loss + fake_loss
        
        # 反向传播并更新参数
        d_loss.backward()
        optimizer_D.step()
        
        # ---------------------
        #  训练生成器
        # ---------------------
        optimizer_G.zero_grad()
        
        # 生成随机噪声
        z = torch.randn(imgs.size(0), latent_dim).to(device)
        fake_imgs = generator(z)
        
        # 计算生成器损失
        g_loss = criterion(discriminator(fake_imgs), torch.ones(imgs.size(0), 1).to(device))
        
        # 反向传播并更新参数
        g_loss.backward()
        optimizer_G.step()
        
        # 打印训练信息
        if (i + 1) % 100 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], "
                  f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
    
    # 每个 epoch 结束后生成一些图像
    with torch.no_grad():
        z = torch.randn(16, latent_dim).to(device)
        fake_imgs = generator(z).cpu()
        fake_imgs = 0.5 * fake_imgs + 0.5  # 反标准化到 [0, 1]
        fake_imgs = fake_imgs.numpy()
        
        plt.figure(figsize=(4, 4))
        for j in range(16):
            plt.subplot(4, 4, j + 1)
            plt.imshow(fake_imgs[j, 0], cmap='gray')
            plt.axis('off')
        plt.suptitle(f"Epoch {epoch + 1}")
        plt.show()

三、代码解析

1.数据加载与预处理:

  • 使用 torchvision.datasets.MNIST 加载 MNIST 数据集。
  • 使用 transforms.Normalize 将图像标准化到 [-1, 1]。

2.生成器和判别器:

  • 生成器将随机噪声映射为 28x28 的图像。
  • 判别器将图像映射为一个标量,表示图像的真实性。

3.训练过程:

  • 交替训练判别器和生成器。
  • 使用二分类交叉熵损失函数和 Adam 优化器。

4.可视化生成结果:

  • 每个 epoch 结束后生成 16 张图像并可视化。

四、运行结果

运行上述代码后,你将看到以下输出:

  1. 训练过程中每 100 步打印一次判别器和生成器的损失值。
  2. 每个 epoch 结束后生成的手写数字图像。

随着训练的进行,生成器生成的图像会越来越逼真。


五、总结

本文介绍了生成对抗网络的基本概念,并使用 PyTorch 实现了一个简单的 GAN 模型来生成手写数字图像。通过对抗训练,生成器能够生成越来越逼真的图像。

在下一篇文章中,我们将学习如何使用循环神经网络(RNN)处理序列数据。敬请期待!


代码实例说明:

  • 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。
  • 如果你有 GPU,可以将模型和数据移动到 GPU 上运行,例如:generator = generator.to(device)discriminator = discriminator.to(device)

希望这篇文章能帮助你更好地理解生成对抗网络的原理和应用!如果有任何问题,欢迎在评论区留言讨论。

控制面板
您好,欢迎到访网站!
  查看权限
网站分类
最新留言