开源一个基于Storm 分布式BP神经网络的Demo(Java版)

撸大湿 2014-02-26 03:01:27
加精
RT,这个Demo是我去年写,功能上没有大问题
该demo有3个缺点
1、代码性能一般,有一定的优化空间。由于是Demo,当时写的时候只考虑功能,没考虑性能
2、神经细胞层数被写死了,只有2层,如果要添加细胞层需要改源码。
3、训练任然是串行的,没有充分利用分布式框架并行计算优势。但是后续的计算任然是并行的,只要添加DRPC Client或者自定义DSpout就行了

帖子里我就不介绍Storm和神经网络原理了
我会单独写一篇博文介绍,预计本周内完成,连接:http://blog.csdn.net/tntzbzc/article/details/19974515

Spout启动+Drpc Client 训练++Drpc Client 计算
三个我放在一个main中了
第一个是服务端job 启动
第二个是训练(串行),测试的训练数据是任意整数,把它转成32个由0或1组成的double数组作为输入参数(input),以及一个double[4]的结果值(real)
real[4]:1 0 0 0 正奇数 0 1 0 0 正偶数 0 0 1 0 负奇数 0 0 0 1 负偶数
两层细胞的权重weight是随机生成的
随机生成了1000个样本,训练2000次

第三个是计算,如果想实现分布式并行计算,可以自己添加Client


public class DrpcClient {

/****************************************
* by CSDN 撸大湿 Email : tntzhou@hotmail.com
****************************************/

public static void main(String[] args) throws Exception {
int InputHideCount = 32; /* Input Hide输入数量 */
int HideOutCount = 10; /* Hide Bolt输出数量,等于是Out Bolt的输入数量 */
int OutCount = 4; /* Out输出数量,等于Real的数量 */

TopologyBuilder builder = new TopologyBuilder();
DRPCSpout drpcSpout = new DRPCSpout("BPTrain");
builder.setSpout("drpcSpout", drpcSpout, 1);
builder.setBolt("hide", new HideBolt(), HideOutCount).allGrouping("drpcSpout");

// OutBolt的传参 必须等于 HideBolt的个数
builder.setBolt("out", new OutBolt(HideOutCount), OutCount).allGrouping("hide");

// TrainBPFinsh的传参 必须等于 OutBolt的个数
builder.setBolt("finsh", new TrainBPFinsh(OutCount), 1).allGrouping("out");

builder.setBolt("return", new ReturnResults(), 1).allGrouping("finsh");
Config conf = new Config();
conf.setNumWorkers(Integer.parseInt(args[1]));
StormSubmitter.submitTopology(args[0], conf, builder.createTopology());

String hideweight = getWeightStr(InputHideCount + 1, HideOutCount); // 隐藏层的权重
String outweight = getWeightStr(HideOutCount + 1, OutCount);// 输出层的权重
DRPCClient client = new DRPCClient("mynode001", 3772);
int[] ranInt = new int[1000];
for (int i = 0; i < ranInt.length; i++) {
ranInt[i] = new java.util.Random().nextInt();
}
int Num = 0;
System.out.println("开始训练");
for (int i = 0; i < 2000; i++) {
double r = 0d;
for (int j = 0; j < ranInt.length; j++) {//
RandomNum MyRandom = new RandomNum(ranInt[j]);
String input = MyRandom.getInputDataStr();
String real = MyRandom.getRealDataStr();

String[] result = client.execute(
"BPTrain",
String.valueOf(Num) + "::" + input + "::" + real + "::" + hideweight + "::"
+ outweight).split("::");
// 参数传入全部靠一个字符串,收取也是字符串,最基本的DRPC Client调用
hideweight = result[0];
outweight = result[1];
r = Double.parseDouble(result[2]);
if (j % 100 == 0)
System.out.println(r); // 输出收敛度
}
if (r < 0.005) {
System.out.println("训练结束");
break;
}
}

while (true) {
byte[] input = new byte[10];
System.in.read(input);
int value = 0;
try {
value = Integer.parseInt(new String(input).trim());
} catch (Exception e) {
break;
}
RandomNum rawVal = new RandomNum(value);

String[] resultstr = client.execute(
"BPTrain",
String.valueOf(++Num) + "::" + rawVal.getInputDataStr() + "::"
+ rawVal.getRealDataStr() + "::" + hideweight + "::" + outweight).split("::");

double max = -Integer.MIN_VALUE;
int idx = -1;
String[] result = resultstr[3].split(",");
for (int i = 0; i != result.length; i++) {
if (Double.valueOf(result[i]) > max) {
max = Double.valueOf(result[i]);
idx = i;
}
}

switch (idx) {
case 0:
System.out.format("%d是一个正奇数\n", value);
break;
case 1:
System.out.format("%d是一个正偶数\n", value);
break;
case 2:
System.out.format("%d是一个负奇数\n", value);
break;
case 3:
System.out.format("%d是一个负偶数\n", value);
break;
}
}
}

static String getWeightStr(int inCount, int outCount) {
StringBuilder Wgt = new StringBuilder();
for (int i = 0; i < outCount; i++) {
for (int j = 0; j < inCount; j++) {
Random random = new Random();
double v = random.nextDouble();
double rand = random.nextDouble() > 0.5 ? v : -v;
Wgt.append(rand / 2);
if (j != inCount - 1)
Wgt.append(",");
}
if (i != outCount - 1)
Wgt.append(":");
}
// System.out.println(Wgt.toString().split(":")[0].split(",").length);
return Wgt.toString();
}
}


