从零和使用mxnet实现softmax分类

1.softmax从零实现(1797,64)(1797,)(1797,10)epoch:50,loss:[1.9941667],accuracy:0.3550361713967724

大家好,又见面了,我是全栈君,今天给大家准备了Idea注册码。

1.softmax从零实现

from mxnet.gluon import data as gdata
from sklearn import datasets
from mxnet import nd,autograd
# 加载数据集
digits = datasets.load_digits()
features,labels = nd.array(digits['data']),nd.array(digits['target'])
print(features.shape,labels.shape)
labels_onehot = nd.one_hot(labels,10)
print(labels_onehot.shape)
(1797, 64) (1797,)
(1797, 10)
class softmaxClassifier:
    def __init__(self,inputs,outputs):
        self.inputs = inputs
        self.outputs = outputs
        
        self.weight = nd.random.normal(scale=0.01,shape=(inputs,outputs))
        self.bias = nd.zeros(shape=(1,outputs))
        self.weight.attach_grad()
        self.bias.attach_grad()
        
    def forward(self,x):
        output = nd.dot(x,self.weight) + self.bias
        return self._softmax(output)
        
    def _softmax(self,x):
        step1 = x.exp()
        step2 = step1.sum(axis=1,keepdims=True)
        return step1 / step2
    
    def _bgd(self,params,learning_rate,batch_size):
        '''
        批量梯度下降
        '''
        for param in params:       # 直接使用mxnet的自动求梯度
            param[:] = param - param.grad * learning_rate / batch_size
            
    def loss(self,y_pred,y):
        return nd.sum((-y * y_pred.log())) / len(y)
            
    def dataIter(self,x,y,batch_size):
        dataset = gdata.ArrayDataset(x,y)
        return gdata.DataLoader(dataset,batch_size,shuffle=True)
    
    def fit(self,x,y,learning_rate,epoches,batch_size):
        for epoch in range(epoches):
            for x_batch,y_batch in self.dataIter(x,y,batch_size):
                with autograd.record():
                    y_pred = self.forward(x_batch)
                    l = self.loss(y_pred,y_batch)
                l.backward()
                self._bgd([self.weight,self.bias],learning_rate,batch_size)
            if epoch % 50 == 0:
                y_all_pred = self.forward(x)
                print('epoch:{},loss:{},accuracy:{}'.format(epoch+50,self.loss(y_all_pred,y),self.accuracyScore(y_all_pred,y)))
            
    def predict(self,x):
        y_pred = self.forward(x)
        return y_pred.argmax(axis=0)
    
    def accuracyScore(self,y_pred,y):
        acc_sum = (y_pred.argmax(axis=1) == y.argmax(axis=1)).sum().asscalar()
        return acc_sum / len(y)
