- 深度学习入门:基于PyTorch和TensorFlow的理论与实现
- 红色石头
- 1458字
- 2025-02-23 06:37:20
2.6 线性回归
本章的前5节介绍了PyTorch的基本内容和语法,本节将通过一个简单的线性回归实例,介绍如何使用PyTorch编写一个完整的模型并验证它的好坏。
2.6.1 线性回归的基本原理
线性回归是一个最基本、最简单的机器学习算法,相信读者对它的基本原理已经非常熟悉了,本小节仅做简要介绍。
线性回归一般用于数值预测,如房屋价格预测、信用卡额度预测等。线性回归算法就是要找出这样一条拟合线或拟合面,能够最大限度地拟合真实的数据分布,如图2-5所示。

图2-5 线性回归算法示意
这条直线可以表示为

式中,是预测值,w0和w1是直线参数,正是需要去求的两个值。
如何确定这条直线以及相关的参数w0和w1呢?我们希望预测值与真实值y越接近越好,因此引入代价函数。代价函数是定义在整个训练集上的,是所有样本误差的平均,也就是所有样本损失函数的平均。其实,代价函数与损失函数的唯一区别在于前者针对整个训练集,后者针对单个样本。代价函数越小,表明直线拟合得越好。
此时,代价函数通过均方差的计算而得到,计算公式为:

式中,m表示总的样本个数,yi表示第i个样本的真实值,表示第i个样本的预测值。分母是2m而不是m仅仅是为了平方求导的方便。
接下来要求最小化代价函数J时对应的参数w0和w1。如何最小化代价函数J呢?最简单的方法就是使用梯度下降算法,其核心思想是在函数曲线上的某一点,函数沿梯度方向具有最大的变化率,那么沿着负梯度方向移动会不断逼近最小值,这样一个迭代的过程可以最终实现代价函数的最小化目标。
梯度下降算法中,w0和w1迭代更新的表达式为:


式中,α表示学习率。
这样,经过多次的迭代更新,J会不断接近全局最小值。此时,就可以得到参数w0和w1,直线也就确定了。
2.6.2 线性回归的PyTorch实现
1.数据集
首先,我们要构造一些数据集,代码如下:
# y=3x+10,后面加上torch.randn()函数制造噪音 x = torch.unsqueeze(torch.linspace(-1, 1, 50), dim=1) y = 3*x + 10 + 0.5 * torch.randn(x.size())
显然,原始的数据集中,y是由直线103x+加上一些随机噪声而得到的。原始数据的分布如图2-6所示。

图2-6 原始数据的分布
2.模型定义
定义线性回归模型,代码如下:

接下来,我们定义损失函数和优化函数,这里使用均方误差作为损失函数,使用梯度下降算法进行优化,代码如下:
# 定义损失函数和优化函数 criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=5e-3)
3.模型训练
开始进行模型的训练,代码如下:

上面的模型训练代码中,迭代的次数为1000次,是遍历整个数据集的次数。先进行前向传播计算代价函数,然后向后传播计算梯度,这里需要注意的是,每次计算梯度前都要将梯度归零,不然梯度会累加到一起造成结果不收敛。为了便于观察结果,每20次迭代之后输出当前的均方差损失。
4.模型测试
最后,我们通过model.eval()函数将模型由训练模式变为测试模式,将数据放入模型中进行预测。最后,通过绘图工具Matplotlib判断拟合的直线与原始数据的贴近程度,代码如下:
model.eval() y_hat = model(x) plt.scatter(x.numpy(), y.numpy(), label=′原始数据′) plt.plot(x.numpy(), y_hat.detach().numpy(), c=′r′, label=′拟合直线′) # 显示图例 plt.legend() plt.show()
y_hat就是训练好的线性回归模型的预测值。注意,y_hat.detach().numpy()中,.detach()用于停止对张量的梯度跟踪。模型训练阶段需要跟踪梯度,但是模型测试的时候就不需要梯度跟踪了。最后显示的拟合直线如图2-7所示。

图2-7 原始数据与拟合直线显示效果
可以看到,线性回归模型与原始数据拟合得非常好。我们可以使用下面的代码来查看这条直线的参数w0和w1:
>>> list(model.named_parameters()) [(′fc.weight′, Parameter containing: tensor([[2.7447]], requires_grad=True)), (′fc.bias′, Parameter containing: tensor([10.1136], requires_grad=True))]
通过参数查询可得w0=10.1136,w1=2.7447,与构造数据时使用的直线y=10+3x非常接近。
至此,我们已经介绍了PyTorch的基本用法,并使用PyTorch实现了一个简单的线性回归模型。本书后面的神经网络章节中,我们将会学习更多、更重要的与PyTorch有关的知识,并使用这些知识来处理更复杂的机器视觉(Computer Vision,CV)和自然语言处理(Natural Language Processing,NLP)问题。