隐藏层神经细胞 Hide Bolt

public class HideBolt implements IRichBolt {
/****************************************
* by CSDN 撸大湿 Email : tntzhou@hotmail.com
****************************************/
private static final long serialVersionUID = -3242401692275116210L;
OutputCollector collector;
int TaskID = 0;

@SuppressWarnings("rawtypes")
@Override
public void prepare(Map stormConf, TopologyContext context, OutputCollector _collector) {
collector = _collector;
TaskID = context.getThisTaskIndex(); // 每个Bolt代表一个神经细胞
// Bolt TaskID == 细胞 id
}

@Override
public void execute(Tuple tuple) {

String[] t = tuple.getString(0).split("::");

String jobID = tuple.getString(1); // jobid,drpc rid
int TNum = Integer.valueOf(t[0]); // tuple id
String[] inputstr = t[1].split(",");
double[] input = new double[inputstr.length]; // 传入值
String[] hideweightstr = t[3].split(":");
double[][] hideweight = new double[hideweightstr.length][input.length + 1]; // 隐藏层的权重
for (int i = 0; i < inputstr.length; i++) {
input[i] = Double.parseDouble(inputstr[i]);
}
for (int i = 0; i < hideweightstr.length; i++) {
String[] tmpw = hideweightstr[i].split(",");
for (int j = 0; j < tmpw.length; j++) {
hideweight[i][j] = Double.parseDouble(tmpw[j]);
}
}
double inputsum = 0;

/****************************************
* 计算输出
****************************************/
for (int i = 0; i < input.length; i++) {

inputsum += hideweight[TaskID][i] * input[i];
}

inputsum += hideweight[TaskID][input.length];
double HideOut = 1.0 / (1.0 + Math.exp(-inputsum));

String[] realstr = t[2].split(",");
double[] real = new double[realstr.length];

for (int i = 0; i < realstr.length; i++) {
real[i] = Double.parseDouble(realstr[i]);
}

String[] outweightstr = t[4].split(":");
double[][] outweight = new double[outweightstr.length][outweightstr[0].split(",").length];

for (int i = 0; i < outweightstr.length; i++) {
String[] tmpw = outweightstr[i].split(",");
for (int j = 0; j < tmpw.length; j++) {
outweight[i][j] = Double.parseDouble(tmpw[j]);
}
}

collector.emit(new Values(TNum, TaskID, input, HideOut, real, hideweight, outweight, jobID));
collector.ack(tuple);
}

@Override
public void declareOutputFields(OutputFieldsDeclarer declarer) {
declarer.declare(new Fields("tnum", "HideTaskID", "Input", "HideOut", "Real", "HideWeight",
"OutWeight", "jobID"));
}

@Override
public void cleanup() {
}

@Override
public Map<String, Object> getComponentConfiguration() {
return null;
}

}


输出层神经细胞



