使用生成对抗网络(GAN)生成1010格式规律的数据

创建日期:2025-04-23
更新日期:2025-04-24

示例代码

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import random

# 真实数据
def generate_real():
    real_data = torch.FloatTensor(
        [
            random.uniform(0.8, 1.0),
            random.uniform(0.0, 0.2),
            random.uniform(0.8, 1.0),
            random.uniform(0.0, 0.2),
        ]
    )
    return real_data

# 生成器
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(1, 3),
            nn.Sigmoid(),
            nn.Linear(3, 4),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)

# 判别器
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(4, 3),
            nn.Sigmoid(),
            nn.Linear(3, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)

generator = Generator()
discriminator = Discriminator()

# 优化器
optimizer_g = torch.optim.SGD(generator.parameters(), lr=0.01)
optimizer_d = torch.optim.SGD(discriminator.parameters(), lr=0.01)

# 损失函数
criterion = nn.MSELoss()

# 模型训练
counter = 0
loss_G = []
loss_D = []

for i in range(10000):
    # 训练判别器
    real_data = generate_real()
    real_output = discriminator(real_data)
    d_loss_real = criterion(real_output, torch.FloatTensor([1.0]))
    optimizer_d.zero_grad()
    d_loss_real.backward()
    optimizer_d.step()

    fake_data = generator(torch.FloatTensor([0.5]))
    fake_output = discriminator(fake_data.detach())
    d_loss_fake = criterion(fake_output, torch.FloatTensor([0.0]))
    optimizer_d.zero_grad()
    d_loss_fake.backward()
    optimizer_d.step()

    # 训练生成器
    fake_data = generator(torch.FloatTensor([0.5]))
    fake_output = discriminator(fake_data)
    g_loss = criterion(fake_output, torch.FloatTensor([1.0]))
    optimizer_g.zero_grad()
    g_loss.backward()
    optimizer_g.step()

    # 记录损失
    if i % 10 == 0:
        counter += 1
        loss_G.append(g_loss.item())
        loss_D.append(d_loss_real.item())
    if i % 100 == 0:
        print(f"Epoch {i}, loss_G: {loss_G[-1]}, loss_D: {loss_D[-1]}")

# 绘制损失曲线
plt.figure(figsize=(10, 5))
plt.plot(loss_G, label="Generator Loss")
plt.plot(loss_D, label="Discriminator Loss")
plt.legend()
plt.show()

# 生成1010格式的数据
result = generator.forward(torch.FloatTensor([0.5])).detach()
print(result)

执行结果

控制台输出:

Epoch 9000, loss_G: 0.2383641004562378, loss_D: 0.2622174620628357
Epoch 9100, loss_G: 0.23865269124507904, loss_D: 0.25279828906059265
Epoch 9200, loss_G: 0.23924800753593445, loss_D: 0.23575149476528168
Epoch 9300, loss_G: 0.23898108303546906, loss_D: 0.24855899810791016
Epoch 9400, loss_G: 0.239225372672081, loss_D: 0.24059531092643738
Epoch 9500, loss_G: 0.23939572274684906, loss_D: 0.24873214960098267
Epoch 9600, loss_G: 0.23965498805046082, loss_D: 0.2534474730491638
Epoch 9700, loss_G: 0.2392021119594574, loss_D: 0.25969743728637695
Epoch 9800, loss_G: 0.23949465155601501, loss_D: 0.2547983229160309
Epoch 9900, loss_G: 0.2388973981142044, loss_D: 0.23489777743816376
tensor([0.9435, 0.0349, 0.8724, 0.0378])

生成器和判别器损失图表:

Figure_1.png

网站简介

一个来自三线小城市的程序员开发经验总结。