From: Sunsern C. <sch...@us...> - 2009-02-19 14:20:36
|
Update of /cvsroot/jboost/jboost/src/jboost/atree In directory fdv4jf1.ch3.sourceforge.com:/tmp/cvs-serv31946/src/jboost/atree Modified Files: InstrumentedAlternatingTree.java Log Message: * improve how prediction nodes get normalized. * multiclass support for RobustBoost Index: InstrumentedAlternatingTree.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/atree/InstrumentedAlternatingTree.java,v retrieving revision 1.8 retrieving revision 1.9 diff -C2 -d -r1.8 -r1.9 *** InstrumentedAlternatingTree.java 9 Feb 2009 10:47:50 -0000 1.8 --- InstrumentedAlternatingTree.java 19 Feb 2009 14:20:25 -0000 1.9 *************** *** 14,21 **** --- 14,23 ---- import jboost.booster.Booster; import jboost.booster.BrownBoost; + import jboost.booster.MulticlassWrapMH; import jboost.booster.Prediction; import jboost.booster.NormalizedPrediction; import jboost.booster.RobustBinaryPrediction; import jboost.booster.RobustBoost; + import jboost.booster.MulticlassWrapMH.MultiPrediction; import jboost.controller.Configuration; import jboost.controller.ConfigurationException; *************** *** 389,419 **** // 1.a) Add the new prediction nodes to the alternating tree list pInt[i]= m_predictors.size(); - - // RobustBoost needs to scale all of the previous - // hyphothesis by exp(-dt) - if (predictions[i] instanceof RobustBinaryPrediction) { - // for every RobustBinaryPrediction added before this one - // we scale all of them by exp(-dt) - - double dt = ((RobustBinaryPrediction)predictions[i]).getDt(); - - if (dt > 0) { - double exp_negative_dt = Math.exp(-dt) ; - - // for each prediction before this one - for (int nodeidx=0;nodeidx < m_predictors.size();nodeidx++) { - PredictorNode cpn = (PredictorNode)m_predictors.get(nodeidx); - if (cpn.prediction instanceof RobustBinaryPrediction) { - - //System.out.println(">> scaling predictions"); - //System.out.println("Before: " + ((RobustBinaryPrediction)cpn.prediction).getClassScores()[1]); - ((RobustBinaryPrediction)cpn.prediction).scale(exp_negative_dt); - //System.out.println("After: " + ((RobustBinaryPrediction)cpn.prediction).getClassScores()[1]); - - } - } - } - } - addPredictorNodeToList(pNode[i]); // 1.b) Generate the exampleMasks for the split. --- 391,394 ---- *************** *** 464,467 **** --- 439,510 ---- lastBasePredictor= new AtreePredictor(splitter, parent, predictions, m_booster); node= findSplitter(parent, splitter); + + //--------- RobustBoost ----------// + + // binary case + if (m_booster instanceof RobustBoost) { + + for (int i=0; i < predictions.length; i++) { + + // RobustBoost needs to scale all of the previous + // hyphothesis by exp(-dt) + if (predictions[i] instanceof RobustBinaryPrediction) { + // for every RobustBinaryPrediction added before this one + // we scale all of them by exp(-dt) + double dt = ((RobustBinaryPrediction)predictions[i]).getDt(); + double exp_negative_dt = Math.exp(-dt) ; + + // for each prediction before this one + for (int nodeidx=0;nodeidx < m_predictors.size();nodeidx++) { + PredictorNode cpn = (PredictorNode)m_predictors.get(nodeidx); + cpn.prediction.scale(exp_negative_dt); + } + + } + else { + throw new RuntimeException("RobustBinaryPrediction is expected. This should never happen!"); + } + } + } + // multiclass case + else if (m_booster instanceof MulticlassWrapMH) { + Booster b = ((MulticlassWrapMH)m_booster).m_booster; + if (b instanceof RobustBoost) { + for (int i=0; i < predictions.length; i++) { + + if (predictions[i] instanceof MultiPrediction) { + + Prediction[] preds = ((MultiPrediction)predictions[i]).preds; + + for (int j=0;j<preds.length;j++) { + + if (preds[j] instanceof RobustBinaryPrediction) { + + // for every RobustBinaryPrediction added before this one + // we scale all of them by exp(-dt) + double dt = ((RobustBinaryPrediction)preds[j]).getDt(); + double exp_negative_dt = Math.exp(-dt) ; + + // for each prediction before this one + for (int nodeidx=0;nodeidx < m_predictors.size();nodeidx++) { + PredictorNode cpn = (PredictorNode)m_predictors.get(nodeidx); + cpn.prediction.scale(exp_negative_dt); + } + + } + else { + throw new RuntimeException("RobustBinaryPrediction is expected. This should never happen!"); + } + } + } + else { + throw new RuntimeException("RobustBinaryPrediction is expected. This should never happen!"); + } + } + } + } + + //--------------------------------// + if (node != null && (predictions.length > 0 && !(predictions[0] instanceof NormalizedPrediction)) ) { for (int i=0; i < node.predictorNodes.length; i++) { *************** *** 473,477 **** } } ! /* System.out.println("Adding Candidate: " + candidate); --- 516,521 ---- } } ! ! /* System.out.println("Adding Candidate: " + candidate); *************** *** 655,668 **** public boolean boosterIsFinished() { ! if(m_booster instanceof BrownBoost){ BrownBoost b = (BrownBoost) m_booster; return b.isFinished(); } ! if(m_booster instanceof RobustBoost){ RobustBoost b = (RobustBoost) m_booster; return b.isFinished(); } ! double EPS = 1e-50; double w = m_booster.getTotalWeight(); --- 699,719 ---- public boolean boosterIsFinished() { ! if (m_booster instanceof BrownBoost){ BrownBoost b = (BrownBoost) m_booster; return b.isFinished(); } ! if (m_booster instanceof RobustBoost){ RobustBoost b = (RobustBoost) m_booster; return b.isFinished(); } ! ! if (m_booster instanceof MulticlassWrapMH) { ! if (((MulticlassWrapMH)m_booster).m_booster instanceof RobustBoost) { ! RobustBoost b = (RobustBoost)(((MulticlassWrapMH)m_booster).m_booster); ! return b.isFinished(); ! } ! } ! double EPS = 1e-50; double w = m_booster.getTotalWeight(); |