感知机实现代码

感知机实现代码1 线性可分数据 1 1 读取数据 importpandas metricsimpor scoredata pd read csv r E data perceptron 15 dat sep s header None data columns x1

  感知机模型的基本原理是使得误分类点最小,而数据是基于林轩田老师的《机器学习基石》,如果实在找不到,也可以通过链接进行下载:CSDN下载地址。

1. 线性可分数据

1.1 读取数据

import pandas as pd import numpy as np from sklearn.metrics import accuracy_score data = pd.read_csv(r'E:/data/perceptron_15.dat', sep = '\s', header = None) data.columns = ['x1', 'x2', 'x3', 'x4', 'y'] X_train = data.loc[:, ['x1', 'x2', 'x3', 'x4']] y_train = data.loc[:, 'y'] X_train = X_train.values y_train = y_train.values 

1.2 感知机实现

  其中为了加快运算速度,把b融入到了W中,所以对训练数据X也要做一定的处理(左边插入一列为1的数据)。

class PerceptronSeparable(): def __init__(self): self.w = None self.b = None self.W = None self.eta = 0.01 def sign(self, x): if x >= 0: return 1 else: return -1 #检查单个数据是否预测正确 def check_one(self, x_train, y_train): x_train_one = 1 x_train_extend = np.append(x_train_one, x_train) y_predict = np.dot(x_train_extend, self.W) y_predict = self.sign(y_predict) return y_predict == y_train #检查所有数据是否预测正确 def check_all(self, X_train, y_train): X_train_ones = np.ones([len(X_train) ,1]) X_train_extend = np.concatenate([X_train_ones, X_train], axis = 1) y_predict = np.dot(X_train_extend, self.W) y_predict = np.array(list(map(self.sign, y_predict))) return (y_predict == y_train).all() def fit(self, X_train, y_train): self.w = np.zeros([X_train.shape[1], 1]) self.b = 0 self.W = np.append(self.b, self.w) times = 0 while True: # 如果全部预测正确则退出 if self.check_all(X_train, y_train): break else: random_index = np.random.randint(0, len(X_train)) x_train_single = X_train[random_index, :].reshape(-1, 1) y_train_single = y_train[random_index] if not self.check_one(x_train_single, y_train_single): self.w = self.w + self.eta * x_train_single * y_train_single self.b = self.b + self.eta * y_train_single self.W = np.append(self.b, self.w) times += 1 def predict(self, X_test): X_test_ones = np.ones([len(X_test), 1]) X_test_new = np.concatenate([X_test_ones, X_test], axis=1) y_test_predict = np.dot(X_test_new, self.W) y_test_predict = np.array(list(map(self.sign, y_test_predict))) return y_test_predict 

1.3 效果评估

perceptron = PerceptronSeparable() perceptron.fit(X_train, y_train) print(accuracy_score(y_train, perceptron.predict(X_train))) 

  由于是线性可分数据,打印结果应该是1.0,否则说明代码是有bug。

2. 线性不可分数据

2.1 读取数据

import pandas as pd import numpy as np from sklearn.metrics import accuracy_score data = pd.read_csv(r'E:/data/perceptron_18.dat', sep = '\s', header = None) data.columns = ['x1', 'x2', 'x3', 'x4', 'y'] X_train = data.loc[:, ['x1', 'x2', 'x3', 'x4']] y_train = data.loc[:, 'y'] X_train = X_train.values y_train = y_train.values 

2.2 感知机实现

  实现是大同小异的。最大的区别在于,每次改变参数以后会和之前最优的结果进行比较,如果不如以前的结果,则模型回退到上一个版本。

