随机梯度下降算法的Python实现

随机梯度下降算法的Python实现当用于训练的数据量非常大时 批量梯度下降算法变得不再适用 此时其速度会非常慢 为解决这个问题 人们又想出了随机梯度下降算法 随机梯度下降算法的核心思想并没有变 它仍是基于梯度 通过对目标函数中的参数不断迭代更新 使得目标函数逐渐靠近最小值 具体代码实现如下 先导入要用到的各种包 matplotlibno

当用于训练的数据量非常大时,批量梯度下降算法变得不再适用(此时其速度会非常慢),为解决这个问题,人们又想出了随机梯度下降算法。随机梯度下降算法的核心思想并没有变,它仍是基于梯度,通过对目标函数中的参数不断迭代更新,使得目标函数逐渐靠近最小值。

具体代码实现如下:

先导入要用到的各种包:

%matplotlib notebook import pandas as pd import matplotlib.pyplot as plt import numpy as np

读入数据并查看数据的相关信息:

data = pd.read_csv('ex1data1.txt',header = None,names=['Population','Profits']) data.head() # 查看data中的前五条数据

data中前五条数据如下图所示:

随机梯度下降算法的Python实现

data.describe() # 查看data的各描述统计量信息

绘制原始数据的散点图:

fig,axes = plt.subplots() data.plot(kind='scatter',x='Population',y='Profits',ax=axes,marker='o',color='r') axes.set(xlabel='Population',ylabel='Profits') fig.savefig('p1.png')

绘制的散点图为:

随机梯度下降算法的Python实现

向data中添加一列便于矩阵计算的辅助列:

data.insert(0,'Ones',1) data.head()

加入辅助列的data如下所示:

随机梯度下降算法的Python实现

随机梯度下降的实现:

# 定义数据特征和标签的提取函数: def get_fea_lab(data): cols = data.shape[1] X = data.iloc[:,0:cols-1] # X是data中的前两列(不包括索引列) y = data.iloc[:,cols-1:cols] # y是data中的最后一列 # 将X和y都转化成矩阵的形式: X = np.matrix(X.values) y = np.matrix(y.values) return X,y # 定义代价函数: def computeCost(data,theta,i): X,y = get_fea_lab(data) inner = np.power(((X*theta.T)-y),2) return (float(inner[i]/2)) # 定义随机梯度下降函数: def stochastic_gradient_descent(data,theta,alpha,epoch): X0,y0 = get_fea_lab(data) # 提取X和y矩阵 temp = np.matrix(np.zeros(theta.shape)) parameters = int(theta.shape[1]) cost = np.zeros(len(X0)) avg_cost = np.zeros(epoch) for k in range(epoch): new_data = data.sample(frac=1) # 打乱数据 X,y = get_fea_lab(new_data) # 提取新的X和y矩阵 for i in range(len(X)): error = X[i:i+1]*theta.T-y[i] cost[i] = computeCost(new_data,theta,i) for j in range(parameters): temp[0,j] = theta[0,j] - alpha*error*X[i:i+1,j] theta = temp avg_cost[k] = np.average(cost) return theta,avg_cost # 初始化学习率、迭代轮次和参数theta: alpha = 0.001 epoch = 200 theta = np.matrix(np.array([0,0])) # 调用随机梯度下降函数来计算线性回归中的theat参数: g,avg_cost = stochastic_gradient_descent(data,theta,alpha,epoch) # g的值为matrix([[-3., 1.]])

绘制每轮迭代中代价函数的平均值与迭代轮次的关系图像:

本例中因为数据集中一共只有97个样本,所以对于每轮迭代,我选择的是计算所有样本对应的的代价函数的平均值。在数据集非常大的情况下,我们可以选择计算每轮迭代中最后一部分样本对应的代价函数的平均值。

fig, axes = plt.subplots() axes.plot(np.arange(epoch), avg_cost, 'r') axes.set_xlabel('Epoch') axes.set_ylabel('avg_cost') axes.set_title('avg_cost vs. Epoch') fig.savefig('p2.png')

具体如下图所示:

随机梯度下降算法的Python实现

从上图中我们可以看到,大约从第90轮迭代开始,代价函数的平均值在某个值上下进行小范围波动(某个值其实就是值全局最小值)。前面,我们把最大迭代轮次设为了200,并据此计算除了线性回归参数theta的值为matrix([[-3., 1.]])。而用正规方程计算出的theta参数的精确值为matrix([[-3.],[ 1.]]),二者的差别在可接受范围内。关于用正规方程求解线性回归参数可以参考:https://blog.csdn.net/_/article/details/、https://blog.csdn.net/_/article/details/。

根据前文计算出的theta参数值,绘制原始数据的线性拟合图:

