PyTorch损失函数全解析:从回归到分类,选对Loss让模型训练事半功倍

2025-03-12 人工智能 190 次阅读 0 次点赞
本文介绍了PyTorch中常用的损失函数,按任务类型分类讲解。回归任务包括MSELoss(均方误差)、L1Loss(平均绝对误差)、SmoothL1Loss和HuberLoss;分类任务涵盖CrossEntropyLoss(交叉熵)、BCEWithLogitsLoss(推荐用于二分类)、NLLLoss及KLDivLoss等;特殊任务如CTCLoss用于序列对齐,TripletMarginLoss用于度量学习。文章还给出了任务与损失函数的匹配建议,并提醒注意输入张量形状、是否内置softmax以及正负样本平衡问题。

在深度学习任务中,损失函数(Loss Function)扮演着至关重要的角色——它衡量模型预测值与真实标签之间的差距,是模型优化的核心驱动力。PyTorch的torch.nn模块提供了丰富且高效的损失函数实现,本文将为你介绍其中最常用的几种。

回归任务常用损失函数

1.nn.MSELoss(均方误差损失)

功能:计算输入x与目标y之间每个元素差的平方的平均值,即L2范数的平方。适用于回归任务,对大误差惩罚更重。

示例代码

import torch
import torch.nn as nn

x = torch.tensor([1, 2, 3], dtype=torch.float32)
y = torch.tensor([4, 5, 6], dtype=torch.float32)

criterion = nn.MSELoss()
loss = criterion(x, y)

print(loss.item())  # 输出: 9.0

计算过程:((1-4)² + (2-5)² + (3-6)²)/3 = (9+9+9)/3 = 9

2. nn.L1Loss(平均绝对误差损失)

功能:计算输入x与目标y之间每个元素差的绝对值的平均值。对异常值不敏感,但梯度恒定,可能在极值点附近震荡。

3. nn.SmoothL1Loss(平滑L1损失)

功能:结合了L1和L2的优点,当绝对误差小于阈值beta时使用平方项,否则使用线性项。常用于目标检测中的边框回归。

4. nn.HuberLoss

功能:类似SmoothL1Loss,通过delta参数控制平方项和L1项的切换点。

分类任务常用损失函数

1. nn.CrossEntropyLoss(交叉熵损失)

功能:最常用的多分类损失函数,内部结合了LogSoftmax和NLLLoss。输入为未经过归一化的logits,目标为类别索引。

示例代码

import torch
import torch.nn as nn

# 3个样本,每个样本有3个类别的logits输出
x = torch.tensor([[0.8, 2.1, 3.2]], dtype=torch.float32)
# 真实类别索引
y = torch.tensor([2], dtype=torch.long)

criterion = nn.CrossEntropyLoss()
loss = criterion(x, y)

print(loss.item())

2. nn.BCELoss(二分类交叉熵损失)

功能:用于二分类任务,输入必须是经过sigmoid后的概率值(范围[0,1]),目标为0或1。

3. nn.BCEWithLogitsLoss

功能:将Sigmoid层和BCELoss合并为一个类,数值稳定性更好。推荐直接使用此函数代替BCELoss

4. nn.NLLLoss(负对数似然损失)

功能:通常在CrossEntropyLoss内部使用。使用时需要先对模型输出做LogSoftmax。

5. nn.KLDivLoss(KL散度损失)

功能:衡量两个概率分布之间的差异。常用于知识蒸馏、变分自编码器等场景。

特殊任务的损失函数

1. nn.CTCLoss(连接时序分类损失)

功能:用于序列对齐问题,如语音识别、手写文字识别,无需对齐输入输出序列。

2. nn.MultiMarginLoss(多分类合页损失)

功能:基于边界的多分类损失,类似SVM的损失函数。

3. nn.TripletMarginLoss(三元组损失)

功能:用于度量学习,拉近正样本对、推远负样本对。常用于人脸识别、图像检索。

4. nn.CosineEmbeddingLoss(余弦嵌入损失)

功能:基于余弦相似度的损失函数,衡量两个输入的相似性。

5. nn.PoissonNLLLoss(泊松负对数似然损失)

功能:假设目标服从泊松分布,适用于计数预测任务。

损失函数选择方法

任务类型 推荐损失函数
回归(普通) MSELoss
回归(有异常值) L1Loss 或 SmoothL1Loss
二分类 BCEWithLogitsLoss
多分类 CrossEntropyLoss
多标签分类 MultiLabelSoftMarginLoss
度量学习 TripletMarginLoss
序列对齐 CTCLoss
概率分布比较 KLDivLoss

总结

PyTorch提供了涵盖大多数深度学习任务的损失函数实现。使用时需要注意:

  • 输入张量的形状和数据类型
  • 是否需要预先做softmax/sigmoid(CrossEntropyLoss内部已包含)
  • 正负样本的平衡问题(可借助weight参数)

希望本文能帮助你在PyTorch中正确选择和使用损失函数。

最后更新于1小时前
本文由人工编写,AI优化,转载请注明原文地址: PyTorch损失函数全解析:从回归到分类,选对Loss让模型训练事半功倍

评论 (0)

登录 后发表评论

暂无评论,快来发表第一条评论吧!