class PerceptronNonSeparable(): def __init__(self): self.w = None self.b = None self.W = None self.eta = 0.01 def sign(self, x): if x >= 0: return 1 else: return -1 def check_one(self, x_train, y_train): x_train_one = 1 x_train_extend = np.append(x_train_one, x_train) y_predict = np.dot(x_train_extend, self.W) y_predict = self.sign(y_predict) return y_predict == y_train def check_all(self, X_train, y_train): X_train_ones = np.ones([len(X_train) ,1]) X_train_extend = np.concatenate([X_train_ones, X_train], axis = 1) y_predict = np.dot(X_train_extend, self.W) y_predict = np.array(list(map(self.sign, y_predict))) return (y_predict == y_train).all() def get_predict_right_nums(self, X_train, y_train): X_train_ones = np.ones([len(X_train) ,1]) X_train_extend = np.concatenate([X_train_ones, X_train], axis = 1) y_predict = np.dot(X_train_extend, self.W) y_predict = np.array(list(map(self.sign, y_predict))) return np.sum(y_predict == y_train) def fit(self, X_train, y_train): self.w = np.zeros([X_train.shape[1], 1]) self.b = 0 self.W = np.append(self.b, self.w) times = 0 predict_best_num = 0 for i in range(1000): if self.check_all(X_train, y_train): break else: random_index = np.random.randint(0, len(X_train)) x_train_single = X_train[random_index, :].reshape(-1, 1) y_train_single = y_train[random_index] if not self.check_one(x_train_single, y_train_single): bak_w, bak_b, bak_W = self.w, self.b, self.W self.w = self.w + self.eta * x_train_single * y_train_single self.b = self.b + self.eta * y_train_single self.W = np.append(self.b, self.w) predict_right_num = self.get_predict_right_nums(X_train, y_train) if predict_right_num < predict_best_num: self.w, self.b, self.W = bak_w, bak_b, bak_W def predict(self, X_test): X_test_ones = np.ones([len(X_test), 1]) X_test_new = np.concatenate([X_test_ones, X_test], axis=1) y_test_predict = np.dot(X_test_new, self.W) y_test_predict = np.array(list(map(self.sign, y_test_predict))) return y_test_predict 

2.3 效果评估

perceptron = PerceptronNonSeparable() perceptron.fit(X_train, y_train) print(accuracy_score(y_train, perceptron.predict(X_train))) 

  由于是线性不可分数据,打印结果应该小于1.0,但也应该大于0.7,否则说明代码是有bug。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/176489.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月26日 下午9:48
下一篇 2026年3月26日 下午9:48


相关推荐

  • MySQL二进制日志格式类型详解「建议收藏」

    MySQL二进制日志格式类型详解「建议收藏」mysql很多有类型的日志,按照组件划分的话,可以分为服务层日志和存储引擎层日志:-服务层日志:二进制日志、慢查日志、通用日志-存储引擎层日志:innodb(重做日志、回滚日志)其中比较重要的就是服务器层的二进制日志,其中记录了所有对mysql数据库的修改事件,包括增删改查事件和对表结构的修改事件。要注意的一点是,只有成功执行了的事件才会记录在二进制日志中,未执行成功的不会保存

    2022年6月5日
    75
  • PostgresSQL 分页查询 SQL语句

    PostgresSQL 分页查询 SQL语句SELECT*FROM“库名”.“表名”wheretellike‘%1%’orderbyidasclimit3OFFSET0;

    2022年10月19日
    5
  • qt串口通信接收数据不完整_qt串口接收数据

    qt串口通信接收数据不完整_qt串口接收数据高通QM215高速串口调试总结参考文档硬件和复用情况确认修改如下串口调试测试程序代码:将串口设置为高速串口,AP端收到的数据一直为0XFD参考文档1、sp80-pk881-6_a_qm215_linux_android_software_porting_manual.pdf2、80-pk881-21_a_qm215_linux_peripheral_(uart,_spi,_i2c)_ove…

    2022年10月10日
    13
  • 0元搭建卡盟主站_哪个卡盟平台好

    0元搭建卡盟主站_哪个卡盟平台好设置桶配额功能说明设置桶的配额值,单位为字节,支持的最大值为263-1,配额值设为0表示桶的配额没有上限。方法定义1.ObsClient->setBucketQuota(array$parameter)2.ObsClient->setBucketQuotaAsync(array何查看桶标签://引入依赖库require’vendor/autoload.php’;//…

    2022年8月13日
    8
  • Spring配置与第一Spring HelloWorld

    Spring配置与第一Spring HelloWorld

    2022年1月7日
    47
  • 解决ie8下onpropertychange事件间歇性失效的问题「建议收藏」

    有的时候onpropertychange事件一下好用,一下不好用网上有的说去掉&lt;!DOCTYPEhtml&gt;就好了,我试了下,虽然然管用,但doctype是推荐加上的,去掉他有些东西会乱。如果文本框的样式中有width属性,没有height属性就会出现此问题,不知道是什么原因&lt;inputtype="text"id="name"name="na…

    2022年4月7日
    85

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号