From: Aaron A. <aa...@us...> - 2007-10-02 02:28:10
|
Update of /cvsroot/jboost/jboost/src/jboost/atree In directory sc8-pr-cvs6.sourceforge.net:/tmp/cvs-serv11031 Modified Files: AlternatingTree.java AtreePredictor.java AtreeTestSuite.java InstrumentedAlternatingTree.java PredictorNode.java SplitterNode.java Added Files: AlternatingTreeTest.java Log Message: Added ordered prediction capabilities, currenlty set as the default Index: PredictorNode.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/atree/PredictorNode.java,v retrieving revision 1.1.1.1 retrieving revision 1.2 diff -C2 -d -r1.1.1.1 -r1.2 *** PredictorNode.java 16 May 2007 04:06:02 -0000 1.1.1.1 --- PredictorNode.java 2 Oct 2007 02:28:06 -0000 1.2 *************** *** 8,11 **** --- 8,12 ---- import jboost.booster.Prediction; + import jboost.booster.NormalizedPrediction; import jboost.examples.Instance; import jboost.learner.IncompAttException; *************** *** 19,22 **** --- 20,48 ---- */ class PredictorNode implements Serializable { + + /** the prediction value associated with this node. */ + protected Prediction prediction; + + /** A textual identifier, has the format <parentSplitterNodeID>:index + * The id of the root predictor node is "R". */ + protected String id; + + /** An index signifying the iteration in which this node was added + * to the tree. */ + protected int index; + + + /** + * The parent splitter node (or null if root) + */ + SplitterNode parent; + + /** + * The branch index (as a value returned by Splitter) of the + * parent split that leads to this predictor node. + */ + int branchIndex; + + /** constructor */ public PredictorNode(Prediction p,String ID,int ind,Vector sp, *************** *** 42,45 **** --- 68,145 ---- return(retval); } + + + /** + * Calculate the prediction of the subtree starting at this node in + * order of iteration. This is important for normalized predictors + * such as NormalBoost. This function can only be called on the + * root node. + * @author Aaron Arvey + */ + + protected Prediction orderPredict(Instance instance, int numIterations) throws IncompAttException, RuntimeException { + // Are we the root node? + if (parent!=null || id!="R") { + throw new RuntimeException("Cannot perform ordered prediction on a node other then the root"); + } + + Prediction retval=(Prediction)prediction.clone(); + Prediction tmp=null; + for (int i=0; i < numIterations; i++) { + PredictorNode p = findPrediction(instance, i, this); + if (p==null) { // we could not get to this iteration, so we continue to the next iteration + continue; + //throw new Exception("Cannot find prediction for iteration " + i); + } + retval.add(p.prediction); + } + + /* + if (numIterations > 3 && numIterations < 5) { + System.out.println("Doing ordered prediction"); + } + + if (numIterations > 3 && numIterations < 5) { + try { + Thread.currentThread().sleep(9999); + } catch (Exception e) { + // do nothing + } + } + */ + + + return retval; + } + + private PredictorNode findPrediction(Instance instance, int iter, PredictorNode pn) { + if(pn.splitterNodes==null && pn.index != iter) return null; + if(pn.splitterNodes==null && pn.index == iter) return pn; + + // Search for the SplitterNode/PredictorNode of interest + for(int i=0;i<pn.splitterNodes.size();i++){ + if ( ((SplitterNode)pn.splitterNodes.elementAt(i)).getIndex() == iter ) { + return ((SplitterNode)pn.splitterNodes.elementAt(i)).predictNode(instance); + } + } + + // We couldn't find the node of interest, so continue with search + PredictorNode tmp = null; + for(int i=0;i<pn.splitterNodes.size();i++){ + tmp = ((SplitterNode)pn.splitterNodes.elementAt(i)).predictNode(instance); + tmp = findPrediction(instance, iter, tmp); + if (tmp==null) { + // The node is not down there or this instance does + // not fulfill the predicate. Search down the other + // paths + } else { + return tmp; + } + } + + return null; + } + + /** Generate a textual explanation of the prediction */ public String explain(Instance instance) throws IncompAttException { *************** *** 57,89 **** } - /** returns the number of child splitternodes. */ - int getSplitterNodeNo(){ - return(splitterNodes.size()); - } - - /** the prediction value associated with this node. */ - protected Prediction prediction; - - /** A textual identifier, has the format <parentSplitterNodeID>:index - * The id of the root predictor node is "R". */ - protected String id; - - /** An index signifying the iteration in which this node was added - * to the tree. */ - protected int index; - /** - * Return the ID of this PredictorNode - * @return id of this node - */ - public String getID() { - return id; - } - - - public int getIndex() { - return index; - } - /** output self in human-readable format. */ public String toString() { --- 157,161 ---- *************** *** 120,123 **** --- 192,199 ---- /** Add a prediction to its prediction value */ public void addToPrediction(Prediction p) { + if (p instanceof NormalizedPrediction) { + System.err.println("Cannot add normalized prediction to existing node"); + System.exit(2); + } prediction.add(p); } *************** *** 127,145 **** sums their predictions */ protected Vector splitterNodes; ! public Vector getSplitterNodes() { return splitterNodes; } /** ! * The parent splitter node (or null if root) */ ! SplitterNode parent; ! /** ! * The branch index (as a value returned by Splitter) of the ! * parent split that leads to this predictor node. */ ! int branchIndex; } --- 203,233 ---- sums their predictions */ protected Vector splitterNodes; ! public Vector getSplitterNodes() { return splitterNodes; } + /** + * Returns the number of child splitternodes. + */ + int getSplitterNodeNo(){ + return(splitterNodes.size()); + } + /** ! * Return the ID of this PredictorNode ! * @return id of this node */ ! public String getID() { ! return id; ! } ! /** ! * Return the index of this PredictorNode ! * @return index of this node */ ! public int getIndex() { ! return index; ! } } Index: AtreePredictor.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/atree/AtreePredictor.java,v retrieving revision 1.1.1.1 retrieving revision 1.2 diff -C2 -d -r1.1.1.1 -r1.2 *** AtreePredictor.java 16 May 2007 04:06:02 -0000 1.1.1.1 --- AtreePredictor.java 2 Oct 2007 02:28:06 -0000 1.2 *************** *** 30,33 **** --- 30,38 ---- AtreePredictor(Splitter s, PredictorNode p, Prediction[] pred, Booster b) { splitter = s; + + if (p==null){ + System.err.println("Predictor node given to constructor is null"); + System.err.println(s); + } pNode = p; this.pred = pred; *************** *** 35,42 **** zeroPred = b.getPredictions(new Bag[] {b.newBag()} , new int[0][])[0]; } ! public Prediction predict(Instance x) throws IncompAttException { if (isConstant) return pred[0]; for(PredictorNode p = pNode; p.parent != null; p = p.parent.parent) { if (p.parent.splitter.eval(x) != p.branchIndex) --- 40,70 ---- zeroPred = b.getPredictions(new Bag[] {b.newBag()} , new int[0][])[0]; } ! ! /** ! * Check to see if we get to this node. If we reach this node, ! * return the prediction for it. ! */ public Prediction predict(Instance x) throws IncompAttException { + // if this is the root, we have no parent + if (pNode==null) { + return predict(x,0); + } else { + return predict(x,pNode.index); + } + } + + + /** + * Check to see if we get to this node and this node is iteration + * iter. If we reach this node, return the prediction for it. + */ + public Prediction predict(Instance x, int iter) throws IncompAttException { if (isConstant) return pred[0]; + + if (pNode.index != iter) + return zeroPred; + + // If we don't reach this node, then return zero for(PredictorNode p = pNode; p.parent != null; p = p.parent.parent) { if (p.parent.splitter.eval(x) != p.branchIndex) *************** *** 46,49 **** --- 74,78 ---- return (v < 0 ? zeroPred : pred[v]); } + } Index: AlternatingTree.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/atree/AlternatingTree.java,v retrieving revision 1.1.1.1 retrieving revision 1.2 diff -C2 -d -r1.1.1.1 -r1.2 *** AlternatingTree.java 16 May 2007 04:06:02 -0000 1.1.1.1 --- AlternatingTree.java 2 Oct 2007 02:28:06 -0000 1.2 *************** *** 52,55 **** --- 52,67 ---- } + /** Make a prediction */ + public Prediction predict(Instance instance, int numIters) throws IncompAttException { + //return predict(instance); + return orderPredict(instance, numIters); + } + + /** Make a iteration orderd prediction */ + public Prediction orderPredict(Instance instance, int numIters) throws IncompAttException { + Prediction retval=root.orderPredict(instance, numIters); + return(retval); + } + /** Generate a textual explanation of the prediction */ public String explain(Instance instance) throws IncompAttException { --- NEW FILE: AlternatingTreeTest.java --- /* * */ package jboost.atree; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.Vector; import jboost.CandidateSplit; import jboost.Predictor; import jboost.booster.AdaBoost; import jboost.booster.Booster; import jboost.controller.Configuration; import jboost.examples.Attribute; import jboost.examples.AttributeDescription; import jboost.examples.DiscreteAttribute; import jboost.examples.Example; import jboost.examples.ExampleDescription; import jboost.examples.ExampleSet; import jboost.examples.Instance; import jboost.examples.Label; import jboost.learner.EqualitySplitterBuilder; import jboost.learner.IncompAttException; import jboost.learner.SplitterBuilder; import jboost.tokenizer.DataStream; import jboost.tokenizer.jboost_DataStream; import junit.framework.TestCase; /** * */ public class AlternatingTreeTest extends TestCase { DataStream m_datastream; Booster m_booster; Booster m_booster2; int m_numRounds; int[] m_trainLabels; int[] m_trainFeature1; int[] m_trainFeature2; int[] m_testLabels; int[] m_testValues; int[] m_exampleIndices= new int[12]; ExampleSet m_examples; SplitterBuilder m_builder; Vector m_builders; /* * @see TestCase#setUp() */ protected void setUp() throws Exception { super.setUp(); // build examples m_datastream= new jboost_DataStream(false,"feature1 (zero,one,two)\n labels (one,two)\n"); ExampleDescription description= m_datastream.getExampleDescription(); m_examples= new ExampleSet(description); m_booster= new AdaBoost(); m_booster2= new AdaBoost(); m_builder= new EqualitySplitterBuilder(0, m_booster, new AttributeDescription[] {description.getAttributeDescription(0)}); m_trainLabels= new int[] { 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0}; m_trainFeature1= new int[] { 0, 2, 2, 2, 1, 2, 0, 1, 0, 0, 2, 1}; m_testLabels= new int[] { 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0}; m_testValues= new int[] { 0, 2, 2, 2, 1, 2, 0, 1, 0, 0, 2, 1}; Example x; Attribute[] attributes= new Attribute[1]; Label l; m_exampleIndices= new int[m_trainLabels.length]; for (int i= 0; i < m_trainLabels.length; i++) { l= new Label(m_trainLabels[i]); attributes[0]= new DiscreteAttribute(m_trainFeature1[i]); x= new Example(attributes, l); try { m_builder.addExample(i, x); m_booster.addExample(i, l); m_booster2.addExample(i, l); m_examples.addExample(i, x); m_exampleIndices[i]= i; } catch (IncompAttException e) { } } m_builder.finalizeData(); m_booster.finalizeData(); m_booster2.finalizeData(); m_examples.finalizeData(); m_builders= new Vector(); m_builders.add(m_builder); } /* * @see TestCase#tearDown() */ protected void tearDown() throws Exception { super.tearDown(); } public final void testGetCandidates() { //TODO Implement getCandidates(). InstrumentedAlternatingTree iat; try { iat= new InstrumentedAlternatingTree(m_builders, m_booster, m_exampleIndices, new Configuration()); m_numRounds= 2; for (int j=0; j < m_numRounds; j++) { Vector cand= iat.getCandidates(); CandidateSplit bC= null; for (int i= 0; i < cand.size(); i++) { bC= (CandidateSplit) cand.get(i); iat.addCandidate(bC); } } Predictor c= iat.getCombinedPredictor(); } catch(Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } } public final void testInstrumentAlternatingTree() { // create an instrumented tree InstrumentedAlternatingTree first= null; try { first= new InstrumentedAlternatingTree(m_builders, m_booster, m_exampleIndices, new Configuration()); m_numRounds= 100; for (int j=0; j < m_numRounds; j++) { Vector cand= first.getCandidates(); CandidateSplit bC= null; // This piece should be replaced by a more general tool to measure the // goodness of a // split. int best= 0; double bestLoss= ((CandidateSplit) cand.get(0)).getLoss(); double tmpLoss; for (int i= 1; i < cand.size(); i++) { if ((tmpLoss= ((CandidateSplit) cand.get(i)).getLoss()) < bestLoss) { bestLoss= tmpLoss; best= i; } } bC= (CandidateSplit) cand.get(best); first.addCandidate(bC); } // turn the instrumented tree into an alternating tree AlternatingTree tree= (AlternatingTree) first.getCombinedPredictor(); // serialize the tree ByteArrayOutputStream bos= new ByteArrayOutputStream(); ObjectOutputStream os; os= new ObjectOutputStream(bos); os.writeObject(tree); os.flush(); os.close(); // de-serialize ByteArrayInputStream bis= new ByteArrayInputStream(bos.toByteArray()); ObjectInputStream is; AlternatingTree newTree= null; is= new ObjectInputStream(bis); newTree= (AlternatingTree) is.readObject(); is.close(); InstrumentedAlternatingTree second= new InstrumentedAlternatingTree(newTree, m_builders, m_booster2, m_exampleIndices, new Configuration()); AlternatingTree secondTree= (AlternatingTree) second.getCombinedPredictor(); for (int i=0; i < m_trainLabels.length; i++) { Instance test= m_examples.getExample(i).getInstance(); assertTrue(tree.predict(test).equals(secondTree.predict(test))); } testPredict(tree, secondTree); // assert that the boosters are equivalent assertTrue(m_booster.toString().equals(m_booster2.toString())); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); fail(); } } protected final void testPredict(AlternatingTree t1, AlternatingTree t2) { for (int i=0; i < m_trainLabels.length; i++) { Instance test= m_examples.getExample(i).getInstance(); System.out.println(t1.predict(test)); System.out.println(t1.predict(test).getClassScores()[0]); System.out.println(t1.predict(test).getClassScores()[1]); System.out.println(t1.orderPredict(test,m_numRounds-1)); System.out.println(t2.predict(test)); System.out.println(t2.orderPredict(test,m_numRounds-1)); System.out.flush(); double EPS = 0.000001; assertEquals(t1.predict(test).getClassScores()[0], t1.orderPredict(test,m_numRounds-1).getClassScores()[0], EPS); assertEquals(t1.predict(test).getClassScores()[1], t1.orderPredict(test,m_numRounds-1).getClassScores()[1], EPS); assertEquals(t2.predict(test).getClassScores()[0], t2.orderPredict(test,m_numRounds-1).getClassScores()[0], EPS); assertEquals(t2.predict(test).getClassScores()[1], t2.orderPredict(test,m_numRounds-1).getClassScores()[1], EPS); } } } Index: AtreeTestSuite.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/atree/AtreeTestSuite.java,v retrieving revision 1.1.1.1 retrieving revision 1.2 diff -C2 -d -r1.1.1.1 -r1.2 *** AtreeTestSuite.java 16 May 2007 04:06:02 -0000 1.1.1.1 --- AtreeTestSuite.java 2 Oct 2007 02:28:06 -0000 1.2 *************** *** 18,21 **** --- 18,22 ---- //$JUnit-BEGIN$ suite.addTestSuite(InstrumentedAlternatingTreeTest.class); + suite.addTestSuite(AlternatingTreeTest.class); //$JUnit-END$ return suite; Index: SplitterNode.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/atree/SplitterNode.java,v retrieving revision 1.1.1.1 retrieving revision 1.2 diff -C2 -d -r1.1.1.1 -r1.2 *** SplitterNode.java 16 May 2007 04:06:02 -0000 1.1.1.1 --- SplitterNode.java 2 Oct 2007 02:28:06 -0000 1.2 *************** *** 24,37 **** class SplitterNode implements Serializable{ ! /** Calculate the prediction of the subtree starting at this node. ! Depends on the fact that all splits are binary.*/ protected Prediction predict(Instance instance) throws IncompAttException { ! if(splitter==null) throw(new RuntimeException("Splitter node: "+id+" has no splitter.")); int which_branch=splitter.eval(instance); ! if(which_branch<0) return(null); ! if(predictorNodes==null) throw(new RuntimeException("Splitter node: "+id+" has no predictor nodes.")); return(predictorNodes[which_branch].predict(instance)); } /** Generate a textual explanation of the prediction */ protected String explain(Instance instance) throws IncompAttException { --- 24,56 ---- class SplitterNode implements Serializable{ ! /** ! * Calculate the prediction of the subtree starting at this node. ! * Depends on the fact that all splits are binary. ! */ protected Prediction predict(Instance instance) throws IncompAttException { ! if(splitter==null) ! throw(new RuntimeException("Splitter node: "+id+" has no splitter.")); int which_branch=splitter.eval(instance); ! if(which_branch<0) ! return(null); ! if(predictorNodes==null) ! throw(new RuntimeException("Splitter node: "+id+" has no predictor nodes.")); return(predictorNodes[which_branch].predict(instance)); } + /** + * Return the prediction node that comes next for this instance + */ + protected PredictorNode predictNode(Instance instance) throws IncompAttException { + if(splitter==null) + throw(new RuntimeException("Splitter node: "+id+" has no splitter.")); + int which_branch=splitter.eval(instance); + if(which_branch<0) + return(null); + if(predictorNodes==null) + throw(new RuntimeException("Splitter node: "+id+" has no predictor nodes.")); + return(predictorNodes[which_branch]); + } + /** Generate a textual explanation of the prediction */ protected String explain(Instance instance) throws IncompAttException { Index: InstrumentedAlternatingTree.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/atree/InstrumentedAlternatingTree.java,v retrieving revision 1.2 retrieving revision 1.3 diff -C2 -d -r1.2 -r1.3 *** InstrumentedAlternatingTree.java 27 May 2007 11:24:07 -0000 1.2 --- InstrumentedAlternatingTree.java 2 Oct 2007 02:28:06 -0000 1.3 *************** *** 15,18 **** --- 15,19 ---- import jboost.booster.BrownBoost; import jboost.booster.Prediction; + import jboost.booster.NormalizedPrediction; import jboost.controller.Configuration; import jboost.controller.ConfigurationException; *************** *** 112,116 **** // Use the number of boosting iterations as the default // size for the internal lists used by this tree ! int listSize= config.getInt("numRounds", 25); // initialize the data structures used by the tree --- 113,117 ---- // Use the number of boosting iterations as the default // size for the internal lists used by this tree ! int listSize= config.getInt("numRounds", 200); // initialize the data structures used by the tree *************** *** 215,220 **** Vector splitters=new Vector(m_splitterBuilders.size()); ! for(Iterator i=m_splitterBuilders.iterator();i.hasNext();) { PredictorNodeSB pSB=(PredictorNodeSB)i.next(); if (m_treeType == AtreeType.ADD_ROOT && pSB.pNode != 0) { while(sbCount.currentCount()!=0) { --- 216,223 ---- Vector splitters=new Vector(m_splitterBuilders.size()); ! for (Iterator i = m_splitterBuilders.iterator(); i.hasNext(); ) { ! PredictorNodeSB pSB=(PredictorNodeSB)i.next(); + if (m_treeType == AtreeType.ADD_ROOT && pSB.pNode != 0) { while(sbCount.currentCount()!=0) { *************** *** 318,323 **** Prediction[] pred= m_booster.getPredictions(bags, partition); m_booster.update(pred, partition); ! ((PredictorNode) m_predictors.get(0)).addToPrediction(pred[0]); lastBasePredictor= new AtreePredictor(pred); } --- 321,334 ---- Prediction[] pred= m_booster.getPredictions(bags, partition); m_booster.update(pred, partition); ! ! if (pred.length > 0 && pred[0] instanceof NormalizedPrediction) { ! System.err.println("Cannot update root with mixed binary pred"); ! System.exit(2); ! } ! ((PredictorNode) m_predictors.get(0)).addToPrediction(pred[0]); + if (pred==null) { + System.err.println("Updating root pred is null!"); + } lastBasePredictor= new AtreePredictor(pred); } *************** *** 351,355 **** * @param partition */ ! private void insert(Bag[] bags, PredictorNode parent, Splitter splitter, SplitterBuilder[] parentArray, Prediction[] predictions, int[][] partition) { --- 362,366 ---- * @param partition */ ! private SplitterNode insert(Bag[] bags, PredictorNode parent, Splitter splitter, SplitterBuilder[] parentArray, Prediction[] predictions, int[][] partition) { *************** *** 384,387 **** --- 395,399 ---- parent.addSplitterNode(sNode); m_splitters.add(pInt); + return sNode; } *************** *** 393,396 **** --- 405,409 ---- public void addCandidate(CandidateSplit candidate) throws InstrumentException { AtreeCandidateSplit acand= null; + SplitterNode node= null; try { acand= (AtreeCandidateSplit) candidate; *************** *** 410,416 **** Prediction[] predictions= m_booster.getPredictions(bags, partition); m_booster.update(predictions, partition); lastBasePredictor= new AtreePredictor(splitter, parent, predictions, m_booster); ! SplitterNode node= findSplitter(parent, splitter); ! if (node != null) { for (int i=0; i < node.predictorNodes.length; i++) { node.predictorNodes[i].addToPrediction(predictions[i]); --- 423,432 ---- Prediction[] predictions= m_booster.getPredictions(bags, partition); m_booster.update(predictions, partition); + if (parent==null) { + System.err.println("Parent is null!"); + } lastBasePredictor= new AtreePredictor(splitter, parent, predictions, m_booster); ! node= findSplitter(parent, splitter); ! if (node != null && (predictions.length > 0 && !(predictions[0] instanceof NormalizedPrediction)) ) { for (int i=0; i < node.predictorNodes.length; i++) { node.predictorNodes[i].addToPrediction(predictions[i]); *************** *** 418,424 **** } else { SplitterBuilder[] parentArray= ((PredictorNodeSB) m_splitterBuilders.get(acand.getPredictorNode())).SB; ! insert(bags, parent, splitter, parentArray, predictions, partition); } } } --- 434,463 ---- } else { SplitterBuilder[] parentArray= ((PredictorNodeSB) m_splitterBuilders.get(acand.getPredictorNode())).SB; ! node = insert(bags, parent, splitter, parentArray, predictions, partition); } } + /* + System.out.println("Adding Candidate: " + candidate); + + System.out.println("Node being added:" ); + System.out.println("" + node ); + + System.out.println("m_predictors:"); + for (int i=0; i<m_predictors.size(); i++) { + System.out.println("i: "+ i + m_predictors.get(i)); + } + System.out.println("m_predictors.index:"); + for (int i=0; i<m_predictors.size(); i++) { + System.out.println("i: "+ i + ((PredictorNode)(m_predictors.get(i))).getIndex()); + } + System.out.println("m_splitters:\n"); + for (int i=0; i<m_splitters.size(); i++) { + System.out.println("i: "+i+m_splitters.get(i)); + } + System.out.println("m_splitterBuilders:\n"); + for (int i=0; i<m_splitterBuilders.size(); i++) { + System.out.println("i: "+i+m_splitterBuilders.get(i)); + } + */ } *************** *** 449,452 **** --- 488,494 ---- m_booster.update(predictions, partition); + if (parent==null) { + System.err.println("Adding candidate and the parent is null!"); + } lastBasePredictor= new AtreePredictor(splitter, parent, predictions, m_booster); SplitterNode node= findSplitter(parent, splitter); |