手把手教你用PyTorch实现GAN:从零生成动漫人物头像的完整教程
生成对抗网络(Generative Adversarial Network,GAN)是近年来深度学习领域最令人兴奋的突破之一。它由Ian Goodfellow在2014年提出,通过让两个神经网络相互博弈的方式,实现了高质量的图像生成。
本文将带领大家使用PyTorch实现一个完整的GAN模型,用于生成动漫人物头像。我们将使用来自Kaggle的Anime Faces数据集,该数据集包含大量预处理好的动漫人物头像图片。
GAN的工作原理
GAN包含两个核心组件:
- 生成器(Generator):接收随机噪声作为输入,试图生成逼真的图像来欺骗判别器
- 判别器(Discriminator):接收真实图像和生成图像,试图区分它们的真伪
这两个网络相互博弈:生成器不断学习如何生成更逼真的图像,而判别器不断提高自己的辨别能力。最终,生成器能够生成以假乱真的图像。
环境准备与数据集
首先,我们需要安装必要的依赖:
pip install torch torchvision matplotlib pillow
从Kaggle下载数据集后,解压到 ./data/Anime Faces 目录。
完整训练代码
下面是完整的训练代码,包含数据预处理、模型定义和训练过程:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import os
num_epochs = 100
batch_size = 64
learning_rate = 0.0002
latent_dim = 100
img_channels = 3
feature_maps = 64
anime_faces_folder = "./data/Anime Faces"
generator_model_path = "./model/anime_faces/generator.pth"
discriminator_model_path = "./model/anime_faces/discriminator.pth"
if not os.path.exists("./model/anime_faces"):
os.makedirs("./model/anime_faces")
class AnimeFacesDataset(Dataset):
def __init__(self, root, transform=None):
super().__init__()
self.root = root
self.transform = transform
self.files = []
self.labels = []
for file in os.listdir(root):
self.files.append(file)
self.labels.append(1)
def __len__(self):
return len(self.files)
def __getitem__(self, index):
path = os.path.join(self.root, self.files[index])
image = Image.open(path).convert("RGB")
label = self.labels[index]
if self.transform:
image = self.transform(image)
return image, label
transform = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
dataset = AnimeFacesDataset(root=anime_faces_folder, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Generator(nn.Module):
def __init__(self, latent_dim, img_channels, feature_maps):
super().__init__()
self.net = nn.Sequential(
nn.ConvTranspose2d(
latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0
),
nn.BatchNorm2d(feature_maps * 8),
nn.ReLU(),
nn.ConvTranspose2d(
feature_maps * 8, feature_maps * 4, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(feature_maps * 4),
nn.ReLU(),
nn.ConvTranspose2d(
feature_maps * 4, feature_maps * 2, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(feature_maps * 2),
nn.ReLU(),
nn.ConvTranspose2d(
feature_maps * 2, img_channels, kernel_size=4, stride=2, padding=1
),
nn.Tanh(),
)
def forward(self, x):
return self.net(x)
class Discriminator(nn.Module):
def __init__(self, img_channels, feature_maps):
super().__init__()
self.net = nn.Sequential(
# 输入: (3, 32, 32)
nn.Conv2d(img_channels, feature_maps, 4, 2, 1), # 输出: (64, 16, 16)
nn.LeakyReLU(0.2),
nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1), # 输出: (128, 8, 8)
nn.BatchNorm2d(feature_maps * 2),
nn.LeakyReLU(0.2),
nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1), # 输出: (256, 4, 4)
nn.BatchNorm2d(feature_maps * 4),
nn.LeakyReLU(0.2),
nn.Conv2d(feature_maps * 4, 1, 4, 1, 0), # 输出: (1, 1, 1)
nn.Sigmoid(),
nn.Flatten(),
)
def forward(self, x):
return self.net(x)
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
generator = (
Generator(latent_dim, img_channels, feature_maps).apply(weights_init).to(device)
)
discriminator = Discriminator(img_channels, feature_maps).apply(weights_init).to(device)
if os.path.exists(generator_model_path):
generator.load_state_dict(torch.load(generator_model_path))
if os.path.exists(discriminator_model_path):
discriminator.load_state_dict(torch.load(discriminator_model_path))
criterion = nn.MSELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(
discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999)
)
for epoch in range(num_epochs):
for i, (real_imgs, _) in enumerate(dataloader):
real_imgs = real_imgs.to(device)
batch_size = real_imgs.size(0)
# 训练判别器
optimizer_D.zero_grad()
# 真实图像标签为1
real_labels = torch.full((batch_size, 1), 0.9, device=device)
real_output = discriminator(real_imgs)
d_loss_real = criterion(real_output, real_labels)
# 生成假图像
z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
fake_imgs = generator(z)
# 假图像标签为0
fake_labels = torch.full((batch_size, 1), 0.1, device=device)
fake_output = discriminator(fake_imgs.detach())
d_loss_fake = criterion(fake_output, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_D.step()
# 训练生成器两次
for _ in range(2):
optimizer_G.zero_grad()
# 重新生成假图像以增加多样性
z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
fake_imgs = generator(z)
output = discriminator(fake_imgs)
# 生成器目标为真实标签
g_loss = criterion(output, torch.ones_like(output))
g_loss.backward()
optimizer_G.step()
if i % 100 == 0:
print(
f"Epoch [{epoch + 1}/{num_epochs}] Batch {i}/{len(dataloader)} "
f"Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}"
)
torch.save(generator.state_dict(), generator_model_path)
torch.save(discriminator.state_dict(), discriminator_model_path)
# 生成随机噪声
z = torch.randn(1, latent_dim, 1, 1).to(device)
# 生成图像
generated_img = generator(z).detach().cpu()
# 反归一化处理
generated_img = generated_img.squeeze().permute(1, 2, 0) * 0.5 + 0.5
# 可视化
plt.imshow(generated_img)
plt.axis("off")
plt.show()
代码解释
数据集类:AnimeFacesDataset 负责加载图像并进行预处理
数据增强:使用了随机水平翻转、颜色抖动等技术,增强模型泛化能力
生成器架构:使用转置卷积层逐步将100维噪声向量上采样到32×32×3的图像
判别器架构:使用卷积层逐步下采样图像,最终输出一个标量表示真假概率
训练技巧:
- 使用标签平滑(0.9和0.1而不是1和0)
- 生成器每次迭代训练两次,增加生成频率
- 使用Adam优化器和合适的beta值
图像生成代码
训练完成后,我们可以使用保存的模型来生成新的动漫头像:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os
latent_dim = 100
img_channels = 3
feature_maps = 64
generator_model_path = "./model/anime_faces/generator.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Generator(nn.Module):
def __init__(self, latent_dim, img_channels, feature_maps):
super().__init__()
self.net = nn.Sequential(
nn.ConvTranspose2d(
latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0
),
nn.BatchNorm2d(feature_maps * 8),
nn.ReLU(),
nn.ConvTranspose2d(
feature_maps * 8, feature_maps * 4, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(feature_maps * 4),
nn.ReLU(),
nn.ConvTranspose2d(
feature_maps * 4, feature_maps * 2, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(feature_maps * 2),
nn.ReLU(),
nn.ConvTranspose2d(
feature_maps * 2, img_channels, kernel_size=4, stride=2, padding=1
),
nn.Tanh(),
)
def forward(self, x):
return self.net(x)
generator = Generator(latent_dim, img_channels, feature_maps).to(device)
if os.path.exists(generator_model_path):
generator.load_state_dict(torch.load(generator_model_path))
rows, cols = 4, 4
figure = plt.figure()
for row in range(rows):
for col in range(cols):
z = torch.randn(1, latent_dim, 1, 1).to(device)
generated_img = generator(z).detach().cpu()
generated_img = generated_img.squeeze().permute(1, 2, 0) * 0.5 + 0.5
index = row * cols + col
plot = figure.add_subplot(rows, cols, index + 1)
plot.imshow(generated_img)
plot.axis("off")
plt.show()
生成效果展示
经过100个epoch的训练,我们的模型能够生成多样化的动漫人物头像:

从上图可以看出,生成的动漫头像具有以下特点:
- 多样的发色和发型
- 不同的面部表情
- 相对清晰的五官轮廓
- 符合动漫风格的色彩搭配
优化建议
如果你想进一步提升生成质量,可以尝试以下方法:
- 增加模型复杂度:增加更多的卷积层或使用更大的特征图
- 调整训练策略:使用WGAN-GP损失函数,改善训练稳定性
- 提高图像分辨率:逐步从低分辨率训练到高分辨率(渐进式GAN)
- 使用更好的数据增强:如随机裁剪、旋转等
- 调整超参数:学习率、批次大小、噪声维度等
总结
本文实现了一个完整的GAN模型,用于生成动漫人物头像。我们详细介绍了:
- GAN的基本原理和架构
- PyTorch中的数据加载和预处理方法
- 生成器和判别器的具体实现
- 训练过程中的技巧和注意事项
- 如何使用训练好的模型生成新图像
通过这个项目,你不仅掌握了GAN的核心概念,还学会了如何使用PyTorch实现一个完整的深度学习项目。你可以将同样的技术应用到其他图像生成任务中,如风景图生成、人脸生成等。
GAN是一个非常活跃的研究领域,还有许多变体和改进版本值得探索,比如DCGAN、StyleGAN、CycleGAN等。希望这篇文章能为你的深度学习之旅提供一个良好的起点!