随机梯度下降法概述与实例分析_梯度下降法推导

随机梯度下降法概述与实例分析_梯度下降法推导机器学习算法中回归算法有很多,例如神经网络回归算法、蚁群回归算法,支持向量机回归算法等,其中也包括本篇文章要讲述的梯度下降算法,本篇文章将主要讲解其基本原理以及基于SparkMLlib进行实例示范,不足之处请多多指教。梯度下降算法包含多种不同的算法,有批量梯度算法,随机梯度算法,折中梯度算法等等。对于随机梯度下降算法而言,它通过不停的判断和选择当前目标下最优的路径,从而能够在最短路径…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

机器学习算法中回归算法有很多,例如神经网络回归算法、蚁群回归算法,支持向量机回归算法等,其中也包括本篇文章要讲述的梯度下降算法,本篇文章将主要讲解其基本原理以及基于Spark MLlib进行实例示范,不足之处请多多指教。

梯度下降算法包含多种不同的算法,有批量梯度算法,随机梯度算法,折中梯度算法等等。对于随机梯度下降算法而言,它通过不停的判断和选择当前目标下最优的路径,从而能够在最短路径下达到最优的结果。我们可以在一个人下山坡为例,想要更快的到达山低,最简单的办法就是在当前位置沿着最陡峭的方向下山,到另一个位置后接着上面的方式依旧寻找最陡峭的方向走,这样每走一步就停下来观察最下路线的方法就是随机梯度下降算法的本质。
这里写图片描述

随机梯度下降算法理论基础

在线性回归中,我们给出回归方程,如下所示:
这里写图片描述
我们知道,对于最小二乘法要想求得最优变量就要使得计算值与实际值的偏差的平方最小。而随机梯度下降算法对于系数需要通过不断的求偏导求解出当前位置下最优化的数据,那么梯度方向公式推导如下公式,公式中的θ会向着梯度下降最快的方向减少,从而推断出θ的最优解。

这里写图片描述

因此随机梯度下降法的公式归结为通过迭代计算特征值从而求出最合适的值。θ的求解公式如下。
这里写图片描述

α是下降系数,即步长,学习率,通俗的说就是计算每次下降的幅度的大小,系数越大每次计算的差值越大,系数越小则差值越小,但是迭代计算的时间也会相对延长。θ的初值可以随机赋值,比如下面的例子中初值赋值为0。

Spark MLlib随机梯度下降算法实例

下面使用Spark MLlib来迭代计算回归方程y=2x的θ最优解,代码如下:

package cn.just.shinelon.MLlib.Algorithm

import java.util

import scala.collection.immutable.HashMap

/**
  * 随机梯度下降算法实战
  * 随机梯度下降算法:最短路径下达到最优结果
  * 数学表达公式如下:
  * f(θ)=θ0x0+θ1x1+θ2x2+...+θnxn
  * 对于系数要通过不停地求解出当前位置下最优化的数据,即不停对系数θ求偏导数
  * 则θ求解的公式如下:
  * θ=θ-α(f(θ)-yi)xi
  * 公式中α是下降系数,即每次下降的幅度大小,系数越大则差值越小,系数越小则差值越小,但是计算时间也相对延长
  */
object SGD {
  var data=HashMap[Int,Int]()         //创建数据集
  def getdata():HashMap[Int,Int]={
    for(i <- 1 to 50){                //创建50个数据集
      data += (i->(2*i))              //写入公式y=2x
    }
    data                              //返回数据集
  }

  var θ:Double=0                        //第一步 假设θ为0
  var α:Double=0.1                      //设置步进系数

  def sgd(x:Double,y:Double)={        //随机梯度下降迭代公式
    θ=θ-α*((θ*x)-y)                 //迭代公式
  }

  def main(args: Array[String]): Unit = {
    val dataSource=getdata()          //获取数据集
    dataSource.foreach(myMap=>{       //开始迭代
      sgd(myMap._1,myMap._2)          //输入数据
    })
    println("最终结果值θ为:"+θ)
  }
}

需要注意的是随着步长系数增大以及数据量的增大,θ值偏差越来越大。同时这里也遗留下一个问题,当数据量大到一定程度,为什么θ值会为NaN,笔者心中有所疑惑,如果哪位大佬有想法可以留言探讨,谢谢!!!


如果你想和我一起学习交流,共同进步,欢迎加群:
在这里插入图片描述

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/195274.html原文链接:https://javaforall.net

(0)
上一篇 2025年10月20日 下午2:43
下一篇 2025年10月20日 下午3:15


相关推荐

  • NGINX 配置404错误页面转向

    NGINX 配置404错误页面转向

    2021年9月24日
    45
  • java虚拟机可以运行的文件_虚拟机的网络模型有

    java虚拟机可以运行的文件_虚拟机的网络模型有Java虚拟机中的内存模型?Java虚拟机运行时内存所有的类的实例(不包括局部变量与方法参数)都存储在Java堆中,每条线程有自己的工作内存(Java栈),不同线程之间无法直接访问对方工作内存中的变量。方法区用于存储被虚拟机加载的类信息、常量、static变量等数据,堆用于存储对象实例,比如通过new创建的对象实例就保存在堆中,堆中的对象的由垃圾回收器负责回收。Java栈用于实现方法调用,每次方法调用就对应栈中的一个栈帧,栈帧包含局部变量表、操作数栈、方法接口等于方法相关的信息,栈中的数据当没有引用指向

    2025年11月25日
    6
  • goland激活码最新【2021.7最新】

    (goland激活码最新)JetBrains旗下有多款编译器工具(如:IntelliJ、WebStorm、PyCharm等)在各编程领域几乎都占据了垄断地位。建立在开源IntelliJ平台之上,过去15年以来,JetBrains一直在不断发展和完善这个平台。这个平台可以针对您的开发工作流进行微调并且能够提供…

    2022年3月21日
    52
  • 功率放大器和匹配网络学习

    功率放大器PA学习导通角:在一个周期内,由电力电子器件(如晶闸管)控制其导通的角度。交流电一般为正弦波,正半周占180°,负半周占180°。当交流电通过可控硅时,可以让交流电电流通过控制使其在0-180度的任一角度处开始导通,即所谓可控整流,当正半周加到可控硅的阳极,在180度的某一角度时,在可控硅的控制极加一触发脉冲,例如在30度加一脉冲,可控硅只能通过余下的150度的电流。这种使可控硅导电…

    2022年4月11日
    46
  • windows10系统下vue开发环境搭建

    windows10系统下vue开发环境搭建安装NodeJs下载地址:http://nodejs.cn/download/到官网下载自己系统对应的版本,按照推荐的方式默认安装,这里不再赘述。安装完成后,打卡powershell,执行命令node-v查询一下,检查是否正常安装。如果提示找不到node命令,添加node安装路径到系统环境变量,重启powershell,再试。如果你安装的是旧版本的npm,可以很容易得通过npm命令来升级。sudonpminstallnpm-g#linuxnpminstallnpm-g

    2022年10月20日
    4
  • Vim配置文件vimrc入门介绍

    Vim配置文件vimrc入门介绍本文转载自:vim教程网Vim入门级基础配置-Vim入门教程(1)介绍Vim配置文件.vimrc,配置Vim显示行号、支持utf8中文不乱码、突出显示Vim当前行,设置高亮显示括号匹配和tab缩进,解决Vim粘贴时多出缩进和空格问题。一、Vim配置文件.vimrcVim编辑器相关的所有功能开关都可以通过.vimrc文件进行设置。.vimrc配置文件分系统配置和用户配置两种。系…

    2022年4月30日
    111

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号