多层感知器算法
#导入MLP分类器 from sklearn.neural_network import MLPClassifier
#导入红酒数据集 from sklearn.datasets import load_wine wine=load_wine()
#选取前两个特征进行建模 X=wine.data[:,:2] #类别变量 y=wine.target
#数据可视化 import matplotlib.pyplot as plt %matplotlib inline plt.figure(dpi=100) plt.scatter(X[:,0],X[:,1],c=y,cmap=plt.cm.spring,edgecolors='k')

#拆分训练集与测试集 from sklearn.model_selection import train_test_split X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3,random_state=8)
#MLPClassifier"三板斧" mlp=MLPClassifier(random_state=8) mlp.fit(X_train,y_train) print(mlp.score(X_train,y_train),mlp.score(X_test,y_test))
0.35484 0.
#参数设置 mlp=MLPClassifier(hidden_layer_sizes=(100,100),random_state=8) mlp.fit(X_train,y_train) print(mlp.score(X_train,y_train),mlp.score(X_test,y_test))
0.70968 0.77778
mlp=MLPClassifier(hidden_layer_sizes=(100,100,100),random_state=8) mlp.fit(X_train,y_train) print(mlp.score(X_train,y_train),mlp.score(X_test,y_test))
0.32258 0.77778
mlp=MLPClassifier(hidden_layer_sizes=(120,120,200),activation='tanh',random_state=8) mlp.fit(X_train,y_train) print(mlp.score(X_train,y_train),mlp.score(X_test,y_test))
0.03226 0.48148
mlp=MLPClassifier(hidden_layer_sizes=(100,100), activation='relu', solver='sgd', random_state=8) mlp.fit(X_train,y_train) print(mlp.score(X_train,y_train),mlp.score(X_test,y_test))
0.03226 0.5
mlp=MLPClassifier(hidden_layer_sizes=(100,100), activation='relu', #solver='lbfgs', random_state=8) mlp.fit(X_train,y_train) print(mlp.score(X_train,y_train),mlp.score(X_test,y_test))
0.70968 0.77778
import numpy as np #拟合结果可视化 x_min,x_max=X[:,0].min()-0.5,X[:,0].max()+0.5 y_min,y_max=X[:,1].min()-0.5,X[:,1].max()+0.5 xx,yy=np.meshgrid(np.arange(x_min,x_max,.02),np.arange(y_min,y_max,.02)) Z=mlp.predict(np.c_[xx.ravel(),yy.ravel()]) Z=Z.reshape(xx.shape) plt.figure() plt.pcolormesh(xx,yy,Z,cmap=plt.cm.Pastel1) plt.scatter(X[:,0],X[:,1],c=y,cmap=plt.cm.winter,edgecolor='k') plt.xlim(xx.min(),xx.max()) plt.ylim(yy.min(),yy.max()) plt.title("Classifier:MLPClassifier") plt.show()

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/177088.html原文链接:https://javaforall.net
