示例代码
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import random
# 真实数据
def generate_real():
real_data = torch.FloatTensor(
[
random.uniform(0.8, 1.0),
random.uniform(0.0, 0.2),
random.uniform(0.8, 1.0),
random.uniform(0.0, 0.2),
]
)
return real_data
# 生成器
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(1, 3),
nn.Sigmoid(),
nn.Linear(3, 4),
nn.Sigmoid(),
)
def forward(self, x):
return self.model(x)
# 判别器
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(4, 3),
nn.Sigmoid(),
nn.Linear(3, 1),
nn.Sigmoid(),
)
def forward(self, x):
return self.model(x)
generator = Generator()
discriminator = Discriminator()
# 优化器
optimizer_g = torch.optim.SGD(generator.parameters(), lr=0.01)
optimizer_d = torch.optim.SGD(discriminator.parameters(), lr=0.01)
# 损失函数
criterion = nn.MSELoss()
# 模型训练
counter = 0
loss_G = []
loss_D = []
for i in range(10000):
# 训练判别器
real_data = generate_real()
real_output = discriminator(real_data)
d_loss_real = criterion(real_output, torch.FloatTensor([1.0]))
optimizer_d.zero_grad()
d_loss_real.backward()
optimizer_d.step()
fake_data = generator(torch.FloatTensor([0.5]))
fake_output = discriminator(fake_data.detach())
d_loss_fake = criterion(fake_output, torch.FloatTensor([0.0]))
optimizer_d.zero_grad()
d_loss_fake.backward()
optimizer_d.step()
# 训练生成器
fake_data = generator(torch.FloatTensor([0.5]))
fake_output = discriminator(fake_data)
g_loss = criterion(fake_output, torch.FloatTensor([1.0]))
optimizer_g.zero_grad()
g_loss.backward()
optimizer_g.step()
# 记录损失
if i % 10 == 0:
counter += 1
loss_G.append(g_loss.item())
loss_D.append(d_loss_real.item())
if i % 100 == 0:
print(f"Epoch {i}, loss_G: {loss_G[-1]}, loss_D: {loss_D[-1]}")
# 绘制损失曲线
plt.figure(figsize=(10, 5))
plt.plot(loss_G, label="Generator Loss")
plt.plot(loss_D, label="Discriminator Loss")
plt.legend()
plt.show()
# 生成1010格式的数据
result = generator.forward(torch.FloatTensor([0.5])).detach()
print(result)
执行结果
控制台输出:
Epoch 9000, loss_G: 0.2383641004562378, loss_D: 0.2622174620628357
Epoch 9100, loss_G: 0.23865269124507904, loss_D: 0.25279828906059265
Epoch 9200, loss_G: 0.23924800753593445, loss_D: 0.23575149476528168
Epoch 9300, loss_G: 0.23898108303546906, loss_D: 0.24855899810791016
Epoch 9400, loss_G: 0.239225372672081, loss_D: 0.24059531092643738
Epoch 9500, loss_G: 0.23939572274684906, loss_D: 0.24873214960098267
Epoch 9600, loss_G: 0.23965498805046082, loss_D: 0.2534474730491638
Epoch 9700, loss_G: 0.2392021119594574, loss_D: 0.25969743728637695
Epoch 9800, loss_G: 0.23949465155601501, loss_D: 0.2547983229160309
Epoch 9900, loss_G: 0.2388973981142044, loss_D: 0.23489777743816376
tensor([0.9435, 0.0349, 0.8724, 0.0378])
生成器和判别器损失图表:
转载请注明转自www.hylab.cn,原文地址:使用生成对抗网络(GAN)生成1010格式规律的数据