手把手教你用PyTorch实现GAN:从零生成动漫人物头像的完整教程

2025-03-10 人工智能 219 次阅读 0 次点赞
本文介绍了生成对抗网络(GAN)的基本原理,并使用PyTorch实现了一个用于生成动漫人物头像的完整模型。文章详细讲解了GAN的生成器和判别器结构,通过两者相互博弈的训练方式,生成器能生成逼真的图像。作者提供了数据集准备、模型定义、训练代码及图像生成代码,并使用了标签平滑、训练技巧等优化方法。训练100个epoch后,模型可生成发色、表情多样的动漫头像。文章还给出了提升生成质量的优化建议,适合初学者实践GAN项目。

生成对抗网络(Generative Adversarial Network,GAN)是近年来深度学习领域最令人兴奋的突破之一。它由Ian Goodfellow在2014年提出,通过让两个神经网络相互博弈的方式,实现了高质量的图像生成。

本文将带领大家使用PyTorch实现一个完整的GAN模型,用于生成动漫人物头像。我们将使用来自Kaggle的Anime Faces数据集,该数据集包含大量预处理好的动漫人物头像图片。

GAN的工作原理

GAN包含两个核心组件:

  • 生成器(Generator):接收随机噪声作为输入,试图生成逼真的图像来欺骗判别器
  • 判别器(Discriminator):接收真实图像和生成图像,试图区分它们的真伪

这两个网络相互博弈:生成器不断学习如何生成更逼真的图像,而判别器不断提高自己的辨别能力。最终,生成器能够生成以假乱真的图像。

环境准备与数据集

首先,我们需要安装必要的依赖:

pip install torch torchvision matplotlib pillow

从Kaggle下载数据集后,解压到 ./data/Anime Faces 目录。

完整训练代码

下面是完整的训练代码,包含数据预处理、模型定义和训练过程:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import os

num_epochs = 100
batch_size = 64
learning_rate = 0.0002
latent_dim = 100
img_channels = 3
feature_maps = 64

anime_faces_folder = "./data/Anime Faces"
generator_model_path = "./model/anime_faces/generator.pth"
discriminator_model_path = "./model/anime_faces/discriminator.pth"

if not os.path.exists("./model/anime_faces"):
    os.makedirs("./model/anime_faces")


class AnimeFacesDataset(Dataset):
    def __init__(self, root, transform=None):
        super().__init__()
        self.root = root
        self.transform = transform

        self.files = []
        self.labels = []

        for file in os.listdir(root):
            self.files.append(file)
            self.labels.append(1)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        path = os.path.join(self.root, self.files[index])
        image = Image.open(path).convert("RGB")
        label = self.labels[index]
        if self.transform:
            image = self.transform(image)
        return image, label


