PyTorch损失函数全解析:从回归到分类,选对Loss让模型训练事半功倍
在深度学习任务中,损失函数(Loss Function)扮演着至关重要的角色——它衡量模型预测值与真实标签之间的差距,是模型优化的核心驱动力。PyTorch的torch.nn模块提供了丰富且高效的损失函数实现,本文将为你介绍其中最常用的几种。
回归任务常用损失函数
1.nn.MSELoss(均方误差损失)
功能:计算输入x与目标y之间每个元素差的平方的平均值,即L2范数的平方。适用于回归任务,对大误差惩罚更重。
示例代码:
import torch
import torch.nn as nn
x = torch.tensor([1, 2, 3], dtype=torch.float32)
y = torch.tensor([4, 5, 6], dtype=torch.float32)
criterion = nn.MSELoss()
loss = criterion(x, y)
print(loss.item()) # 输出: 9.0
计算过程:((1-4)² + (2-5)² + (3-6)²)/3 = (9+9+9)/3 = 9
2. nn.L1Loss(平均绝对误差损失)
功能:计算输入x与目标y之间每个元素差的绝对值的平均值。对异常值不敏感,但梯度恒定,可能在极值点附近震荡。
3. nn.SmoothL1Loss(平滑L1损失)
功能:结合了L1和L2的优点,当绝对误差小于阈值beta时使用平方项,否则使用线性项。常用于目标检测中的边框回归。
4. nn.HuberLoss
功能:类似SmoothL1Loss,通过delta参数控制平方项和L1项的切换点。
分类任务常用损失函数
1. nn.CrossEntropyLoss(交叉熵损失)
功能:最常用的多分类损失函数,内部结合了LogSoftmax和NLLLoss。输入为未经过归一化的logits,目标为类别索引。
示例代码:
import torch
import torch.nn as nn
# 3个样本,每个样本有3个类别的logits输出
x = torch.tensor([[0.8, 2.1, 3.2]], dtype=torch.float32)
# 真实类别索引
y = torch.tensor([2], dtype=torch.long)
criterion = nn.CrossEntropyLoss()
loss = criterion(x, y)
print(loss.item())
2. nn.BCELoss(二分类交叉熵损失)
功能:用于二分类任务,输入必须是经过sigmoid后的概率值(范围[0,1]),目标为0或1。
3. nn.BCEWithLogitsLoss
功能:将Sigmoid层和BCELoss合并为一个类,数值稳定性更好。推荐直接使用此函数代替BCELoss。
4. nn.NLLLoss(负对数似然损失)
功能:通常在CrossEntropyLoss内部使用。使用时需要先对模型输出做LogSoftmax。
5. nn.KLDivLoss(KL散度损失)
功能:衡量两个概率分布之间的差异。常用于知识蒸馏、变分自编码器等场景。
特殊任务的损失函数
1. nn.CTCLoss(连接时序分类损失)
功能:用于序列对齐问题,如语音识别、手写文字识别,无需对齐输入输出序列。
2. nn.MultiMarginLoss(多分类合页损失)
功能:基于边界的多分类损失,类似SVM的损失函数。
3. nn.TripletMarginLoss(三元组损失)
功能:用于度量学习,拉近正样本对、推远负样本对。常用于人脸识别、图像检索。
4. nn.CosineEmbeddingLoss(余弦嵌入损失)
功能:基于余弦相似度的损失函数,衡量两个输入的相似性。
5. nn.PoissonNLLLoss(泊松负对数似然损失)
功能:假设目标服从泊松分布,适用于计数预测任务。
损失函数选择方法
| 任务类型 | 推荐损失函数 |
|---|---|
| 回归(普通) | MSELoss |
| 回归(有异常值) | L1Loss 或 SmoothL1Loss |
| 二分类 | BCEWithLogitsLoss |
| 多分类 | CrossEntropyLoss |
| 多标签分类 | MultiLabelSoftMarginLoss |
| 度量学习 | TripletMarginLoss |
| 序列对齐 | CTCLoss |
| 概率分布比较 | KLDivLoss |
总结
PyTorch提供了涵盖大多数深度学习任务的损失函数实现。使用时需要注意:
- 输入张量的形状和数据类型
- 是否需要预先做softmax/sigmoid(CrossEntropyLoss内部已包含)
- 正负样本的平衡问题(可借助
weight参数)
希望本文能帮助你在PyTorch中正确选择和使用损失函数。