From: Sunsern C. <sch...@us...> - 2009-06-08 23:40:53
|
Update of /cvsroot/jboost/jboost/src/jboost/atree In directory fdv4jf1.ch3.sourceforge.com:/tmp/cvs-serv11232/src/jboost/atree Modified Files: Tag: jboost-2_0 AlternatingTreeTest.java InstrumentedAlternatingTree.java InstrumentedAlternatingTreeTest.java PredictorNode.java SplitterBuilderWorker.java SplitterNode.java Log Message: jboost 2.0 Index: InstrumentedAlternatingTreeTest.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/atree/InstrumentedAlternatingTreeTest.java,v retrieving revision 1.1.1.1 retrieving revision 1.1.1.1.4.1 diff -C2 -d -r1.1.1.1 -r1.1.1.1.4.1 *** InstrumentedAlternatingTreeTest.java 16 May 2007 04:06:02 -0000 1.1.1.1 --- InstrumentedAlternatingTreeTest.java 8 Jun 2009 23:40:49 -0000 1.1.1.1.4.1 *************** *** 7,11 **** import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; - import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; --- 7,10 ---- Index: PredictorNode.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/atree/PredictorNode.java,v retrieving revision 1.3 retrieving revision 1.3.4.1 diff -C2 -d -r1.3 -r1.3.4.1 *** PredictorNode.java 7 Apr 2008 17:09:06 -0000 1.3 --- PredictorNode.java 8 Jun 2009 23:40:49 -0000 1.3.4.1 *************** *** 5,9 **** import jboost.booster.Prediction; - import jboost.booster.NormalizedPrediction; import jboost.examples.Instance; import jboost.learner.IncompAttException; --- 5,8 ---- *************** *** 125,129 **** for(int i=0;i<pn.splitterNodes.size();i++){ tmp = ((SplitterNode)pn.splitterNodes.elementAt(i)).predictNode(instance); ! tmp = findPrediction(instance, iter, tmp); if (tmp==null) { // The node is not down there or this instance does --- 124,130 ---- for(int i=0;i<pn.splitterNodes.size();i++){ tmp = ((SplitterNode)pn.splitterNodes.elementAt(i)).predictNode(instance); ! ! if (tmp!=null) tmp = findPrediction(instance, iter, tmp); ! if (tmp==null) { // The node is not down there or this instance does *************** *** 189,196 **** /** Add a prediction to its prediction value */ public void addToPrediction(Prediction p) { - if (p instanceof NormalizedPrediction) { - System.err.println("Cannot add normalized prediction to existing node"); - System.exit(2); - } prediction.add(p); } --- 190,193 ---- Index: AlternatingTreeTest.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/atree/AlternatingTreeTest.java,v retrieving revision 1.1 retrieving revision 1.1.4.1 diff -C2 -d -r1.1 -r1.1.4.1 *** AlternatingTreeTest.java 2 Oct 2007 02:28:06 -0000 1.1 --- AlternatingTreeTest.java 8 Jun 2009 23:40:49 -0000 1.1.4.1 *************** *** 6,10 **** import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; - import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; --- 6,9 ---- Index: SplitterBuilderWorker.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/atree/SplitterBuilderWorker.java,v retrieving revision 1.1.1.1 retrieving revision 1.1.1.1.4.1 diff -C2 -d -r1.1.1.1 -r1.1.1.1.4.1 *** SplitterBuilderWorker.java 16 May 2007 04:06:02 -0000 1.1.1.1 --- SplitterBuilderWorker.java 8 Jun 2009 23:40:49 -0000 1.1.1.1.4.1 *************** *** 2,5 **** --- 2,6 ---- import java.util.Vector; + import java.util.concurrent.CountDownLatch; import jboost.CandidateSplit; *************** *** 7,11 **** import jboost.monitor.Monitor; import jboost.util.BaseCountWorker; - import EDU.oswego.cs.dl.util.concurrent.CountDown; /** --- 8,11 ---- *************** *** 19,23 **** Vector splitters; ! public SplitterBuilderWorker(PredictorNodeSB pSB, Vector splitters, CountDown count) { super(count); this.pSB=pSB; --- 19,23 ---- Vector splitters; ! public SplitterBuilderWorker(PredictorNodeSB pSB, Vector splitters, CountDownLatch count) { super(count); this.pSB=pSB; Index: SplitterNode.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/atree/SplitterNode.java,v retrieving revision 1.2 retrieving revision 1.2.4.1 diff -C2 -d -r1.2 -r1.2.4.1 *** SplitterNode.java 2 Oct 2007 02:28:06 -0000 1.2 --- SplitterNode.java 8 Jun 2009 23:40:49 -0000 1.2.4.1 *************** *** 1,7 **** package jboost.atree; - import java.io.IOException; - import java.io.ObjectInputStream; - import java.io.ObjectOutputStream; import java.io.Serializable; --- 1,4 ---- Index: InstrumentedAlternatingTree.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/atree/InstrumentedAlternatingTree.java,v retrieving revision 1.9 retrieving revision 1.9.2.1 diff -C2 -d -r1.9 -r1.9.2.1 *** InstrumentedAlternatingTree.java 19 Feb 2009 14:20:25 -0000 1.9 --- InstrumentedAlternatingTree.java 8 Jun 2009 23:40:49 -0000 1.9.2.1 *************** *** 5,8 **** --- 5,11 ---- import java.util.Iterator; import java.util.Vector; + import java.util.concurrent.CountDownLatch; + import java.util.concurrent.Executor; + import java.util.concurrent.RejectedExecutionException; import jboost.CandidateSplit; *************** *** 13,23 **** import jboost.booster.Bag; import jboost.booster.Booster; - import jboost.booster.BrownBoost; - import jboost.booster.MulticlassWrapMH; import jboost.booster.Prediction; - import jboost.booster.NormalizedPrediction; import jboost.booster.RobustBinaryPrediction; import jboost.booster.RobustBoost; - import jboost.booster.MulticlassWrapMH.MultiPrediction; import jboost.controller.Configuration; import jboost.controller.ConfigurationException; --- 16,22 ---- *************** *** 27,32 **** import jboost.monitor.Monitor; import jboost.util.ExecutorSinglet; ! import EDU.oswego.cs.dl.util.concurrent.CountDown; ! import EDU.oswego.cs.dl.util.concurrent.Executor; --- 26,30 ---- import jboost.monitor.Monitor; import jboost.util.ExecutorSinglet; ! *************** *** 38,41 **** --- 36,40 ---- */ + @SuppressWarnings("unchecked") public class InstrumentedAlternatingTree extends ComplexLearner { *************** *** 94,98 **** * @param config The configuration information. */ ! public InstrumentedAlternatingTree(Vector sb, Booster b, int[] ex, Configuration config) { --- 93,98 ---- * @param config The configuration information. */ ! ! public InstrumentedAlternatingTree(Vector sb, Booster b, int[] ex, Configuration config) { *************** *** 108,113 **** throws InstrumentException, NotSupportedException { ! init(splitterbuilders, booster, examples, config); ! createRoot(); instrumentAlternatingTree(tree); } --- 108,113 ---- throws InstrumentException, NotSupportedException { ! init(splitterbuilders, booster, examples, config); ! createRoot(tree.getRoot()); instrumentAlternatingTree(tree); } *************** *** 189,192 **** --- 189,201 ---- } + + private void createRoot(PredictorNode root) { + PredictorNode predictorNode= new PredictorNode(root.prediction, "R", + 0, null, null, 0); + + m_predictors.add(predictorNode); + } + + /** * Suggest a list of Candidate Splitters *************** *** 224,237 **** // create a synchronization barrier that counts the number // of processed splitter builders ! CountDown sbCount=new CountDown(m_splitterBuilders.size()); Vector splitters=new Vector(m_splitterBuilders.size()); for (Iterator i = m_splitterBuilders.iterator(); i.hasNext(); ) { PredictorNodeSB pSB=(PredictorNodeSB)i.next(); if (m_treeType == AtreeType.ADD_ROOT && pSB.pNode != 0) { ! while(sbCount.currentCount()!=0) { ! sbCount.release(); } break; --- 233,248 ---- // create a synchronization barrier that counts the number // of processed splitter builders ! CountDownLatch sbCount=new CountDownLatch(m_splitterBuilders.size()); Vector splitters=new Vector(m_splitterBuilders.size()); for (Iterator i = m_splitterBuilders.iterator(); i.hasNext(); ) { + //System.out.println("Creating new SplitterBuilderWorker and run it ... "); + PredictorNodeSB pSB=(PredictorNodeSB)i.next(); if (m_treeType == AtreeType.ADD_ROOT && pSB.pNode != 0) { ! while(sbCount.getCount()!=0) { ! sbCount.countDown(); } break; *************** *** 242,246 **** m_predictors.get(pSB.pNode)).getSplitterNodeNo(); if (childCount > 0) { ! sbCount.release(); continue; } --- 253,257 ---- m_predictors.get(pSB.pNode)).getSplitterNodeNo(); if (childCount > 0) { ! sbCount.countDown(); continue; } *************** *** 251,264 **** m_predictors.get(pSB.pNode)).getSplitterNodeNo(); if (childCount > 0) { ! sbCount.release(); continue; } } SplitterBuilderWorker sbw= new SplitterBuilderWorker(pSB,splitters,sbCount); try { pe.execute(sbw); ! } catch (InterruptedException ie) { System.err.println("exception ocurred while handing off the " + "splitter job to the pool: " --- 262,278 ---- m_predictors.get(pSB.pNode)).getSplitterNodeNo(); if (childCount > 0) { ! sbCount.countDown(); continue; } } + //System.out.println("Create new SplitterBuilderWorker and run it ... "); + SplitterBuilderWorker sbw= new SplitterBuilderWorker(pSB,splitters,sbCount); + try { pe.execute(sbw); ! } catch (RejectedExecutionException ie) { System.err.println("exception ocurred while handing off the " + "splitter job to the pool: " *************** *** 268,278 **** } // wait on all threads to finish try { ! sbCount.acquire(); } catch(InterruptedException ie) { ! if(sbCount.currentCount()!=0) { System.err.println("interrupted exception occurred, but the " ! + "sbCount is " + sbCount.currentCount()); } }; --- 282,294 ---- } + //System.out.println("Waiting on all threads to finish..."); + // wait on all threads to finish try { ! sbCount.await(); } catch(InterruptedException ie) { ! if(sbCount.getCount()!=0) { System.err.println("interrupted exception occurred, but the " ! + "sbCount is " + sbCount.getCount()); } }; *************** *** 281,323 **** } ! /** ! * Build a splitter using a single splitter buildier ! * @param retval ! * @throws NotSupportedException ! */ ! private void buildSplitter(PredictorNodeSB pSB, Vector splitters) ! throws NotSupportedException { ! CandidateSplit split; ! ! double trivLoss; ! long start; ! long stop; ! ! // Create bag containing all m_examples reaching this node: ! // tmpBag = m_booster.newBag(makeIndices((boolean []) ! // m_masks.get(pSB.pNode))); Compute loss for trivial split: ! // trivLoss = m_booster.getLoss(new Bag[] {tmpBag}); ! // TODO: ! // need to fix so that splits worse than trivial are not ! // added. In the meantime, allow all splits. ! trivLoss= Double.MAX_VALUE; ! start= System.currentTimeMillis(); ! int j=0; ! for (j= 0; j < pSB.SB.length; j++) { ! split= pSB.SB[j].build(); ! ! // only add candidates with loss better than trivial split ! // TODO: figure out what to do if no splits better ! // than trivial ! if (split != null && split.getLoss() < trivLoss) ! splitters.add(new AtreeCandidateSplit(pSB.pNode, split)); ! } ! stop= System.currentTimeMillis(); ! if (Monitor.logLevel > 3) { ! Monitor.log("It took an average of " + (stop-start)/(j*1000.0) + ! " seconds to build " + j + " splitterbuilders."); ! } ! } ! /** --- 297,301 ---- } ! /** *************** *** 333,341 **** m_booster.update(pred, partition); - if (pred.length > 0 && pred[0] instanceof NormalizedPrediction) { - System.err.println("Cannot update root with mixed binary pred"); - System.exit(2); - } - ((PredictorNode) m_predictors.get(0)).addToPrediction(pred[0]); if (pred==null) { --- 311,314 ---- *************** *** 398,402 **** SplitterBuilder[] childArray= new SplitterBuilder[parentArray.length]; for (int j= 0; j < parentArray.length; j++) { ! childArray[j]= parentArray[j].spawn(examplesMask, partition[i].length); } PredictorNodeSB pnSB= new PredictorNodeSB(pInt[i], childArray); --- 371,375 ---- SplitterBuilder[] childArray= new SplitterBuilder[parentArray.length]; for (int j= 0; j < parentArray.length; j++) { ! childArray[j]= parentArray[j].spawn(examplesMask, partition[i].length); } PredictorNodeSB pnSB= new PredictorNodeSB(pInt[i], childArray); *************** *** 446,450 **** for (int i=0; i < predictions.length; i++) { ! // RobustBoost needs to scale all of the previous // hyphothesis by exp(-dt) --- 419,423 ---- for (int i=0; i < predictions.length; i++) { ! // RobustBoost needs to scale all of the previous // hyphothesis by exp(-dt) *************** *** 455,458 **** --- 428,435 ---- double exp_negative_dt = Math.exp(-dt) ; + for (int j=0; j < i; j++) { + predictions[j].scale(exp_negative_dt); + } + // for each prediction before this one for (int nodeidx=0;nodeidx < m_predictors.size();nodeidx++) { *************** *** 467,511 **** } } - // multiclass case - else if (m_booster instanceof MulticlassWrapMH) { - Booster b = ((MulticlassWrapMH)m_booster).m_booster; - if (b instanceof RobustBoost) { - for (int i=0; i < predictions.length; i++) { - - if (predictions[i] instanceof MultiPrediction) { - - Prediction[] preds = ((MultiPrediction)predictions[i]).preds; - - for (int j=0;j<preds.length;j++) { - - if (preds[j] instanceof RobustBinaryPrediction) { - - // for every RobustBinaryPrediction added before this one - // we scale all of them by exp(-dt) - double dt = ((RobustBinaryPrediction)preds[j]).getDt(); - double exp_negative_dt = Math.exp(-dt) ; - - // for each prediction before this one - for (int nodeidx=0;nodeidx < m_predictors.size();nodeidx++) { - PredictorNode cpn = (PredictorNode)m_predictors.get(nodeidx); - cpn.prediction.scale(exp_negative_dt); - } - - } - else { - throw new RuntimeException("RobustBinaryPrediction is expected. This should never happen!"); - } - } - } - else { - throw new RuntimeException("RobustBinaryPrediction is expected. This should never happen!"); - } - } - } - } - - //--------------------------------// ! if (node != null && (predictions.length > 0 && !(predictions[0] instanceof NormalizedPrediction)) ) { for (int i=0; i < node.predictorNodes.length; i++) { node.predictorNodes[i].addToPrediction(predictions[i]); --- 444,449 ---- } } ! if (node != null && predictions.length > 0) { for (int i=0; i < node.predictorNodes.length; i++) { node.predictorNodes[i].addToPrediction(predictions[i]); *************** *** 566,570 **** Splitter splitter= acand.getSplitter(); PredictorNode parent= (PredictorNode) m_predictors.get(acand.getPredictorNode()); ! m_booster.update(predictions, partition); if (parent==null) { System.err.println("Adding candidate and the parent is null!"); --- 504,510 ---- Splitter splitter= acand.getSplitter(); PredictorNode parent= (PredictorNode) m_predictors.get(acand.getPredictorNode()); ! ! //m_booster.update(predictions, partition); ! if (parent==null) { System.err.println("Adding candidate and the parent is null!"); *************** *** 657,690 **** } - /** Adjusts the predictions of all of the existing {@link PredictorNode}s in - * the tree. */ - public void adjustPredictions() { - int[][] examples= null; - boolean[] exMask= null; - int count= 0; - PredictorNode[] pn= null; - int s= m_splitters.size(); - int[] pNodes= null; - Bag[] b= null; - int nodeNo= 0; - Prediction[] p= null; - for (int i= 0; i < s; i++) { - pNodes= (int[]) m_splitters.get(i); - pn= new PredictorNode[pNodes.length]; - examples= new int[pNodes.length][]; - b= new Bag[pNodes.length]; - for (int j= 0; j < pNodes.length; j++) { - nodeNo= pNodes[j]; - exMask= (boolean[]) m_masks.get(nodeNo); - pn[j]= (PredictorNode) m_predictors.get(nodeNo); - examples[j]= makeIndices(exMask); - b[j]= m_booster.newBag(examples[j]); - } - p= m_booster.getPredictions(b,examples); - m_booster.update(p, examples); - for (int j= 0; i < pNodes.length; j++) - pn[j].addToPrediction(p[j]); - } - } /** Produces a string describing this tree. */ public String toString() { --- 597,600 ---- *************** *** 699,718 **** public boolean boosterIsFinished() { - if (m_booster instanceof BrownBoost){ - BrownBoost b = (BrownBoost) m_booster; - return b.isFinished(); - } ! if (m_booster instanceof RobustBoost){ RobustBoost b = (RobustBoost) m_booster; return b.isFinished(); } ! if (m_booster instanceof MulticlassWrapMH) { ! if (((MulticlassWrapMH)m_booster).m_booster instanceof RobustBoost) { ! RobustBoost b = (RobustBoost)(((MulticlassWrapMH)m_booster).m_booster); ! return b.isFinished(); ! } ! } double EPS = 1e-50; --- 609,619 ---- public boolean boosterIsFinished() { ! if (m_booster instanceof RobustBoost){ RobustBoost b = (RobustBoost) m_booster; return b.isFinished(); } ! double EPS = 1e-50; *************** *** 786,803 **** - /** Make a integer array of m_examples given an ExampleMask */ - private int[] makeIndices(boolean[] exMask) { - int[] examples= null; - int count= 0; - for (int j= 0; j < exMask.length; j++) - if (exMask[j] == true) - count++; - examples= new int[count]; - count= 0; - for (int j= 0; j < exMask.length; j++) - if (exMask[j] == true) - examples[count++]= j; - return (examples); - } /** the last base predictor added to the tree */ private Predictor lastBasePredictor= null; --- 687,690 ---- |