[marf-cvs] marf/src/marf/Classification/NeuralNetwork NeuralNetwork.java,1.5,1.6 Neuron.java,1.1,1.2
Brought to you by:
mokhov
From: <mo...@us...> - 2002-11-23 16:46:59
|
Update of /cvsroot/marf/marf/src/marf/Classification/NeuralNetwork In directory sc8-pr-cvs1:/tmp/cvs-serv9540/NeuralNetwork Modified Files: NeuralNetwork.java Neuron.java Log Message: Update of frameworked NN. Ian, if you're working on it, incorporate some chages in your working copy. Index: NeuralNetwork.java =================================================================== RCS file: /cvsroot/marf/marf/src/marf/Classification/NeuralNetwork/NeuralNetwork.java,v retrieving revision 1.5 retrieving revision 1.6 diff -C2 -d -r1.5 -r1.6 *** NeuralNetwork.java 19 Nov 2002 13:04:11 -0000 1.5 --- NeuralNetwork.java 23 Nov 2002 16:46:56 -0000 1.6 *************** *** 17,24 **** /** * Class NeuralNetwork * ! * $Header$ */ - public class NeuralNetwork extends Classification { --- 17,24 ---- /** * Class NeuralNetwork + * @author Ian Clement * ! * <p>$Header$ */ public class NeuralNetwork extends Classification { *************** *** 30,38 **** private int currLayerBuf=0; private Neuron currNeuron; ! private int neuron_type = -1; private ArrayList inputs = new ArrayList(); private ArrayList outputs = new ArrayList(); ! //All output will use this encoding static final String outputEncoding = "UTF-8"; --- 30,38 ---- private int currLayerBuf=0; private Neuron currNeuron; ! private int neuron_type = -1; private ArrayList inputs = new ArrayList(); private ArrayList outputs = new ArrayList(); ! //All output will use this encoding static final String outputEncoding = "UTF-8"; *************** *** 47,88 **** ! public NeuralNetwork(FeatureExtraction poFeatureExtraction) ! { ! super(poFeatureExtraction); ! } public boolean classify() throws ClassificationException { ! double [] adFeatures = oFeatureExtraction.getFeatures(); ! if(adFeatures.length != inputs.size()) ! throw new ClassificationException("Input array size not consistent with input layer."); ! ! for(int i=0; i<adFeatures.length; i++) ! ((Neuron)inputs.get(i)).result = adFeatures[i]; ! ! ArrayList current = null; ! ! for(int i=0; i<layers.size(); i++) { ! current = (ArrayList)layers.get(i); ! ! for(int j=0; j<current.size(); j++) ! ((Neuron)current.get(j)).eval(); ! } ! //Make result... ! double [] ret = new double[outputs.size()]; ! for(int i=0; i<outputs.size(); i++) ! ret[i] =((Neuron)outputs.get(i)).result; ! ! //... and then something to convert this into a result. ! //oResult = (Result)ret; ! } ! ! public Result getResult() ! { ! return oResult; } --- 47,86 ---- ! public NeuralNetwork(FeatureExtraction poFeatureExtraction) ! { ! super(poFeatureExtraction); ! } public boolean classify() throws ClassificationException { ! double [] adFeatures = oFeatureExtraction.getFeatures(); ! if(adFeatures.length != inputs.size()) ! throw new ClassificationException("Input array size not consistent with input layer."); ! for(int i=0; i<adFeatures.length; i++) ! ((Neuron)inputs.get(i)).result = adFeatures[i]; ! ArrayList current = null; ! for(int i=0; i<layers.size(); i++) { ! current = (ArrayList)layers.get(i); ! ! for(int j=0; j<current.size(); j++) ! ((Neuron)current.get(j)).eval(); ! } ! ! //Make result... ! double [] ret = new double[outputs.size()]; ! ! for(int i=0; i<outputs.size(); i++) ! ret[i] =((Neuron)outputs.get(i)).result; ! ! //... and then something to convert this into a result. ! //oResult = (Result)ret; ! // TODO: ! oResult = new Result(ret); ! ! return (outputs.size() > 0); } *************** *** 144,151 **** createLinks(doc); } ! ! //DOM tree traversal -- build NNet structure private void buildNNet(Node n) { ! int type = n.getNodeType(); String name, value; --- 142,149 ---- createLinks(doc); } ! ! //DOM tree traversal -- build NNet structure private void buildNNet(Node n) { ! int type = n.getNodeType(); String name, value; *************** *** 166,170 **** for (int i = 0; i < atts.getLength(); i++) { Node att = atts.item(i); ! String attName = att.getNodeName(); String attValue = att.getNodeValue(); --- 164,168 ---- for (int i = 0; i < atts.getLength(); i++) { Node att = atts.item(i); ! String attName = att.getNodeName(); String attValue = att.getNodeValue(); *************** *** 177,181 **** else if(attValue.equals("output")) { curr = outputs; ! neuron_type = 2; } else { --- 175,179 ---- else if(attValue.equals("output")) { curr = outputs; ! neuron_type = 2; } else { *************** *** 194,198 **** } else if(name.equals("neuron")) { ! String neuron_name = new String(); double thresh = 0.0; --- 192,196 ---- } else if(name.equals("neuron")) { ! String neuron_name = new String(); double thresh = 0.0; *************** *** 200,213 **** for (int i = 0; i < atts.getLength(); i++) { Node att = atts.item(i); ! String attName = att.getNodeName(); String attValue = att.getNodeValue(); ! if(attName.equals("index")) { System.out.println("Setting neuron name to " + attValue); neuron_name = new String(attValue); } ! else if(attName.equals("thresh")) { ! try { thresh = Double.valueOf(attValue.trim()).doubleValue(); System.out.println("Setting threshold to " + thresh + "."); --- 198,211 ---- for (int i = 0; i < atts.getLength(); i++) { Node att = atts.item(i); ! String attName = att.getNodeName(); String attValue = att.getNodeValue(); ! if(attName.equals("index")) { System.out.println("Setting neuron name to " + attValue); neuron_name = new String(attValue); } ! else if(attName.equals("thresh")) { ! try { thresh = Double.valueOf(attValue.trim()).doubleValue(); System.out.println("Setting threshold to " + thresh + "."); *************** *** 222,233 **** //System.out.println("Making new neuron " + neuron_name + " of type " + type); ! ! Neuron tmp = new Neuron(neuron_name, neuron_type); tmp.threshold = thresh; curr.add(tmp); ! } } ! // Recurse for children if any for (Node child = n.getFirstChild(); child != null; child = child.getNextSibling()) { --- 220,231 ---- //System.out.println("Making new neuron " + neuron_name + " of type " + type); ! ! Neuron tmp = new Neuron(neuron_name, neuron_type); tmp.threshold = thresh; curr.add(tmp); ! } } ! // Recurse for children if any for (Node child = n.getFirstChild(); child != null; child = child.getNextSibling()) { *************** *** 235,239 **** } } ! //DOM tree traversal -- create input and output links private void createLinks(Node n) throws ClassificationException --- 233,237 ---- } } ! //DOM tree traversal -- create input and output links private void createLinks(Node n) throws ClassificationException *************** *** 241,258 **** int type = n.getNodeType(); String name, value; ! if(type == Node.ELEMENT_NODE) { ! name = n.getNodeName(); ! NamedNodeMap atts = n.getAttributes(); ! if(name.equals("layer")) { for (int i = 0; i < atts.getLength(); i++) { Node att = atts.item(i); ! String attName = att.getNodeName(); String attValue = att.getNodeValue(); ! if(attName.equals("type")) { if(attValue.equals("input")) { --- 239,256 ---- int type = n.getNodeType(); String name, value; ! if(type == Node.ELEMENT_NODE) { ! name = n.getNodeName(); ! NamedNodeMap atts = n.getAttributes(); ! if(name.equals("layer")) { for (int i = 0; i < atts.getLength(); i++) { Node att = atts.item(i); ! String attName = att.getNodeName(); String attValue = att.getNodeValue(); ! if(attName.equals("type")) { if(attValue.equals("input")) { *************** *** 273,309 **** } else if(name.equals("neuron")) { ! String index = new String(); ! for (int i = 0; i < atts.getLength(); i++) { Node att = atts.item(i); ! String attName = att.getNodeName(); String attValue = att.getNodeValue(); ! if(attName.equals("index")) { index = new String(attValue); } } ! currNeuron = getNeuron(curr, index); ! } else if(name.equals("input")) { ! String index = null; double weight = -1.0; ! for(int i=0; i<atts.getLength(); i++) { Node att = atts.item(i); ! String attName = att.getNodeName(); String attValue = att.getNodeValue(); ! if(attName.equals("ref")) { index = new String(attValue); } else if(attName.equals("weight")) { ! try { weight = Double.valueOf(attValue.trim()).doubleValue(); } catch (NumberFormatException nfe) { --- 271,307 ---- } else if(name.equals("neuron")) { ! String index = new String(); ! for (int i = 0; i < atts.getLength(); i++) { Node att = atts.item(i); ! String attName = att.getNodeName(); String attValue = att.getNodeValue(); ! if(attName.equals("index")) { index = new String(attValue); } } ! currNeuron = getNeuron(curr, index); ! } else if(name.equals("input")) { ! String index = null; double weight = -1.0; ! for(int i=0; i<atts.getLength(); i++) { Node att = atts.item(i); ! String attName = att.getNodeName(); String attValue = att.getNodeValue(); ! if(attName.equals("ref")) { index = new String(attValue); } else if(attName.equals("weight")) { ! try { weight = Double.valueOf(attValue.trim()).doubleValue(); } catch (NumberFormatException nfe) { *************** *** 311,315 **** } } ! } if(weight < 0.0) { --- 309,313 ---- } } ! } if(weight < 0.0) { *************** *** 320,326 **** throw new ClassificationException("No \'ref\' value assigned for neuron " + currNeuron.name + " in layer " + currLayer); } ! System.out.println("Adding input " + index + " with weight " + weight); ! if(currLayer > 0) { Neuron toAdd = getNeuron((ArrayList)layers.get(currLayer-1), index); --- 318,324 ---- throw new ClassificationException("No \'ref\' value assigned for neuron " + currNeuron.name + " in layer " + currLayer); } ! System.out.println("Adding input " + index + " with weight " + weight); ! if(currLayer > 0) { Neuron toAdd = getNeuron((ArrayList)layers.get(currLayer-1), index); *************** *** 336,345 **** } else if(name.equals("output")) { ! String index = null; ! for(int i=0; i<atts.getLength(); i++) { Node att = atts.item(i); ! String attName = att.getNodeName(); String attValue = att.getNodeValue(); --- 334,343 ---- } else if(name.equals("output")) { ! String index = null; ! for(int i=0; i<atts.getLength(); i++) { Node att = atts.item(i); ! String attName = att.getNodeName(); String attValue = att.getNodeValue(); *************** *** 348,359 **** index = new String(attValue); } ! } ! if(index == null || index.equals("")) { throw new ClassificationException("No \'ref\' value assigned for neuron " + currNeuron.name + " in layer " + currLayer); } ! System.out.println("Adding output " + index); ! if(currLayer >= 0) { Neuron toAdd = getNeuron((ArrayList)layers.get(currLayer+1), index); --- 346,357 ---- index = new String(attValue); } ! } ! if(index == null || index.equals("")) { throw new ClassificationException("No \'ref\' value assigned for neuron " + currNeuron.name + " in layer " + currLayer); } ! System.out.println("Adding output " + index); ! if(currLayer >= 0) { Neuron toAdd = getNeuron((ArrayList)layers.get(currLayer+1), index); *************** *** 364,371 **** currNeuron.addOutput(toAdd); } ! } ! } ! // Recurse for children if any for (Node child = n.getFirstChild(); child != null; child = child.getNextSibling()) { --- 362,369 ---- currNeuron.addOutput(toAdd); } ! } ! } ! // Recurse for children if any for (Node child = n.getFirstChild(); child != null; child = child.getNextSibling()) { *************** *** 383,387 **** public void dumpXML() { ! System.out.println("<?xml version=\"1.0\"?>"); System.out.println("<net>"); --- 381,385 ---- public void dumpXML() { ! System.out.println("<?xml version=\"1.0\"?>"); System.out.println("<net>"); *************** *** 401,415 **** System.out.println("\" index=\"" + i + "\"/>"); ! for(int j=0; j<tmp_layer.size(); j++) { Neuron tmp_neuron = (Neuron)tmp_layer.get(j); ! tmp_neuron.printXML(2); } ! indent(1); System.out.println("</layer>"); } ! System.out.println("</net>"); } --- 399,413 ---- System.out.println("\" index=\"" + i + "\"/>"); ! for(int j=0; j<tmp_layer.size(); j++) { Neuron tmp_neuron = (Neuron)tmp_layer.get(j); ! tmp_neuron.printXML(2); } ! indent(1); System.out.println("</layer>"); } ! System.out.println("</net>"); } *************** *** 433,437 **** public void train(double [] in, double [] expected, double trainconst) throws ClassificationException { ! if(trainconst <= 0.0) throw new ClassificationException("Training constant must be >= 0.0"); --- 431,435 ---- public void train(double [] in, double [] expected, double trainconst) throws ClassificationException { ! if(trainconst <= 0.0) throw new ClassificationException("Training constant must be >= 0.0"); *************** *** 468,472 **** ArrayList tmp = (ArrayList)layers.get(i); ! for(int j=0; j<tmp.size(); j++) ((Neuron)tmp.get(j)).commit(); } --- 466,470 ---- ArrayList tmp = (ArrayList)layers.get(i); ! for(int j=0; j<tmp.size(); j++) ((Neuron)tmp.get(j)).commit(); } *************** *** 475,490 **** /* From Storage Manager */ ! ! public void dump() throws IOException ! { ! throw new NotImplementedException("NeuralNetwork.dump()"); ! } ! public void restore() throws IOException { ! throw new NotImplementedException("NeuralNetwork.restore()"); } ! //iclement: This may need revision: // Error handler to report errors and warnings --- 473,490 ---- /* From Storage Manager */ ! ! public void dump() throws IOException ! { ! //throw new NotImplementedException("NeuralNetwork.dump()"); ! dumpXML(); ! } ! public void restore() throws IOException { ! throw new NotImplementedException("NeuralNetwork.restore()"); } ! //iclement: This may need revision: + //mokhov: i guess so // Error handler to report errors and warnings *************** *** 496,500 **** this.out = out; } ! /** * Returns a string describing parse exception details --- 496,500 ---- this.out = out; } ! /** * Returns a string describing parse exception details *************** *** 517,526 **** out.println("Warning: " + getParseExceptionInfo(spe)); } ! public void error(SAXParseException spe) throws SAXException { String message = "Error: " + getParseExceptionInfo(spe); throw new SAXException(message); } ! public void fatalError(SAXParseException spe) throws SAXException { String message = "Fatal Error: " + getParseExceptionInfo(spe); --- 517,526 ---- out.println("Warning: " + getParseExceptionInfo(spe)); } ! public void error(SAXParseException spe) throws SAXException { String message = "Error: " + getParseExceptionInfo(spe); throw new SAXException(message); } ! public void fatalError(SAXParseException spe) throws SAXException { String message = "Fatal Error: " + getParseExceptionInfo(spe); Index: Neuron.java =================================================================== RCS file: /cvsroot/marf/marf/src/marf/Classification/NeuralNetwork/Neuron.java,v retrieving revision 1.1 retrieving revision 1.2 diff -C2 -d -r1.1 -r1.2 *** Neuron.java 18 Nov 2002 01:27:36 -0000 1.1 --- Neuron.java 23 Nov 2002 16:46:56 -0000 1.2 *************** *** 1,2 **** --- 1,4 ---- + package marf.Classification.NeuralNetwork; + import java.util.ArrayList; import java.lang.Double; *************** *** 13,17 **** private ArrayList weights = new ArrayList(); private ArrayList weightsBuf = new ArrayList(); ! private ArrayList outputs = new ArrayList(); --- 15,19 ---- private ArrayList weights = new ArrayList(); private ArrayList weightsBuf = new ArrayList(); ! private ArrayList outputs = new ArrayList(); *************** *** 28,33 **** public boolean addInput(Neuron in, double weight) { ! return inputs.add(in) && ! weights.add(new Double(weight)) && weightsBuf.add(new Double(weight)); } --- 30,35 ---- public boolean addInput(Neuron in, double weight) { ! return inputs.add(in) && ! weights.add(new Double(weight)) && weightsBuf.add(new Double(weight)); } *************** *** 42,46 **** if(inputs.isEmpty()) return; ! double count = 0; --- 44,48 ---- if(inputs.isEmpty()) return; ! double count = 0; *************** *** 54,66 **** //System.out.println("Neuron: " + name + ", Sum: " + count + ", Result: " + result); } ! private double getWeight(Neuron n) { int val = inputs.indexOf(n); ! ! if(val>=0) return ((Double)weights.get(val)).doubleValue(); ! //System.out.println("There is no pointer n in neuron"); ! return -1.0; } --- 56,68 ---- //System.out.println("Neuron: " + name + ", Sum: " + count + ", Result: " + result); } ! private double getWeight(Neuron n) { int val = inputs.indexOf(n); ! ! if(val>=0) return ((Double)weights.get(val)).doubleValue(); ! //System.out.println("There is no pointer n in neuron"); ! return -1.0; } *************** *** 72,81 **** } else if(type == 1) { ! double sum = 0.0; ! for(int i=0; i<outputs.size(); i++) sum += ((Neuron)outputs.get(i)).delta * ((Neuron)outputs.get(i)).getWeight(this); ! delta = result * (1.0 - result) * sum; } --- 74,83 ---- } else if(type == 1) { ! double sum = 0.0; ! for(int i=0; i<outputs.size(); i++) sum += ((Neuron)outputs.get(i)).delta * ((Neuron)outputs.get(i)).getWeight(this); ! delta = result * (1.0 - result) * sum; } *************** *** 87,91 **** for(int i=0; i<inputs.size(); i++) { ! weightsBuf.set(i, new Double(beta * ((Double)weights.get(i)).doubleValue() + alpha * delta * ((Neuron)inputs.get(i)).result)); --- 89,93 ---- for(int i=0; i<inputs.size(); i++) { ! weightsBuf.set(i, new Double(beta * ((Double)weights.get(i)).doubleValue() + alpha * delta * ((Neuron)inputs.get(i)).result)); *************** *** 103,110 **** public void printXML(int tab) { ! indent(tab); System.out.println("<neuron index=\"" + name + "\" thresh=\"" + threshold + "\">"); ! for(int i=0; i<inputs.size(); i++) { indent(tab+1); --- 105,112 ---- public void printXML(int tab) { ! indent(tab); System.out.println("<neuron index=\"" + name + "\" thresh=\"" + threshold + "\">"); ! for(int i=0; i<inputs.size(); i++) { indent(tab+1); |