RBM代码Python

RBM代码Pythoncoding utf 8 Createdon201 年 4 月 1 日 author LIU importsysimp pylabaspltim linalgimport Imag

# -*- coding: utf-8 -*- ''' Created on 2016年4月1日 @author: LIU ''' import sys import numpy import matplotlib.pylab as plt import numpy as np import random from scipy.linalg import norm import PIL.Image from utils import * class RBM(object): def __init__(self, input=None, n_visible=2, n_hidden=3, \ W=None, hbias=None, vbias=None, rng=None): self.n_visible = n_visible # num of units in visible (input) layer self.n_hidden = n_hidden # num of units in hidden layer if rng is None: rng = numpy.random.RandomState(1234) if W is None: a = 1. / n_visible initial_W = numpy.array(rng.uniform( # initialize W uniformly(随机生成实数在-a-a之间) low=-a, high=a, size=(n_visible, n_hidden))) W = initial_W if hbias is None: hbias = numpy.zeros(n_hidden) # initialize h bias 0 if vbias is None: vbias = numpy.zeros(n_visible) # initialize v bias 0 self.rng = rng self.input = input self.W = W self.hbias = hbias self.vbias = vbias def contrastive_divergence(self, lr=0.1, k=1, input=None): if input is not None: self.input = input ''' CD-ks算法 ''' ph_mean, ph_sample = self.sample_h_given_v(self.input) chain_start = ph_sample #实现一步吉布斯采样通过给隐层采样 for step in xrange(k): if step == 0: nv_means, nv_samples,\ nh_means, nh_samples = self.gibbs_hvh(chain_start) else: nv_means, nv_samples,\ nh_means, nh_samples = self.gibbs_hvh(nh_samples) # chain_end = nv_samples self.W += lr * (numpy.dot(self.input.T, ph_mean) - numpy.dot(nv_samples.T, nh_means)) self.vbias += lr * numpy.mean(self.input - nv_samples, axis=0) self.hbias += lr * numpy.mean(ph_mean - nh_means, axis=0) # cost = self.get_reconstruction_cross_entropy() # return cost # 通过给出显层单元推断出隐层单元的  #计算隐层单元的激活率通过给出显层,得到一个采样通过给他们的 def sample_h_given_v(self, v0_sample): h1_mean = self.propup(v0_sample) h1_sample = self.rng.binomial(size=h1_mean.shape, # discrete: binomial n=1, p=h1_mean) return [h1_mean, h1_sample] #一一步吉布斯采样通过从隐层率开始 def sample_v_given_h(self, h0_sample): v1_mean = self.propdown(h0_sample) v1_sample = self.rng.binomial(size=v1_mean.shape, # discrete: binomial n=1, p=v1_mean) return [v1_mean, v1_sample] def propup(self, v): pre_sigmoid_activation = numpy.dot(v, self.W) + self.hbias return sigmoid(pre_sigmoid_activation) def propdown(self, h): pre_sigmoid_activation = numpy.dot(h, self.W.T) + self.vbias return sigmoid(pre_sigmoid_activation) #转换函数主要功能是通过给定的隐层采样来执行cd更新 def gibbs_hvh(self, h0_sample): v1_mean, v1_sample = self.sample_v_given_h(h0_sample) h1_mean, h1_sample = self.sample_h_given_v(v1_sample) return [v1_mean, v1_sample, h1_mean, h1_sample] #计算重构误差  def get_reconstruction_cross_entropy(self): pre_sigmoid_activation_h = numpy.dot(self.input, self.W) + self.hbias sigmoid_activation_h = sigmoid(pre_sigmoid_activation_h) pre_sigmoid_activation_v = numpy.dot(sigmoid_activation_h, self.W.T) + self.vbias sigmoid_activation_v = sigmoid(pre_sigmoid_activation_v) cross_entropy = - numpy.mean( numpy.sum(self.input * numpy.log(sigmoid_activation_v) + (1 - self.input) * numpy.log(1 - sigmoid_activation_v), axis=1)) return cross_entropy def reconstruct(self, v): h = sigmoid(numpy.dot(v, self.W) + self.hbias) reconstructed_v = sigmoid(numpy.dot(h, self.W.T) + self.vbias) return reconstructed_v def readData(path): data = [] for line in open(path, 'r'): ele = line.split(' ') tmp = [] for e in ele: if e != '': tmp.append(float(e.strip(' '))) data.append(tmp) return data def test_rbm(learning_rate=0.1, k=1, training_epochs=50): # data = numpy.array([[1,1,1,0,0,0], # [1,0,1,0,0,0], # [1,1,1,0,0,0], # [0,0,1,1,1,0], # [0,0,1,1,0,0], # [0,0,1,1,1,0]]) data = readData('data.txt') data = np.array(data) data = data.transpose() rng = numpy.random.RandomState(123) # construct RBM # rbm = RBM(input=data, n_visible=6, n_hidden=2, rng=rng) rbm = RBM(input=data, n_visible=784, n_hidden=2, rng=rng) # train for epoch in xrange(training_epochs): rbm.contrastive_divergence(lr=learning_rate, k=k) cost = rbm.get_reconstruction_cross_entropy() print >> sys.stderr, 'Training epoch %d, cost is ' % epoch, cost # test # v = numpy.array([[1, 1, 0, 0, 0, 0], # [0, 0, 0, 1, 1, 0]]) v=data[1,:] print rbm.reconstruct(v) if __name__ == "__main__": test_rbm()
# -*- coding: utf-8 -*- ''' Created on 2016年4月1日 @author: LIU ''' import numpy numpy.seterr(all='ignore') def sigmoid(x): return 1. / (1 + numpy.exp(-x)) def dsigmoid(x): return x * (1. - x) # def tanh(x): # return numpy.tanh(x) #  # def dtanh(x): # return 1. - x * x #  # def softmax(x): # e = numpy.exp(x - numpy.max(x)) # prevent overflow # if e.ndim == 1: # return e / numpy.sum(e, axis=0) # else:  # return e / numpy.array([numpy.sum(e, axis=1)]).T # ndim = 2 #  #  # def ReLU(x): # return x * (x > 0) #  # def dReLU(x): # return 1. * (x > 0)
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/228773.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月16日 下午6:14
下一篇 2026年3月16日 下午6:14


