以下是加了注释的matlab代码,该版本流传甚广:
function rbm = rbmtrain(rbm, x, opts)
assert(isfloat(x), 'x must be a float');
assert(all(x(:)>=0) && all(x(:)<=1), 'all data in x must be in [0:1]');
m = size(x, 1);
% m返回样本个数, x矩阵行数即样本个数,列数即特征数 %
numbatches = m / opts.batchsize;
% batchsize即每批处理量,numbatches即分批次数;总样本数m=批次数numbatches*每批次含样本数batchsize %
assert(rem(numbatches, 1) == 0, 'numbatches not integer');
for i = 1 : opts.numepochs % opts.numbatches即迭代次数 %
kk = randperm(m);
% 把原来m个样本序号打乱,把打乱后序号序列存到kk中 %
err = 0;
for l = 1 : numbatches % 遍历每一批次样本 %
batch = x(kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize), :);
% 提取整体样本x中的batchsize+1~(batchsize+batchsize)这批次样本,存入batch %
v1 = batch;
h1 = sigmrnd(repmat(rbm.c', opts.batchsize, 1) + v1 * rbm.W');[/color]
v2 = sigmrnd(repmat(rbm.b', opts.batchsize, 1) + h1 * rbm.W);
h2 = sigm(repmat(rbm.c', opts.batchsize, 1) + v2 * rbm.W');
% 以上为门德卡罗抽样的简化版:CD-k抽样 %
c1 = h1' * v1;
c2 = h2' * v2;
% c1和c2只是中间变量 %
rbm.vW = rbm.momentum * rbm.vW + rbm.alpha * (c1 - c2) / opts.batchsize;
rbm.vb = rbm.momentum * rbm.vb + rbm.alpha * sum(v1 - v2)' / opts.batchsize;
rbm.vc = rbm.momentum * rbm.vc + rbm.alpha * sum(h1 - h2)' / opts.batchsize;
rbm.W = rbm.W + rbm.vW;
rbm.b = rbm.b + rbm.vb;
rbm.c = rbm.c + rbm.vc;
err = err + sum(sum((v1 - v2) .^ 2)) / opts.batchsize;
%因为v1 v2元素取值只有1 0,故相减可能出现负值,平方的目的取绝对值以统计错误个数%
% 该批次遍历完毕后参数得到更新,将更新前的v2-v1即0-1误差和全部统计,并除以样本总数%
% 得到该批次下,每个样本平均所含的反编码错误数。err为之前所有批次的单位样本平均错误数的累积和 %
end
disp(['epoch ' num2str(i) '/' num2str(opts.numepochs) '. Average reconstruction error is: ' num2str(err / numbatches)]);
% 输出Gradient Descent迭代一次后,平均每批次单位样本的平均误差数,或认为是单位样本的平均出错编码个数 %
% 结果为0~n中的某个值 %
end
end
请注意我红字标记的地方:
h1 = sigmrnd(repmat(rbm.c', opts.batchsize, 1) + v1 * rbm.W');
c1 = h1' * v1;
c2 = h2' * v2;
rbm.vW = rbm.momentum * rbm.vW + rbm.alpha * (c1 - c2)
以下是论文的公式

根据公式:代码应该为:
h1 = sigm(repmat(rbm.c', opts.batchsize, 1) + v1 * rbm.W'); 这里应该是sigm,而不是sigmrnd
c1 = h1' * v1;
c2 = h2' * v2;
rbm.vW = rbm.momentum * rbm.vW + rbm.alpha * (c1 - c2)
我觉得代码里面那个h1是CD-K采样,用于后续估算v2、和p(h=1|v2))用的中间量,而公式中间要的是h1=p(h=1|v1),两者怎么能混为一谈?
请问是我哪里没看出来还是代码真的就错了??