使用PyTorch拟合线性函数
数学方法
import torch
import matplotlib.pyplot as plt
x = torch.linspace(0, 100, 100).type(torch.float32)
y = x + torch.randn(100) * 10
learning_rate = 0.0001
weight = torch.rand(1, requires_grad=True)
bias = torch.rand(1, requires_grad=True)
for i in range(1000):
predictions = weight.expand_as(x) * x + bias.expand_as(x)
loss = torch.mean((predictions - y) ** 2)
print('loss:', loss.item())
loss.backward()
weight.data.add_(-learning_rate * weight.grad.data)
bias.data.add_(-learning_rate * bias.grad.data)
weight.grad.data.zero_()
bias.grad.data.zero_()
plt.figure(figsize=(10, 8))
xplot, = plt.plot(x, y, 'o')
yplot, = plt.plot(x, weight.data * x + bias.data)
plt.xlabel('X')
plt.ylabel('Y')
str1 = str(weight.data[0]) + 'x + ' + str(bias.data[0])
plt.legend([xplot, yplot], ['Data', str1])
plt.show()
使用模型
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
x = torch.linspace(0, 100, 100).type(torch.float32)
y = x + torch.randn(100) * 10
x_train = x.reshape(-1, 1)
y_train = y.reshape(-1, 1)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(1, 1)
def forward(self, x):
return self.fc(x)
model = Model()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.0001)
for i in range(1000):
predictions = model(x_train)
loss = criterion(predictions, y_train)
print('loss:', loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
weight = model.fc.weight[0, 0].item()
bias = model.fc.bias[0].item()
plt.figure(figsize=(10, 8))
xplot, = plt.plot(x, y, 'o')
yplot, = plt.plot(x, weight * x + bias)
plt.xlabel('X')
plt.ylabel('Y')
str1 = str(weight) + 'x + ' + str(bias)
plt.legend([xplot, yplot], ['Data', str1])
plt.show()
效果图

最后更新于1年前
本文由人工编写,AI优化,转载请注明原文地址: 使用PyTorch拟合线性函数
推荐阅读
Ollama工具调用原理详解及Python代码实现教程
3672025-11-27
谷歌Antigravity IDE:AI智能体驱动的软件开发平台详解
7242025-11-24
使用vtk.js加载vtu格式(UnstructuredGrid)的文件
3472025-12-02
OpenAI Codex命令行工具安装与使用教程:AI编程助手实战指南
15022025-10-08
Kaggle Notebook性能实测:免费GPU主机配置与运行时间分析
7202025-11-23
Windows系统PyTorch安装教程:CUDA 12.1环境配置与TorchText版本兼容性指南
22622025-10-08
评论 (0)
发表评论
昵称:加载中...
暂无评论,快来发表第一条评论吧!