Softmax 回归
在之前的章节中,我们探讨了线性回归及其实现,包括从零开始的实现和使用高级API的实现。回归模型通常用于定量输出,例如预测价格、胜场数或患者住院天数。然而,并非所有问题都适合使用回归模型,这取决于其输出的性质。这导 致了对数回归或生存建模等特殊情况。
分类问题
分类将重点从“多少?”转向“哪个类别?”的问题。示例包括确定电子邮件是否为垃圾邮件、预测客户行为或识别图像中的对象。分类可以分为:
- 硬分类:直接将项分配到一个类别。
- 软分类:项属于某个类别的概率。
- 多标签分类:项可以同时属于多个类别。
分类问题的示例:
- 这封电子邮件属于垃圾邮件文件夹还是收件箱?
- 这位客户是否有可能订阅某项服务?
- 这张图片中描绘的是哪种动物?
- 某人接下来可能会看哪部电影?
图像分类问题设置
考虑一个简单示例,其中每个输入都是一个2x2灰度图像,由四个特征表示。每张图像被分为三类之一:猫、鸡或狗。
标签表示
- 自然编码:使用整数,例如 ,其中每个数字代表一个不同的类别。
- 独热编码:每个类别由一个二进制向量表示:
分类的线性模型
为了处理多个类别,我们使用多个仿射函数——每个类别一个。例如,对于四个特征和三个类别,我们总共需要12个权重和三个偏置。模型输出由以下给出:
这种设置等同于一个单层全连接神经网络。
Softmax 操作
softmax函数通过应用指数函数并将这些值归一化使其和为一,从而将线性输出转换为概率。它定义为:
此函数确保输出值为非负且和为1,这是概率的必要属性。
矢量化以提高效率
为了计算效率,特别是在处理数据小批量时,我们使用矢量化操作。这涉及到矩阵-矩阵乘法,它们在计算上更快,更适合现代计算架构。
Softmax 回归的损失函数
我们利用交叉熵损失,这是分类任务的常见选择。此损失衡量预测概率与独热编码标签所代表的实际分布之间的差异:
这种设置对应于在给定预测的情况下最大化观测标签的似然性,为学习模型参数提供了概率基础。
使用 PyTorch 从零开始实现 Softmax 回归
import torch
from d2l import torch as d2l
class SoftmaxRegressionScratch(d2l.Classifier):
def __init__(self, num_inputs, num_outputs, lr, sigma=0.01):
super().__init__()
self.save_hyperparameters()
self.W = torch.normal(0, sigma, size=(num_inputs, num_outputs), requires_grad=True)
self.b = torch.zeros(num_outputs, requires_grad=True)
def parameters(self):
return [self.W, self.b]
def forward(self, X):
X = X.reshape((-1, self.W.shape[0]))
return softmax(torch.matmul(X, self.W) + self.b)
def loss(self, y_hat, y):
return cross_entropy(y_hat, y)
def softmax(X):
X_exp = torch.exp(X)
partition = X_exp.sum(1, keepdims=True)
return X_exp / partition # The broadcasting mechanism is applied here
def cross_entropy(y_hat, y):
return -torch.log(y_hat[list(range(len(y_hat))), y]).mean()
data = d2l.FashionMNIST(batch_size=256)
model = SoftmaxRegressionScratch(num_inputs=784, num_outputs=10, lr=0.1)
trainer = d2l.Trainer(max_epochs=10)
trainer.fit(model, data)
# Prediction
X, y = next(iter(data.val_dataloader()))
preds = model(X).argmax(axis=1)
preds.shape
# Identify and visualize incorrect predictions
wrong = preds.type(y.dtype) != y
X, y, preds = X[wrong], y[wrong], preds[wrong]
labels = [a+'\n'+b for a, b in zip(
data.text_labels(y), data.text_labels(preds))]
data.visualize([X, y], labels=labels)