使用生成对抗网络(GAN)生成人脸

创建日期:2025-05-02
更新日期:2025-05-02

卷积对抗网络(DCGAN)

示例代码

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

num_epochs = 20
batch_size = 256
learning_rate = 0.0002
betas = (0.5, 0.999)

dataset = datasets.CelebA(
    root='./data/celeba',
    split='train',
    transform=transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]),
    download=True,
)

dataloader = DataLoader(dataset , batch_size=128, shuffle=True, num_workers=2, pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 生成随机数据
def generate_random(batch_size, size):
    random_data = torch.randn(batch_size, size)
    return random_data.to(device)

# 生成器 - 使用转置卷积
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        # 输入是长度为100的噪声向量
        self.input_size = 100
        # 起始特征图大小
        self.ngf = 64

        self.main = nn.Sequential(
            # 输入是 Z: [batch, 100, 1, 1]
            nn.ConvTranspose2d(self.input_size, self.ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(self.ngf * 8),
            nn.ReLU(True),
            # 状态尺寸: [batch, ngf*8, 4, 4]
            nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf * 4),
            nn.ReLU(True),
            # 状态尺寸: [batch, ngf*4, 8, 8]
            nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.ReLU(True),
            # 状态尺寸: [batch, ngf*2, 16, 16]
            nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.ReLU(True),
            # 状态尺寸: [batch, ngf, 32, 32]
            nn.ConvTranspose2d(self.ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh(),
            # 最终输出尺寸: [batch, 3, 64, 64]
        )

    def forward(self, x):
        # 将输入重塑为 [batch, 100, 1, 1]
        x = x.view(x.size(0), self.input_size, 1, 1)
        return self.main(x)

# 判别器 - 使用卷积
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        # 特征图基础大小
        self.ndf = 64

        self.main = nn.Sequential(
            # 输入是 [batch, 3, 64, 64]
            nn.Conv2d(3, self.ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态尺寸: [batch, ndf, 32, 32]
            nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态尺寸: [batch, ndf*2, 16, 16]
            nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态尺寸: [batch, ndf*4, 8, 8]
            nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态尺寸: [batch, ndf*8, 4, 4]
            nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
            # 输出尺寸: [batch, 1, 1, 1]
        )

    def forward(self, x):
        return self.main(x).view(-1, 1)

generator = Generator().to(device)
discriminator = Discriminator().to(device)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1 or classname.find("BatchNorm") != -1:
        if hasattr(m, "weight") and m.weight is not None:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            nn.init.constant_(m.bias.data, 0)

# 初始化生成器和判别器
generator.apply(weights_init)
discriminator.apply(weights_init)

# 优化器
optimizer_g = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=betas)
optimizer_d = torch.optim.Adam(
    discriminator.parameters(), lr=learning_rate, betas=betas
)

# 损失函数
criterion = nn.BCELoss()

# 模型训练
generator.train()
discriminator.train()

loss_G = []
loss_D = []

# 训练
def train():
    for epoch in range(num_epochs):
        for i, (images, _) in enumerate(dataloader):
            # 获取实际的批次大小
            current_batch_size = images.size(0)
            images = images.to(device)
    
            # 真实数据标签平滑(防止D过强)
            real_labels = torch.full((current_batch_size, 1), 0.9, device=device)
            fake_labels = torch.full((current_batch_size, 1), 0.1, device=device)
    
            # 训练判别器
            optimizer_d.zero_grad()
    
            # 真实数据
            real_output = discriminator(images)
            d_loss_real = criterion(real_output, real_labels)
    
            # 生成数据
            z = generate_random(current_batch_size, 100)
            fake_data = generator(z).detach()  # 不要计算G的梯度
            fake_output = discriminator(fake_data)
            d_loss_fake = criterion(fake_output, fake_labels)
    
            d_loss = (d_loss_real + d_loss_fake) / 2
            d_loss.backward()
            optimizer_d.step()
    
            # 训练生成器
            optimizer_g.zero_grad()
            z = generate_random(current_batch_size, 100)
            fake_data = generator(z)
            fake_output = discriminator(fake_data)
            g_loss = criterion(fake_output, real_labels)
            g_loss.backward()
            optimizer_g.step()
    
            # 记录损失
            if i % 64 == 0:
                loss_G.append(g_loss.item())
                loss_D.append(d_loss.item())
            if i % 64 == 0:
                print(
                    f"Epoch {epoch + 1}, Step {i + current_batch_size}, loss_G: {loss_G[-1]}, loss_D: {loss_D[-1]}"
                )

# 训练结束后保存模型
def save_model():
    torch.save(generator.state_dict(), "generator.pth")
    torch.save(discriminator.state_dict(), "discriminator.pth")

# 绘制损失曲线
def plot_loss():
    plt.figure(figsize=(10, 5))
    plt.plot(loss_G, label="Generator Loss")
    plt.plot(loss_D, label="Discriminator Loss")
    plt.legend()
    plt.show()

# 生成图像
def generate_image():
    rows, cols = 4, 4
    
    figure = plt.figure(figsize=(8, 8))
    
    # 在生成图像前将生成器设置为评估模式
    generator.eval()
    
    for row in range(rows):
        for col in range(cols):
            z = generate_random(1, 100)
            output = generator(z)
            # 将图像从 [1, 3, 64, 64] 转换为 [64, 64, 3]
            img = output.detach().cpu().squeeze().permute(1, 2, 0).numpy()
            # 将图像从 [-1, 1] 范围转换到 [0, 1] 范围
            img = (img + 1) / 2
            index = row * cols + col
            plot = figure.add_subplot(rows, cols, index + 1)
            plot.imshow(img)
            plot.axis("off")
    
    plt.show()

if __name__ == "__main__":
    train()
    save_model()
    plot_loss()
    generate_image()

损失曲线

Figure_1.png

生成图片

Figure_2.png

WGAN-GP

示例代码

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

num_epochs = 20
batch_size = 256
learning_rate = 0.0001  # WGAN-GP通常使用较小的学习率
n_critic = 5  # 判别器训练次数
lambda_gp = 10  # 梯度惩罚系数

dataset = datasets.CelebA(
    root='./data/celeba',
    split='train',
    transform=transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]),
    download=True,
)

