今日目标
o 理解GAN的基本概念和架构
o 掌握生成器和判别器的设计
o 学会训练GAN的平衡策略
o 了解不同类型的GAN变体
o 掌握GAN在图像生成中的应用
GAN概述
生成对抗网络(GAN)是一种深度学习架构,通过两个神经网络相互对抗来生成数据:
o 生成器(Generator):生成假数据,试图欺骗判别器
o 判别器(Discriminator):区分真实数据和生成数据
o 对抗训练:两个网络相互竞争,共同提升
GAN应用领域
# 主要应用领域:
# - 图像生成:人脸、风景、艺术作品
# - 图像编辑:风格转换、超分辨率
# - 数据增强:生成训练数据
# - 文本生成:对话系统、文本创作
# - 音频生成:音乐、语音合成
GAN基础
1. 安装和导入
pip install torch torchvision numpy matplotlib seaborn
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings('ignore')
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 检查设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
2. 简单GAN实现
def simple_gan():
"""简单GAN实现"""
# 生成器网络
class Generator(nn.Module):
def __init__(self, latent_dim, output_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, output_dim),
nn.Tanh()
)
def forward(self, z):
return self.model(z)
# 判别器网络
class Discriminator(nn.Module):
def __init__(self, input_dim):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# 参数设置
latent_dim = 100
data_dim = 2 # 生成2D数据
batch_size = 64
num_epochs = 1000
# 创建网络
generator = Generator(latent_dim, data_dim).to(device)
discriminator = Discriminator(data_dim).to(device)
# 优化器
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 损失函数
criterion = nn.BCELoss()
# 生成真实数据(高斯分布)
def generate_real_data(n_samples):
return torch.randn(n_samples, data_dim).to(device)
# 生成噪声
def generate_noise(n_samples):
return torch.randn(n_samples, latent_dim).to(device)
# 训练记录
g_losses = []
d_losses = []
print("开始训练简单GAN...")
for epoch in range(num_epochs):
# 训练判别器
d_optimizer.zero_grad()
# 真实数据
real_data = generate_real_data(batch_size)
real_labels = torch.ones(batch_size, 1).to(device)
d_real_output = discriminator(real_data)
d_real_loss = criterion(d_real_output, real_labels)
# 生成数据
noise = generate_noise(batch_size)
fake_data = generator(noise)
fake_labels = torch.zeros(batch_size, 1).to(device)
d_fake_output = discriminator(fake_data.detach())
d_fake_loss = criterion(d_fake_output, fake_labels)
# 判别器总损失
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
d_optimizer.step()
# 训练生成器
g_optimizer.zero_grad()
# 生成器试图欺骗判别器
g_fake_output = discriminator(fake_data)
g_loss = criterion(g_fake_output, real_labels)
g_loss.backward()
g_optimizer.step()
# 记录损失
g_losses.append(g_loss.item())
d_losses.append(d_loss.item())
if (epoch + 1) % 100 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], '
f'D Loss: {d_loss.item():.4f}, '
f'G Loss: {g_loss.item():.4f}')
# 可视化训练过程
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(g_losses, label='生成器损失')
plt.plot(d_losses, label='判别器损失')
plt.title('GAN训练损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.grid(True, alpha=0.3)
# 可视化生成结果
plt.subplot(1, 2, 2)
with torch.no_grad():
noise = generate_noise(1000)
fake_data = generator(noise).cpu().numpy()
real_data = generate_real_data(1000).cpu().numpy()
plt.scatter(real_data[:, 0], real_data[:, 1], alpha=0.6, label='真实数据')
plt.scatter(fake_data[:, 0], fake_data[:, 1], alpha=0.6, label='生成数据')
plt.title('生成结果对比')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
return generator, discriminator, g_losses, d_losses
# 运行简单GAN示例
simple_generator, simple_discriminator, g_losses, d_losses = simple_gan()
DCGAN实现
1. 深度卷积GAN
def dcgan_implementation():
"""DCGAN实现"""
# 生成器网络
class DCGANGenerator(nn.Module):
def __init__(self, latent_dim, channels=1):
super(DCGANGenerator, self).__init__()
self.latent_dim = latent_dim
self.main = nn.Sequential(
# 输入: latent_dim x 1 x 1
nn.ConvTranspose2d(latent_dim, 256, 4, 1, 0, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# 状态: 256 x 4 x 4
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# 状态: 128 x 8 x 8
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# 状态: 64 x 16 x 16
nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
nn.Tanh()
# 输出: channels x 32 x 32
)
def forward(self, z):
return self.main(z)
# 判别器网络
class DCGANDiscriminator(nn.Module):
def __init__(self, channels=1):
super(DCGANDiscriminator, self).__init__()
self.main = nn.Sequential(
# 输入: channels x 32 x 32
nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 状态: 64 x 16 x 16
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# 状态: 128 x 8 x 8
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# 状态: 256 x 4 x 4
nn.Conv2d(256, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
# 输出: 1 x 1 x 1
)
def forward(self, x):
return self.main(x)
# 参数设置
latent_dim = 100
image_size = 32
batch_size = 64
num_epochs = 50
# 创建网络
generator = DCGANGenerator(latent_dim).to(device)
discriminator = DCGANDiscriminator().to(device)
# 优化器
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 损失函数
criterion = nn.BCELoss()
# 生成噪声
def generate_noise(batch_size):
return torch.randn(batch_size, latent_dim, 1, 1).to(device)
# 生成真实数据(随机噪声作为示例)
def generate_real_data(batch_size):
return torch.randn(batch_size, 1, image_size, image_size).to(device)
# 训练记录
g_losses = []
d_losses = []
print("开始训练DCGAN...")
for epoch in range(num_epochs):
# 训练判别器
d_optimizer.zero_grad()
# 真实数据
real_data = generate_real_data(batch_size)
real_labels = torch.ones(batch_size, 1, 1, 1).to(device)
d_real_output = discriminator(real_data)
d_real_loss = criterion(d_real_output, real_labels)
# 生成数据
noise = generate_noise(batch_size)
fake_data = generator(noise)
fake_labels = torch.zeros(batch_size, 1, 1, 1).to(device)
d_fake_output = discriminator(fake_data.detach())
d_fake_loss = criterion(d_fake_output, fake_labels)
# 判别器总损失
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
d_optimizer.step()
# 训练生成器
g_optimizer.zero_grad()
# 生成器试图欺骗判别器
g_fake_output = discriminator(fake_data)
g_loss = criterion(g_fake_output, real_labels)
g_loss.backward()
g_optimizer.step()
# 记录损失
g_losses.append(g_loss.item())
d_losses.append(d_loss.item())
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], '
f'D Loss: {d_loss.item():.4f}, '
f'G Loss: {g_loss.item():.4f}')
# 可视化训练过程
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(g_losses, label='生成器损失')
plt.plot(d_losses, label='判别器损失')
plt.title('DCGAN训练损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.grid(True, alpha=0.3)
# 可视化生成结果
plt.subplot(1, 2, 2)
with torch.no_grad():
noise = generate_noise(16)
fake_images = generator(noise).cpu()
# 显示生成的图像
fake_grid = torchvision.utils.make_grid(fake_images, nrow=4, normalize=True)
plt.imshow(fake_grid.permute(1, 2, 0), cmap='gray')
plt.title('DCGAN生成结果')
plt.axis('off')
plt.tight_layout()
plt.show()
return generator, discriminator, g_losses, d_losses
# 运行DCGAN示例
dcgan_generator, dcgan_discriminator, dcgan_g_losses, dcgan_d_losses = dcgan_implementation()
GAN变体
1. WGAN实现
def wgan_implementation():
"""Wasserstein GAN实现"""
# 生成器网络
class WGANGenerator(nn.Module):
def __init__(self, latent_dim, output_dim):
super(WGANGenerator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, output_dim),
nn.Tanh()
)
def forward(self, z):
return self.model(z)
# 判别器网络(WGAN中称为Critic)
class WGANCritic(nn.Module):
def __init__(self, input_dim):
super(WGANCritic, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 1)
# 注意:WGAN不使用Sigmoid
)
def forward(self, x):
return self.model(x)
# 参数设置
latent_dim = 100
data_dim = 2
batch_size = 64
num_epochs = 1000
critic_iterations = 5 # 每训练一次生成器,训练多次判别器
clip_value = 0.01
# 创建网络
generator = WGANGenerator(latent_dim, data_dim).to(device)
critic = WGANCritic(data_dim).to(device)
# 优化器
g_optimizer = optim.RMSprop(generator.parameters(), lr=0.00005)
c_optimizer = optim.RMSprop(critic.parameters(), lr=0.00005)
# 生成数据函数
def generate_real_data(n_samples):
return torch.randn(n_samples, data_dim).to(device)
def generate_noise(n_samples):
return torch.randn(n_samples, latent_dim).to(device)
# 权重裁剪函数
def clip_weights(model, clip_value):
for param in model.parameters():
param.data.clamp_(-clip_value, clip_value)
# 训练记录
g_losses = []
c_losses = []
print("开始训练WGAN...")
for epoch in range(num_epochs):
# 训练判别器多次
for _ in range(critic_iterations):
c_optimizer.zero_grad()
# 真实数据
real_data = generate_real_data(batch_size)
real_scores = critic(real_data)
# 生成数据
noise = generate_noise(batch_size)
fake_data = generator(noise)
fake_scores = critic(fake_data)
# Wasserstein损失
c_loss = -torch.mean(real_scores) + torch.mean(fake_scores)
c_loss.backward()
c_optimizer.step()
# 权重裁剪
clip_weights(critic, clip_value)
# 训练生成器
g_optimizer.zero_grad()
noise = generate_noise(batch_size)
fake_data = generator(noise)
fake_scores = critic(fake_data)
g_loss = -torch.mean(fake_scores)
g_loss.backward()
g_optimizer.step()
# 记录损失
g_losses.append(g_loss.item())
c_losses.append(c_loss.item())
if (epoch + 1) % 100 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], '
f'C Loss: {c_loss.item():.4f}, '
f'G Loss: {g_loss.item():.4f}')
# 可视化训练过程
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(g_losses, label='生成器损失')
plt.plot(c_losses, label='判别器损失')
plt.title('WGAN训练损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.grid(True, alpha=0.3)
# 可视化生成结果
plt.subplot(1, 2, 2)
with torch.no_grad():
noise = generate_noise(1000)
fake_data = generator(noise).cpu().numpy()
real_data = generate_real_data(1000).cpu().numpy()
plt.scatter(real_data[:, 0], real_data[:, 1], alpha=0.6, label='真实数据')
plt.scatter(fake_data[:, 0], fake_data[:, 1], alpha=0.6, label='生成数据')
plt.title('WGAN生成结果对比')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
return generator, critic, g_losses, c_losses
# 运行WGAN示例
wgan_generator, wgan_critic, wgan_g_losses, wgan_c_losses = wgan_implementation()
GAN应用
1. 图像生成应用
def gan_applications():
"""GAN应用示例"""
# 1. 条件GAN示例
class ConditionalGenerator(nn.Module):
def __init__(self, latent_dim, num_classes, output_dim):
super(ConditionalGenerator, self).__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes
self.model = nn.Sequential(
nn.Linear(latent_dim + num_classes, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, output_dim),
nn.Tanh()
)
def forward(self, z, labels):
# 将噪声和标签连接
z = torch.cat([z, labels], dim=1)
return self.model(z)
# 2. 生成不同类别的数据
def generate_conditional_data(generator, num_classes, samples_per_class=100):
"""生成条件数据"""
all_data = []
all_labels = []
for class_id in range(num_classes):
# 生成噪声
noise = torch.randn(samples_per_class, 100).to(device)
# 创建one-hot标签
labels = torch.zeros(samples_per_class, num_classes).to(device)
labels[:, class_id] = 1
# 生成数据
with torch.no_grad():
fake_data = generator(noise, labels).cpu().numpy()
all_data.append(fake_data)
all_labels.extend([class_id] * samples_per_class)
return np.vstack(all_data), np.array(all_labels)
# 3. 可视化条件生成结果
plt.figure(figsize=(15, 5))
# 创建条件生成器
cgan_generator = ConditionalGenerator(100, 3, 2).to(device)
# 生成不同类别的数据
fake_data, fake_labels = generate_conditional_data(cgan_generator, 3, 200)
# 可视化
colors = ['red', 'green', 'blue']
for i in range(3):
mask = fake_labels == i
plt.scatter(fake_data[mask, 0], fake_data[mask, 1],
c=colors[i], alpha=0.6, label=f'类别{i}')
plt.title('条件GAN生成结果')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
# 4. GAN评估指标
def calculate_fid_score(real_data, fake_data):
"""计算FID分数(简化版本)"""
# 计算真实数据和生成数据的均值和协方差
real_mean = np.mean(real_data, axis=0)
fake_mean = np.mean(fake_data, axis=0)
real_cov = np.cov(real_data.T)
fake_cov = np.cov(fake_data.T)
# 计算FID分数
mean_diff = real_mean - fake_mean
cov_diff = real_cov + fake_cov - 2 * np.sqrt(real_cov @ fake_cov)
fid = np.sum(mean_diff**2) + np.trace(cov_diff)
return fid
# 计算FID分数
real_data = torch.randn(1000, 2).numpy()
fake_data = torch.randn(1000, 2).numpy()
fid_score = calculate_fid_score(real_data, fake_data)
print(f"FID分数: {fid_score:.4f}")
return cgan_generator, fid_score
# 运行GAN应用示例
cgan_generator, fid_score = gan_applications()
今日总结
今天我们学习了生成对抗网络(GAN)的基础知识:
1. GAN基础:生成器、判别器、对抗训练
2. 简单GAN:全连接网络、BCE损失
3. DCGAN:卷积网络、图像生成
4. WGAN:Wasserstein距离、权重裁剪
5. GAN应用:条件生成、评估指标
GAN是生成模型的重要技术,掌握这些知识可以创建逼真的生成内容。