sfm_clf = softmaxClassifier(64,10)
sfm_clf.fit(features,labels_onehot,learning_rate=0.1,epoches=500,batch_size=200)
epoch:50,loss:
[1.9941667]
<NDArray 1 @cpu(0)>,accuracy:0.3550361713967724
epoch:100,loss:
[0.37214527]
<NDArray 1 @cpu(0)>,accuracy:0.9393433500278241
epoch:150,loss:
[0.25443634]
<NDArray 1 @cpu(0)>,accuracy:0.9549248747913188
epoch:200,loss:
[0.20699367]
<NDArray 1 @cpu(0)>,accuracy:0.9588202559821926
epoch:250,loss:
[0.1799827]
<NDArray 1 @cpu(0)>,accuracy:0.9660545353366722
epoch:300,loss:
[0.1619963]
<NDArray 1 @cpu(0)>,accuracy:0.9677239844184753
epoch:350,loss:
[0.14888664]
<NDArray 1 @cpu(0)>,accuracy:0.9716193656093489
epoch:400,loss:
[0.13875261]
<NDArray 1 @cpu(0)>,accuracy:0.9738452977184195
epoch:450,loss:
[0.13058177]
<NDArray 1 @cpu(0)>,accuracy:0.9760712298274903
epoch:500,loss:
[0.12379646]
<NDArray 1 @cpu(0)>,accuracy:0.9777406789092933
print('预测结果:',sfm_clf.predict(features[:10]))
print('真实结果:',labels[:10])
预测结果: 
[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
<NDArray 10 @cpu(0)>
真实结果: 
[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
<NDArray 10 @cpu(0)>

2.使用mxnet实现softmax分类

from mxnet import gluon,nd,autograd,init
from mxnet.gluon import nn,trainer,loss as gloss,data as gdata
# 定义模型
net = nn.Sequential()
net.add(nn.Dense(10))

# 初始化模型
net.initialize(init=init.Normal(sigma=0.01))

# 损失函数
loss = gloss.SoftmaxCrossEntropyLoss(sparse_label=False)

# 优化算法
optimizer = trainer.Trainer(net.collect_params(),'sgd',{'learning_rate':0.1})

# 训练
epoches = 500
batch_size = 200

dataset = gdata.ArrayDataset(features, labels_onehot)
data_iter = gdata.DataLoader(dataset,batch_size,shuffle=True)
for epoch in range(epoches):
    for x_batch,y_batch in data_iter:
        with autograd.record():
            l = loss(net.forward(x_batch), y_batch).sum() / batch_size
        l.backward()
        optimizer.step(batch_size)
    if epoch % 50 == 0:
        y_all_pred = net.forward(features)
        acc_sum = (y_all_pred.argmax(axis=1) == labels_onehot.argmax(axis=1)).sum().asscalar()
        print('epoch:{},loss:{},accuracy:{}'.format(epoch+50,loss(y_all_pred,labels_onehot).sum() / len(labels_onehot),acc_sum/len(y_all_pred)))
epoch:50,loss:
[2.1232333]
<NDArray 1 @cpu(0)>,accuracy:0.24652198107957707
epoch:100,loss:
[0.37193483]
<NDArray 1 @cpu(0)>,accuracy:0.9410127991096272
epoch:150,loss:
[0.25408813]
<NDArray 1 @cpu(0)>,accuracy:0.9543683917640512
epoch:200,loss:
[0.20680156]
<NDArray 1 @cpu(0)>,accuracy:0.9627156371730662
epoch:250,loss:
[0.1799252]
<NDArray 1 @cpu(0)>,accuracy:0.9666110183639399
epoch:300,loss:
[0.16203885]
<NDArray 1 @cpu(0)>,accuracy:0.9699499165275459
epoch:350,loss:
[0.14899409]
<NDArray 1 @cpu(0)>,accuracy:0.9738452977184195
epoch:400,loss:
[0.13890252]
<NDArray 1 @cpu(0)>,accuracy:0.9749582637729549
epoch:450,loss:
[0.13076076]
<NDArray 1 @cpu(0)>,accuracy:0.9755147468002225
epoch:500,loss:
[0.1239901]
<NDArray 1 @cpu(0)>,accuracy:0.9777406789092933
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/120005.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • java实现发送邮件功能

    java实现发送邮件功能java实现发送邮件功能电子邮件开发在后台中是普遍存在的现象和功能,比如用户注册,系统自动发送一封电子邮件到用户邮箱;再比如密码找回,系统会自动把密码发送到用户邮箱;……等等,所以作为一名java程序员,还是有必要学会这项技能的。我是一名安卓开发人员,我们都知道在客户端和后台交互数据的时候用到了Http协议,那么相应的,邮箱传输也有自己的一套协议,如SMTP,POP3,IMAP。在原生的javaJ

    2022年5月14日
    55
  • Linux杀死进程命令:kill、killall、pkill

    Linux杀死进程命令:kill、killall、pkillkill命令:   1.格式:kill[信号]进程id   2..查看经常信号:kill-l   3.常用命令:          平滑重启进程:kill-1进程id          强制杀死进程:kill-9进程id#查看进程可用psaux命令killall命令:   1.格式:killall[信…

    2022年9月1日
    4
  • 如何设置网址跳转_怎么让域名跳转到另一个域名

    如何设置网址跳转_怎么让域名跳转到另一个域名一、什么是URL转发?URL(UniformResourceLocator:统一资源定位器)是WWW页的地址,它从左到右由下述部分组成:Internet资源类型(scheme):指出WWW客户程序用来操作的工具。如“http://”表示WWW服务器,“ftp://”表示FTP服务器,“gopher://”表示Gopher服务器,而“new:”表示Newgroup新闻组。服务器地址…

    2022年10月10日
    2
  • Windows Server 2012修改光驱盘符

    Windows Server 2012修改光驱盘符WindowsServer2012修改光驱盘符,可通过下面的步骤完成:windows+R,输入diskmgmt.msc这时可以看到我们熟悉的磁盘管理界面:右击盘符,选择“更改驱动器号和路径”或“ChangeDriveLetterandPaths”。修改盘符: 转载于:https://blog.51cto.com/sincano/1880125…

    2022年5月26日
    52
  • VB学习总结心得2–理想的学习方法

    VB学习总结心得2–理想的学习方法 

    2022年6月21日
    29
  • unity摄像机深度图使用[通俗易懂]

    unity摄像机深度图使用[通俗易懂]https://www.jianshu.com/p/80a932d1f11ehttps://www.jianshu.com/p/178f3a065187https://www.cnblogs.com/czaoth/p/5830735.htmlhttps://www.cnblogs.com/jackmaxwell/p/7117909.htmlhttps://docs.unity3d.com/…

    2022年4月25日
    150

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号