dataloader = DataLoader(dataset , batch_size=128, shuffle=True, num_workers=2, pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 生成随机数据
def generate_random(batch_size, size):
    random_data = torch.randn(batch_size, size)
    return random_data.to(device)

# 生成器 - 使用转置卷积
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        # 输入是长度为100的噪声向量
        self.input_size = 100
        # 起始特征图大小
        self.ngf = 64

        self.main = nn.Sequential(
            # 输入是 Z: [batch, 100, 1, 1]
            nn.ConvTranspose2d(self.input_size, self.ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(self.ngf * 8),
            nn.ReLU(True),
            # 状态尺寸: [batch, ngf*8, 4, 4]
            nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf * 4),
            nn.ReLU(True),
            # 状态尺寸: [batch, ngf*4, 8, 8]
            nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.ReLU(True),
            # 状态尺寸: [batch, ngf*2, 16, 16]
            nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.ReLU(True),
            # 状态尺寸: [batch, ngf, 32, 32]
            nn.ConvTranspose2d(self.ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh(),
            # 最终输出尺寸: [batch, 3, 64, 64]
        )

    def forward(self, x):
        # 将输入重塑为 [batch, 100, 1, 1]
        x = x.view(x.size(0), self.input_size, 1, 1)
        return self.main(x)

# 判别器 - 使用卷积,移除sigmoid层
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        # 特征图基础大小
        self.ndf = 64

        self.main = nn.Sequential(
            # 输入是 [batch, 3, 64, 64]
            nn.Conv2d(3, self.ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态尺寸: [batch, ndf, 32, 32]
            nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态尺寸: [batch, ndf*2, 16, 16]
            nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态尺寸: [batch, ndf*4, 8, 8]
            nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态尺寸: [batch, ndf*8, 4, 4]
            nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False),
            # 移除了sigmoid层
        )

    def forward(self, x):
        return self.main(x).view(-1, 1)

generator = Generator().to(device)
discriminator = Discriminator().to(device)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1 or classname.find("BatchNorm") != -1:
        if hasattr(m, "weight") and m.weight is not None:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            nn.init.constant_(m.bias.data, 0)

# 初始化生成器和判别器
generator.apply(weights_init)
discriminator.apply(weights_init)

# 优化器
optimizer_g = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_d = torch.optim.Adam(
    discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999)
)

# 损失函数
criterion = nn.BCELoss()

# 模型训练
generator.train()
discriminator.train()

loss_G = []
loss_D = []

# 计算梯度惩罚
def compute_gradient_penalty(D, real_samples, fake_samples):
    """计算梯度惩罚"""
    # 随机权重
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
    # 获取插值样本
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones(real_samples.size(0), 1, device=device)
    # 获取梯度
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# 训练
def train():
    for epoch in range(num_epochs):
        for i, (images, _) in enumerate(dataloader):
            # 获取实际的批次大小
            current_batch_size = images.size(0)
            images = images.to(device)
            
            # 训练判别器
            for _ in range(n_critic):
                optimizer_d.zero_grad()
                
                # 真实数据
                real_output = discriminator(images)
                d_loss_real = -real_output.mean()
                
                # 生成数据
                z = generate_random(current_batch_size, 100)
                fake_data = generator(z).detach()
                fake_output = discriminator(fake_data)
                d_loss_fake = fake_output.mean()
                
                # 计算梯度惩罚
                gradient_penalty = compute_gradient_penalty(discriminator, images.data, fake_data.data)
                
                # 判别器总损失
                d_loss = d_loss_real + d_loss_fake + lambda_gp * gradient_penalty
                d_loss.backward()
                optimizer_d.step()
            
            # 训练生成器
            optimizer_g.zero_grad()
            z = generate_random(current_batch_size, 100)
            fake_data = generator(z)
            fake_output = discriminator(fake_data)
            g_loss = -fake_output.mean()
            g_loss.backward()
            optimizer_g.step()
            
            # 记录损失
            if i % 64 == 0:
                loss_G.append(g_loss.item())
                loss_D.append(d_loss.item())
                print(
                    f"Epoch {epoch + 1}, Step {i + current_batch_size}, loss_G: {loss_G[-1]}, loss_D: {loss_D[-1]}"
                )

# 训练结束后保存模型
def save_model():
    torch.save(generator.state_dict(), "generator.pth")
    torch.save(discriminator.state_dict(), "discriminator.pth")

# 绘制损失曲线
def plot_loss():
    plt.figure(figsize=(10, 5))
    plt.plot(loss_G, label="Generator Loss")
    plt.plot(loss_D, label="Discriminator Loss")
    plt.legend()
    plt.show()

# 生成图像
def generate_image():
    rows, cols = 4, 4
    
    figure = plt.figure(figsize=(8, 8))
    
    # 在生成图像前将生成器设置为评估模式
    generator.eval()
    
    for row in range(rows):
        for col in range(cols):
            z = generate_random(1, 100)
            output = generator(z)
            # 将图像从 [1, 3, 64, 64] 转换为 [64, 64, 3]
            img = output.detach().cpu().squeeze().permute(1, 2, 0).numpy()
            # 将图像从 [-1, 1] 范围转换到 [0, 1] 范围
            img = (img + 1) / 2
            index = row * cols + col
            plot = figure.add_subplot(rows, cols, index + 1)
            plot.imshow(img)
            plot.axis("off")
    
    plt.show()

if __name__ == "__main__":
    train()
    save_model()
    plot_loss()
    generate_image()

损失曲线

Figure_3.png

生成图片

Figure_4.png

转载请注明转自www.hylab.cn,原文地址:使用生成对抗网络(GAN)生成人脸

网站简介

一个来自三线小城市的程序员开发经验总结。