[marf-cvs] marf/src/marf/Classification/NeuralNetwork Neuron.java,NONE,1.1
Brought to you by:
mokhov
From: <icl...@pr...> - 2002-11-18 01:27:43
|
Update of /cvsroot/marf/marf/src/marf/Classification/NeuralNetwork In directory sc8-pr-cvs1:/tmp/cvs-serv17380 Added Files: Neuron.java Log Message: DEV: From /resources/NeuralNetwork/ --- NEW FILE: Neuron.java --- import java.util.ArrayList; import java.lang.Double; import java.lang.Math; public class Neuron { //Data Members public String name; public int type = -1; //input=0; hidden=1; output=2; undef=-1; private ArrayList inputs = new ArrayList(); private ArrayList weights = new ArrayList(); private ArrayList weightsBuf = new ArrayList(); private ArrayList outputs = new ArrayList(); public double delta = 0.0; public double threshold = 0.0; public double result = 0.0; Neuron(String n, int t) { name = new String(n); type = t; } //Methods public boolean addInput(Neuron in, double weight) { return inputs.add(in) && weights.add(new Double(weight)) && weightsBuf.add(new Double(weight)); } public boolean addOutput(Neuron out) { return outputs.add(out); } public void eval() { if(type == 0) return; //assumes that the result of an input neuron is == the input if(inputs.isEmpty()) return; double count = 0; for(int i=0; i<inputs.size(); i++) count += ((Neuron)inputs.get(i)).result * ((Double)weights.get(i)).doubleValue(); count -= threshold; result = 1.0 / (1.0 + Math.exp(-count)); //System.out.println("Neuron: " + name + ", Sum: " + count + ", Result: " + result); } private double getWeight(Neuron n) { int val = inputs.indexOf(n); if(val>=0) return ((Double)weights.get(val)).doubleValue(); //System.out.println("There is no pointer n in neuron"); return -1.0; } public void train(double expected, double alpha, double beta) { if(type == 2) { //output nodes calc delta differntly based on expected res... delta = (expected - result) * result * (1.0 - result); } else if(type == 1) { double sum = 0.0; for(int i=0; i<outputs.size(); i++) sum += ((Neuron)outputs.get(i)).delta * ((Neuron)outputs.get(i)).getWeight(this); delta = result * (1.0 - result) * sum; } else { return; } //System.out.println("Neuron: " + name + ", Delta: " + delta); for(int i=0; i<inputs.size(); i++) { weightsBuf.set(i, new Double(beta * ((Double)weights.get(i)).doubleValue() + alpha * delta * ((Neuron)inputs.get(i)).result)); //double arg = ((Double)weights.get(i)).doubleValue(); //System.out.println("\tNew weight " + i + ": " + arg); } } public void commit() { for(int i=0; i<weights.size(); i++) { weights.set(i, new Double(((Double)weightsBuf.get(i)).doubleValue())); } } public void printXML(int tab) { indent(tab); System.out.println("<neuron index=\"" + name + "\" thresh=\"" + threshold + "\">"); for(int i=0; i<inputs.size(); i++) { indent(tab+1); System.out.println("<input ref=\"" + ((Neuron)inputs.get(i)).name + "\" weight=\"" + (Double)weights.get(i) + "\"/>"); } for(int i=0; i<outputs.size(); i++) { indent(tab+1); System.out.println("<output ref=\"" + ((Neuron)outputs.get(i)).name + "\"/>"); } indent(tab); System.out.println("</neuron>"); } public void indent(int n) { for(int i=0; i<n; i++) System.out.print("\t"); } } |