相关推荐

  • 满世界都在喊的Agent智能体,是真需求还是伪概念?

    满世界都在喊的Agent智能体,是真需求还是伪概念?

    2026年3月15日
    3
  • data:image/png;base64[通俗易懂]

    data:image/png;base64[通俗易懂]<imgsrc=”data:image/gif;base64,R0lGODlhJQAlAJECAL3L2AYrTv///wAAACH/C05FVFNDQVBFMi4wAwEAAAAh+QQFCgACACwAAAAAJQAlAAACi5SPqcvtDyGYIFpF690i8xUw3qJBwUlSadmcLqYmGQu6KDIeM13beGzYWWy3DlB4IYaMk+Dso2RWkFCfLPcRvFbZxFLUDTt21BW56TyjRep1e20+i+eYMR145W2eefj+6VFmgTQ

    2022年10月12日
    4
  • 编写两分钟的倒计时c语言(c语言倒计时几分几秒)

    集团文件版本号:(M928-T898-M248-WU2669-I2896-DQ586-M1988)集团文件版本号:(M928-T898-M248-WU2669-I2896-DQ586-M1988)C语言分钟倒计时代码C语言-2分钟倒计时代码#include#include#includeintmain(){inta=1,i=59;printf(“2:00”);Sleep(1000);sy…

    2022年4月17日
    134
  • linux解压gz.gz文件,linux解压tar.gz并重命名_linux解压tar.gz文件

    linux解压gz.gz文件,linux解压tar.gz并重命名_linux解压tar.gz文件原标题:linux解压tar.gz并重命名_linux解压tar.gz文件命名为jpg.tar.gztar–cjfjpg.tar.bz2*.jpg//将目录里所有jpg文件打包成jpg.tar后,并且将其需要先下载zipforlinux解压tar–xvffile.tar//解压tar包tar-xzvffile.tar.gz//解压tar.gztaCSDN提供了…

    2022年6月18日
    37
  • matlab中的normrnd函数,MATLAB中normrnd函数的使用方法

    matlab中的normrnd函数,MATLAB中normrnd函数的使用方法基本结构为 1 r normrnd mu sigma 生成服从正态分布 mu 参数代表均值 sigma 参数代表标准差 的随机数 输入的向量或矩阵 mu 和 sigma 必须形式相同 输出 r 也和它们形式相同 标量输入将被扩展成和其它输入具有相同维数的矩阵 2 r normrnd mu sigma m 生成服从正态分布 mu 参数代表均值 sigma 参数代表标准差 的随机数矩阵 矩阵的形式由 m 定义 m 是一

    2026年3月18日
    1
  • 在线客服系统源码 自适应手机移动端 支持多商家 带搭建教程

    在线客服系统源码 自适应手机移动端 支持多商家 带搭建教程下载链接:在线客服系统源码自适应手机移动端支持多商家支持微信公众号/微信小程序带搭建教程-PHP文档类资源-CSDN下载PHP轻量级人工在线客服系统源码自适应手机移动端支持多商家带搭建教程支持多商家支持多商家,每个注册用户为一个商家,每个商家可以添加多个客服。不限坐席每个商家可以无限添加坐席,不限制坐席数支持H5移动端系统自动适配移动端,也可以接入app(h5方式)支持微信公众号/微信小程序客服可以与微信公众号/小程序里的访客实时沟通常见问题自动回复…

    2022年7月19日
    21

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号