x = np.linspace(data.Population.min(),data.Population.max(),100) f = g[0,0] + g[0,1]*x fig,axes = plt.subplots() axes.plot(x,f,'r',label='Fitted') axes.scatter(x=data.Population,y=data.Profits,label='Trainning data') axes.legend(loc='best') axes.set(xlabel='Population',ylabel='Profits',title='Population vs. Profits') fig.savefig('p3.png')

绘制的线性拟合图为:

随机梯度下降算法的Python实现

批量梯度下降算法与随机梯度下降算法的比较:

    1)批量梯度下降算法在每次迭代更新目标函数的参数时,是将训练数据集中的所有样本都考虑进去,以此计算代价函数;而随机梯度下降算法在每次迭代更新目标函数的参数时,只考虑数据集中的一个样本并据此计算代价函数。因此,当训练数据集非常大时,随机梯度下降的迭代速度要比批量梯度下降的迭代速度快很多。

两者的代价函数具体如下图所示:

随机梯度下降算法的Python实现

上图中,左边是批量梯度下降的代价函数,右边是随机梯度下降的代价函数。

    2)不同于批量梯度下降,当迭代到一定轮次时,随机梯度下降计算出的代价函数是在某个靠近全局最小值的区域内徘徊,而不是直接逼近全局最小值并停留在那点。关于随机梯度下降的收敛性,可以参考:https://www.zhihu.com/question/。

其他参考资料:

《Python Machine Learning Second Edition》——Vahid Mirjalili&Sebastian Raschka

Andew Ng机器学习公开课

https://blog.csdn.net/_/article/details/

PS:本文为博主原创文章,转载请注明出处。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/223292.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月17日 下午2:25
下一篇 2026年3月17日 下午2:25


相关推荐

  • Java基础入门笔记06——String类和StringBuffer类,Java中的三大集合,Set集合,List集合,Map集合,Collection类

    Java基础入门笔记06——String类和StringBuffer类,Java中的三大集合,Set集合,List集合,Map集合,Collection类常用类String类和StringBuffer类String类不能改变串对象中的内容,每次操作后都产生一个新串StringBuffer类可以实现字符串动态改变,对原对象增删改查equals()和”==”的区别equals()仅判断值是否相等“==”判断值还要判断引用是否相等length():获取字符串的字符个数length:获取数组长度toCharArray():将字符串对象转换为字符数组Java的三种集合都是接口,需要具体实现集合类存在于java.util包中,是一个用来存放

    2022年8月8日
    10
  • Pytest(10)assert断言[通俗易懂]

    Pytest(10)assert断言[通俗易懂]前言断言是写自动化测试基本最重要的一步,一个用例没有断言,就失去了自动化测试的意义了。什么是断言呢?简单来讲就是实际结果和期望结果去对比,符合预期那就测试pass,不符合预期那就测试failed

    2022年8月6日
    6
  • httprunner(8)用例调用-RunTestCase[通俗易懂]

    httprunner(8)用例调用-RunTestCase[通俗易懂]前言一般我们写接口自动化的时候,遇到复杂的逻辑,都会调用API方法来满足前置条件,Pytest的特性是无法用例之间相互调动的,我们一般只调用自己封装的API方法。而httprunner支持用例之间

    2022年7月28日
    16
  • U盘市场调查分析

    U盘市场调查分析U盘市场分析2005年前后,是U盘行业最辉煌的年代,然而随着互联网的普及应用,各种网盘以及智能手机传输越来越方便,给U盘施加了一定的压力。除此之外,近几年移动硬盘迅速兴起,更是抢占了大量市场份额,让U盘厂商的日子越来越艰难。但凭借体积小、携带方便、即插即用、稳定等特点,U盘在移动存储领域依然占据着一定优势。面对着各类对手的“围攻”,U盘厂商们一方面升级主流产品,留住客户群,另一方面则纷纷开发细分市场,通过挖掘用户不同领域的需求,以实现新的业务增长。需求下降,U盘行业萎缩 2017年有业内人士发出.

    2025年10月10日
    6
  • 离散 单射 满射 双射

    离散 单射 满射 双射单射双射满射阐述一下什么是单射,双射,满射1.单射:对于每一个不同的x都有不同的y,即x1!=x2–>y1!+y2条件:|X|<=|Y|2.满射:对于每一个y都有x与之对应条件:|Y|<=|X|3.双射:既是单射又是满射条件:|X|=|Y|代码实现通过map函数建立映射1.单射:map<int,int>BuildInjection(vector<int>src,vector<int>dst){map&l

    2022年6月10日
    36
  • MVEL 简单介绍

    MVEL 简单介绍MVEL 是一种基于 java 语法的表达式语言 为 java 提供更便捷灵活的动态性 这里简单介绍一些 MVEL 的操作 new 创建一个 java 对象 newString foo 当然这里是举个栗子 String 的创建一般是不同这个构造函数的 对于 java lang 中的类 无需手动导入 如果是创建自定义的对象 就需要写明类全路径名 或者手动导入 当需要针对同一个对象进行

    2026年3月20日
    1

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号