transform = transforms.Compose(
    [
        transforms.Resize((32, 32)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

dataset = AnimeFacesDataset(root=anime_faces_folder, transform=transform)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Generator(nn.Module):
    def __init__(self, latent_dim, img_channels, feature_maps):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(
                latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0
            ),
            nn.BatchNorm2d(feature_maps * 8),
            nn.ReLU(),
            nn.ConvTranspose2d(
                feature_maps * 8, feature_maps * 4, kernel_size=4, stride=2, padding=1
            ),
            nn.BatchNorm2d(feature_maps * 4),
            nn.ReLU(),
            nn.ConvTranspose2d(
                feature_maps * 4, feature_maps * 2, kernel_size=4, stride=2, padding=1
            ),
            nn.BatchNorm2d(feature_maps * 2),
            nn.ReLU(),
            nn.ConvTranspose2d(
                feature_maps * 2, img_channels, kernel_size=4, stride=2, padding=1
            ),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.net(x)


class Discriminator(nn.Module):
    def __init__(self, img_channels, feature_maps):
        super().__init__()
        self.net = nn.Sequential(
            # 输入: (3, 32, 32)
            nn.Conv2d(img_channels, feature_maps, 4, 2, 1),  # 输出: (64, 16, 16)
            nn.LeakyReLU(0.2),
            nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1),  # 输出: (128, 8, 8)
            nn.BatchNorm2d(feature_maps * 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1),  # 输出: (256, 4, 4)
            nn.BatchNorm2d(feature_maps * 4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(feature_maps * 4, 1, 4, 1, 0),  # 输出: (1, 1, 1)
            nn.Sigmoid(),
            nn.Flatten(),
        )

    def forward(self, x):
        return self.net(x)


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


generator = (
    Generator(latent_dim, img_channels, feature_maps).apply(weights_init).to(device)
)
discriminator = Discriminator(img_channels, feature_maps).apply(weights_init).to(device)

if os.path.exists(generator_model_path):
    generator.load_state_dict(torch.load(generator_model_path))

if os.path.exists(discriminator_model_path):
    discriminator.load_state_dict(torch.load(discriminator_model_path))

criterion = nn.MSELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(
    discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999)
)

for epoch in range(num_epochs):
    for i, (real_imgs, _) in enumerate(dataloader):
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)

        # 训练判别器
        optimizer_D.zero_grad()

        # 真实图像标签为1
        real_labels = torch.full((batch_size, 1), 0.9, device=device)
        real_output = discriminator(real_imgs)
        d_loss_real = criterion(real_output, real_labels)

        # 生成假图像
        z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
        fake_imgs = generator(z)

        # 假图像标签为0
        fake_labels = torch.full((batch_size, 1), 0.1, device=device)
        fake_output = discriminator(fake_imgs.detach())
        d_loss_fake = criterion(fake_output, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器两次
        for _ in range(2):
            optimizer_G.zero_grad()
            # 重新生成假图像以增加多样性
            z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
            fake_imgs = generator(z)
            output = discriminator(fake_imgs)
            # 生成器目标为真实标签
            g_loss = criterion(output, torch.ones_like(output))
            g_loss.backward()
            optimizer_G.step()

        if i % 100 == 0:
            print(
                f"Epoch [{epoch + 1}/{num_epochs}] Batch {i}/{len(dataloader)} "
                f"Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}"
            )


torch.save(generator.state_dict(), generator_model_path)
torch.save(discriminator.state_dict(), discriminator_model_path)

# 生成随机噪声
z = torch.randn(1, latent_dim, 1, 1).to(device)
# 生成图像
generated_img = generator(z).detach().cpu()
# 反归一化处理
generated_img = generated_img.squeeze().permute(1, 2, 0) * 0.5 + 0.5
# 可视化
plt.imshow(generated_img)
plt.axis("off")
plt.show()

代码解释

数据集类AnimeFacesDataset 负责加载图像并进行预处理

数据增强:使用了随机水平翻转、颜色抖动等技术,增强模型泛化能力

生成器架构:使用转置卷积层逐步将100维噪声向量上采样到32×32×3的图像

判别器架构:使用卷积层逐步下采样图像,最终输出一个标量表示真假概率

训练技巧

  • 使用标签平滑(0.9和0.1而不是1和0)
  • 生成器每次迭代训练两次,增加生成频率
  • 使用Adam优化器和合适的beta值

图像生成代码

训练完成后,我们可以使用保存的模型来生成新的动漫头像:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os

latent_dim = 100
img_channels = 3
feature_maps = 64

generator_model_path = "./model/anime_faces/generator.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Generator(nn.Module):
    def __init__(self, latent_dim, img_channels, feature_maps):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(
                latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0
            ),
            nn.BatchNorm2d(feature_maps * 8),
            nn.ReLU(),
            nn.ConvTranspose2d(
                feature_maps * 8, feature_maps * 4, kernel_size=4, stride=2, padding=1
            ),
            nn.BatchNorm2d(feature_maps * 4),
            nn.ReLU(),
            nn.ConvTranspose2d(
                feature_maps * 4, feature_maps * 2, kernel_size=4, stride=2, padding=1
            ),
            nn.BatchNorm2d(feature_maps * 2),
            nn.ReLU(),
            nn.ConvTranspose2d(
                feature_maps * 2, img_channels, kernel_size=4, stride=2, padding=1
            ),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.net(x)


generator = Generator(latent_dim, img_channels, feature_maps).to(device)

if os.path.exists(generator_model_path):
    generator.load_state_dict(torch.load(generator_model_path))

rows, cols = 4, 4

figure = plt.figure()

for row in range(rows):
    for col in range(cols):
        z = torch.randn(1, latent_dim, 1, 1).to(device)
        generated_img = generator(z).detach().cpu()
        generated_img = generated_img.squeeze().permute(1, 2, 0) * 0.5 + 0.5

        index = row * cols + col
        plot = figure.add_subplot(rows, cols, index + 1)
        plot.imshow(generated_img)
        plot.axis("off")

plt.show()

生成效果展示

经过100个epoch的训练,我们的模型能够生成多样化的动漫人物头像:

生成对抗网络生成动漫人物头像

从上图可以看出,生成的动漫头像具有以下特点:

  • 多样的发色和发型
  • 不同的面部表情
  • 相对清晰的五官轮廓
  • 符合动漫风格的色彩搭配

优化建议

如果你想进一步提升生成质量,可以尝试以下方法:

  1. 增加模型复杂度:增加更多的卷积层或使用更大的特征图
  2. 调整训练策略:使用WGAN-GP损失函数,改善训练稳定性
  3. 提高图像分辨率:逐步从低分辨率训练到高分辨率(渐进式GAN)
  4. 使用更好的数据增强:如随机裁剪、旋转等
  5. 调整超参数:学习率、批次大小、噪声维度等

总结

本文实现了一个完整的GAN模型,用于生成动漫人物头像。我们详细介绍了:

  • GAN的基本原理和架构
  • PyTorch中的数据加载和预处理方法
  • 生成器和判别器的具体实现
  • 训练过程中的技巧和注意事项
  • 如何使用训练好的模型生成新图像

通过这个项目,你不仅掌握了GAN的核心概念,还学会了如何使用PyTorch实现一个完整的深度学习项目。你可以将同样的技术应用到其他图像生成任务中,如风景图生成、人脸生成等。

GAN是一个非常活跃的研究领域,还有许多变体和改进版本值得探索,比如DCGAN、StyleGAN、CycleGAN等。希望这篇文章能为你的深度学习之旅提供一个良好的起点!

最后更新于1小时前
本文由人工编写,AI优化,转载请注明原文地址: 手把手教你用PyTorch实现GAN:从零生成动漫人物头像的完整教程

评论 (0)

登录 后发表评论

暂无评论,快来发表第一条评论吧!