public class OutBolt implements IRichBolt {

/****************************************
* by CSDN 撸大湿 Email : tntzhou@hotmail.com
****************************************/
private static final long serialVersionUID = -7483206983562705977L;
OutputCollector collector;
int TaskID = 0;
HashMap<Integer, double[]> HideOutMap = new HashMap<Integer, double[]>();
HashMap<Integer, ArrayList<Tuple>> MyTuple = new HashMap<Integer, ArrayList<Tuple>>();
int HideTaskCount = 0;

public OutBolt(int _hidetaskcount) {
this.HideTaskCount = _hidetaskcount;
}

@SuppressWarnings("rawtypes")
@Override
public void prepare(Map stormConf, TopologyContext context, OutputCollector _collector) {
collector = _collector;
TaskID = context.getThisTaskIndex();
}

@Override
public void execute(Tuple tuple) {

int TNum = tuple.getInteger(0);
int HideTaskID = tuple.getInteger(1);
double hideout = tuple.getDouble(3);
if (!HideOutMap.containsKey(TNum)) {
HideOutMap.put(TNum, new double[HideTaskCount]);
MyTuple.put(TNum, new ArrayList<Tuple>());
}
HideOutMap.get(TNum)[HideTaskID] = hideout;
MyTuple.get(TNum).add(tuple);
double[] input = null;
double[] real = null;
double[][] hideweight = null;
double[][] outweight = null;
String jobID = null;

outweight = (double[][]) tuple.getValue(6);
jobID = tuple.getString(7);

if (MyTuple.get(TNum).size() == HideTaskCount) {
input = (double[]) tuple.getValue(2);
real = (double[]) tuple.getValue(4);
hideweight = (double[][]) tuple.getValue(5);

/****************************************
* 计算输出
****************************************/
Double sum = 0d;
for (int i = 0; i < HideOutMap.get(TNum).length; i++) {
sum += HideOutMap.get(TNum)[i] * outweight[TaskID][i];
}
sum += outweight[TaskID][HideOutMap.get(TNum).length];
double Out = 1.0 / (1.0 + Math.exp(-sum));
collector.emit(new Values(TNum, TaskID, input, HideOutMap.get(TNum), Out, real,
hideweight, outweight, jobID));
HideOutMap.remove(TNum);
MyTuple.remove(TNum);
}
collector.ack(tuple);
}

@Override
public void declareOutputFields(OutputFieldsDeclarer declarer) {
declarer.declare(new Fields("tnum", "OutTaskID", "Input", "HideOut", "Out", "Real",
"HideWeight", "OutWeight", "jobID"));
}

@Override
public void cleanup() {
// TODO Auto-generated method stub

}

@Override
public Map<String, Object> getComponentConfiguration() {
// TODO Auto-generated method stub
return null;
}

}
...全文
10338 32 打赏 收藏 转发到动态 举报
写回复
用AI写文章
32 条回复
切换为时间正序
请发表友善的回复…
发表回复
smkio 2016-05-24
  • 打赏
  • 举报
回复
RandomNum是一个自定义类么
清明采薇 2015-01-16
  • 打赏
  • 举报
回复
大神啊 崇拜的五体投地……感谢分享。
鸥翔鱼游1 2014-04-29
  • 打赏
  • 举报
回复
收藏了,慢慢看。
sgfawegawegaweg 2014-04-25
  • 打赏
  • 举报
回复
必须学习下,刚好和毕业论文相关
Icehand哥 2014-04-24
  • 打赏
  • 举报
回复
好贴。
wowomusic 2014-03-06
  • 打赏
  • 举报
回复
先看看再说啊。。。
liviuslee 2014-03-06
  • 打赏
  • 举报
回复
收藏一个 thanks for sharing
待我功成名就 2014-03-06
  • 打赏
  • 举报
回复
大神
天天开心oo 2014-03-05
  • 打赏
  • 举报
回复
鲁大师的设计思路很棒,值得学习 但demo的性能还有待改进
八百说 2014-03-04
  • 打赏
  • 举报
回复
当年matlab编过,老板的课,不得不搞。。。
menglanxiang 2014-03-02
  • 打赏
  • 举报
回复
感觉很复杂,虽然看不懂。
lkf181 2014-03-01
  • 打赏
  • 举报
回复
动态建立asp.net类
software_gemeng 2014-03-01
  • 打赏
  • 举报
回复
虽然说看不懂 但是感觉楼主好牛逼呀 UP一下
心是菩提树 2014-02-28
  • 打赏
  • 举报
回复
有空再学习
mutouzuodeyu 2014-02-28
  • 打赏
  • 举报
回复
收藏了,慢慢看。
待我功成名就 2014-02-28
  • 打赏
  • 举报
回复
膜拜
撸大湿 2014-02-27
  • 打赏
  • 举报
回复
引用 13 楼 whos2002110 的回复:
input,output,weight,bias只记得这么多了,不过你这误差阀值要求也太低了,怎么也得0.0002以下
嗯,一些复杂的应用的确需要调低点,这个可以自己改的 Demo中的样例很简单,只是判断正负、奇偶,这个阀值足够用了
whos2002110 2014-02-27
  • 打赏
  • 举报
回复
input,output,weight,bias只记得这么多了,不过你这误差阀值要求也太低了,怎么也得0.0002以下
lvshuchenyin 2014-02-26
  • 打赏
  • 举报
回复
好像很厉害的样子
長胸為富 2014-02-26
  • 打赏
  • 举报
回复
收藏成功!!
加载更多回复(10)

67,513

社区成员

发帖
与我相关
我的任务
社区描述
J2EE只是Java企业应用。我们需要一个跨J2SE/WEB/EJB的微容器,保护我们的业务核心组件(中间件),以延续它的生命力,而不是依赖J2SE/J2EE版本。
社区管理员
  • Java EE
加入社区
  • 近7日
  • 近30日
  • 至今
社区公告
暂无公告

试试用AI创作助手写篇文章吧