机器学习之linear_model(普通最小二乘法手写+sklearn实现+评价指标)

机器学习之linear_model(普通最小二乘法手写+sklearn实现+评价指标)y wi xi b 基于最小二乘法的线性回归 寻找参数 w 和 b 使得 w 和 b 对 x test data 的预测值 y pred data 与真实的回归目标 y test data 之间的均方误差最小

y=wi*xi+b,基于最小二乘法的线性回归:寻找参数w和b,使得w和b对x_test_data的预测值y_pred_data与真实的回归目标y_test_data之间的均方误差最小。

from sklearn import linear_model import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import mean_squared_error,r2_score,mean_absolute_error 

sklearn中有专门的线性模型包linear_model,numpy用于生成数据,matplotlib用于画图,另外导入MSE,R_Square和MAE三个评价指标。
2、构造数据集。可以自动生成数据,也可以寻找现有数据,以下数据是作业中的数据,样本数据只有一个特征。
3、训练模型。
4、输出系数w和截距b并对测试集进行预测。
5、作图。








完整代码:

import pandas as pd import matplotlib.pyplot as plt from sklearn import linear_model import numpy as np from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error def load_data(): data = pd.read_csv('Salary_Data.csv', encoding='gbk') data = data.values.tolist() train_x = [] train_y = [] test_x = [] test_y = [] # 前一半作为训练集,后一半作为测试集 for i in range(len(data)): if i < len(data) / 2: train_x.append(data[i][0]) train_y.append(data[i][1]) else: test_x.append(data[i][0]) test_y.append(data[i][1]) return train_x, train_y, test_x, test_y def model(): print('手写:') train_x, train_y, test_x, test_y = load_data() # 最小二乘法得到参数 sum = 0.0 sum_square = 0.0 sum_2 = 0.0 sum_b = 0.0 for i in range(len(train_x)): sum = sum + train_x[i] sum_square = sum_square + train_x[i]  2 ave_x = sum / len(train_x) for i in range(len(train_x)): sum_2 = sum_2 + (train_y[i] * (train_x[i] - ave_x)) w = sum_2 / (sum_square - sum  2 / len(train_x)) for i in range(len(train_x)): sum_b = sum_b + (train_y[i] - w * train_x[i]) b = sum_b / len(train_x) print('w=', w, 'b=', b) # 测试 pred_y = [] for i in range(len(test_x)): pred_y.append(w * test_x[i] + b) # 计算MSE,MAE,r2_score sum_mse = 0.0 sum_mae = 0.0 sum1 = 0.0 sum2 = 0.0 for i in range(len(pred_y)): sum_mae = sum_mae + np.abs(pred_y[i] - test_y[i]) sum_mse = sum_mse + (pred_y[i] - test_y[i])  2 sum_y = 0.0 for i in range(len(test_y)): sum_y = sum_y + test_y[i] ave_y = sum_y / len(test_y) for i in range(len(pred_y)): sum1 = sum1 + (pred_y[i] - test_y[i])  2 sum2 = sum2 + (ave_y - test_y[i])  2 print('MSE:', sum_mse / len(pred_y)) print('MAE:', sum_mae / len(pred_y)) print('R2_Squared:', 1 - sum1 / sum2) # 显示 plt.scatter(test_x, test_y, color='black') plt.plot(test_x, pred_y, color='blue', linewidth=3) plt.show() print('\n') # 调包 def sklearn_linearmodel(): print('调包:') train_x, train_y, test_x, test_y = load_data() train_x = np.array(train_x).reshape(-1, 1) train_y = np.array(train_y).reshape(-1, 1) test_x = np.array(test_x).reshape(-1, 1) test_y = np.array(test_y).reshape(-1, 1) # 训练+测试 lr = linear_model.LinearRegression() lr.fit(train_x, train_y) y_pred = lr.predict(test_x) # 输出系数和截距 print('w:', lr.coef_, 'b:', lr.intercept_) # 输出评价指标 print('MSE:', mean_squared_error(test_y, y_pred)) print('MAE:', mean_absolute_error(test_y, y_pred)) print('R2_Squared:', r2_score(test_y, y_pred)) # 显示 plt.scatter(test_x, test_y, color='black') plt.plot(test_x, y_pred, color='blue', linewidth=3) plt.show() if __name__ == '__main__': model() sklearn_linearmodel() 

在这里插入图片描述

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/177690.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月26日 下午6:56
下一篇 2026年3月26日 下午6:56


相关推荐

  • 掌握 8 組電商套圖提示詞模板,用 Nano Banana 批量生成專業產品圖

    掌握 8 組電商套圖提示詞模板,用 Nano Banana 批量生成專業產品圖

    2026年3月15日
    1
  • DM368_了解电脑硬件基本知识

    DM368_了解电脑硬件基本知识最近到了找工作准备期,之前已将C语言、数据结构与算法、APUE总结完毕,现在需要抓紧将以往项目加以总结。关于DM368首先我们先从硬件部分开始讲起,然后再讲环境搭建、系统移植、文件烧写、最后程序开发。一、认识开发板参看下面网址可下载DM368参考原理图和Gerber文件。参看:EVMDM368SupportHome参看:EVMDM365SupportHomeDM365与DM

    2022年8月13日
    10
  • MegaCli 使用

    MegaCli 使用安装 wgetftp rpmfind net linux Mandriva devel cooker x86 64 media non free release megacli 8 02 21 1 mdv2012 0 x86 64 rpmrpm ivhmegacli 8 02 21 1 mdv2012 0 x86 64 rpm 如何使用 megacli 功能

    2026年3月6日
    2
  • auto.js微信自动回复脚本_微信群助手机器人

    auto.js微信自动回复脚本_微信群助手机器人一、前言整体思路1)找到头像右上角有消息标志的聊天(注意直接跑下面代码的时候请确保聊天界面由此前提)2)点击进入聊天窗口,找到所有消息3)取最后一个消息(最新消息)4)和之前的新消息对比是否发生变化5)新消息推送至API6)收到API消息发送微信v8版本发送消息时,不再显示“发送”按钮了,也就没办法用找到“发送”控件的方法实现发送消息了。尝试用KeyCode(code)方式,发送回车键,发现也无效,原因查了一下好像是需要ROOT还是安卓9以上此方法失效。于是用坐标点击的方式点击键盘上的

    2022年9月30日
    6
  • android studio快捷键集合[通俗易懂]

    \itemCtrl+P  查看变量参数信息,也就是看变量是哪种类型  \item Ctrl+B  查找该变量的定义位置。  \item Ctrl+Q  查找快速文档,即在另外一个窗口中打开其声明  \item Alt+Shift+C  查看工程最近更改的地方  \item Ctrl+space  自动完成代码  \item Ctrl+shift+Enter  自动填充表达式

    2022年3月10日
    46
  • kettle 教程(四):自定义 Java 代码

    kettle 教程(四):自定义 Java 代码kettle拥有很多自带的组件,能帮我们实现很多的功能。但是我们总有一些很复(qi)杂(pa)的需求,用自带的组件实现不了,或者说实现起来很复杂。那么这时我们就要用到万能的组件了(Java代码),通过自己写代码来实现任何想要的功能。自定义Java代码假设有这样一个需…

    2022年5月23日
    246

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号