DeepLearnToolbox DBN源码介绍
这几天看了下DeepLearnToolbox的源码,在此记录一下自己对DBN代码的理解。
test_example_DBN.m:测试代码
function test_example_DBN load ../data/mnist_40000_10000; addpath('../DBN'); addpath('../NN'); addpath('../util'); train_x = double(train_x) / 255; test_x = double(test_x) / 255; train_y = double(train_y); test_y = double(test_y); rand('state',0) //train dbn dbn.sizes = [100 200]; //DBN的结构,v1层为raw pixel/原始图片,h1/v2层的节点数为100,h2/v3层的节点数为200 opts.numepochs = 3; opts.batchsize = 100; opts.momentum = 0; //记录以前的更新方向,并与现在的方向结合下,从而加快学习的速度 opts.alpha = 1; dbn = dbnsetup(dbn, train_x, opts); dbn = dbntrain(dbn, train_x, opts); %unfold dbn to nn nn = dbnunfoldtonn(dbn, 10); nn.activation_function = 'sigm'; //train nn //得到DBN的初始化参数后,用nn进行微调 opts.numepochs = 3; opts.batchsize = 100; nn = nntrain(nn, train_x, train_y, opts); [er, bad] = nntest(nn, test_x, test_y); assert(er < 0.10, 'Too big error');
dbnsetup.m:建立DBN网络
function dbn = dbnsetup(dbn, x, opts) n = size(x, 2); dbn.sizes = [n, dbn.sizes]; //[784, 100,200] // 初始化W,b,c for u = 1 : numel(dbn.sizes) - 1 dbn.rbm{u}.alpha = opts.alpha; dbn.rbm{u}.momentum = opts.momentum; dbn.rbm{u}.W = zeros(dbn.sizes(u + 1), dbn.sizes(u)); dbn.rbm{u}.vW = zeros(dbn.sizes(u + 1), dbn.sizes(u)); dbn.rbm{u}.b = zeros(dbn.sizes(u), 1); //可视层的偏置bias dbn.rbm{u}.vb = zeros(dbn.sizes(u), 1); dbn.rbm{u}.c = zeros(dbn.sizes(u + 1), 1); //隐层的偏置bias dbn.rbm{u}.vc = zeros(dbn.sizes(u + 1), 1); end end
dbntrain.m:训练DBN
function dbn = dbntrain(dbn, x, opts) n = numel(dbn.rbm); dbn.rbm{1} = rbmtrain(dbn.rbm{1}, x, opts); for i = 2 : n x = rbmup(dbn.rbm{i - 1}, x); // 即sigm(W*x+c) dbn.rbm{i} = rbmtrain(dbn.rbm{i}, x, opts); end end
rbmtrain.m:训练RBM
采用对比散度(Contrastive Divergence,CD)算法进行训练,这是Hinton在2002年提出了RBM的一个快速学习算法
算法描述在 《Learning Deep Architectures for AI》 Algorithm 1,主要流程如下:
function rbm = rbmtrain(rbm, x, opts) assert(isfloat(x), 'x must be a float'); assert(all(x(:)>=0) && all(x(:)<=1), 'all data in x must be in [0:1]'); m = size(x, 1); numbatches = m / opts.batchsize; assert(rem(numbatches, 1) == 0, 'numbatches not integer'); for i = 1 : opts.numepochs //迭代次数 kk = randperm(m); //将样本随机打乱 err = 0; for l = 1 : numbatches batch = x(kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize), :); v1 = batch; h1 = sigmrnd(repmat(rbm.c', opts.batchsize, 1) + v1 * rbm.W'); // Gibbs采样 v2 = sigmrnd(repmat(rbm.b', opts.batchsize, 1) + h1 * rbm.W); // Gibbs采样 h2 = sigm(repmat(rbm.c', opts.batchsize, 1) + v2 * rbm.W'); // sigm(W*v2+c) // 对比上述流程图 c1 = h1' * v1; c2 = h2' * v2; // rbm.momentum:记录以前的更新方向,并与现在的方向结合,从而加快学习速度 rbm.vW = rbm.momentum * rbm.vW + rbm.alpha * (c1 - c2) / opts.batchsize; rbm.vb = rbm.momentum * rbm.vb + rbm.alpha * sum(v1 - v2)' / opts.batchsize; rbm.vc = rbm.momentum * rbm.vc + rbm.alpha * sum(h1 - h2)' / opts.batchsize; rbm.W = rbm.W + rbm.vW; rbm.b = rbm.b + rbm.vb; rbm.c = rbm.c + rbm.vc; err = err + sum(sum((v1 - v2) .^ 2)) / opts.batchsize; end disp(['epoch ' num2str(i) '/' num2str(opts.numepochs) '. Average reconstruction error is: ' num2str(err / numbatches)]); end end
dbnunfoldtonn.m:利用DBN的参数去初始化NN,然后用NN进行微调nn = nntrain(nn, train_x, train_y, opts);
function nn = dbnunfoldtonn(dbn, outputsize) // DBNUNFOLDTONN Unfolds a DBN to a NN // dbnunfoldtonn(dbn, outputsize ) returns the unfolded dbn with a final layer of size outputsize added. if(exist('outputsize','var')) size = [dbn.sizes outputsize]; else size = [dbn.sizes]; end nn = nnsetup(size); for i = 1 : numel(dbn.rbm) nn.W{i} = [dbn.rbm{i}.c dbn.rbm{i}.W]; //利用DBN每层的W和c去初始化NN的参数 end end
CNN源码解析:http://blog.csdn.net/zouxy09/article/details/9993743
http://blog.csdn.net/dark_scope/article/details/9495505
Reference:
(1) Learning Deep Architectures for AI
(2) A Practical Guide to Training Restricted Boltzmann Machines2010
本文地址:http://www.45fan.com/a/question/71586.html