From: Sunsern C. <sch...@us...> - 2009-02-09 10:45:36
|
Update of /cvsroot/jboost/jboost/src/jboost/booster In directory fdv4jf1.ch3.sourceforge.com:/tmp/cvs-serv31004/src/jboost/booster Modified Files: RobustBinaryPrediction.java RobustBoostTest.java RobustBoost.java Log Message: fixed bugs and tested with Long&Servedio data Index: RobustBoostTest.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/booster/RobustBoostTest.java,v retrieving revision 1.1 retrieving revision 1.2 diff -C2 -d -r1.1 -r1.2 *** RobustBoostTest.java 5 Feb 2009 05:53:27 -0000 1.1 --- RobustBoostTest.java 9 Feb 2009 10:45:30 -0000 1.2 *************** *** 3,7 **** --- 3,14 ---- + import java.io.BufferedReader; + import java.io.FileNotFoundException; + import java.io.FileReader; + import java.util.StringTokenizer; + import jboost.booster.RobustBoost.NewtonSolver; + import jboost.booster.RobustBoost.RobustBinaryBag; + import jboost.examples.Label; /** *************** *** 15,19 **** RobustBoost m_robustBoost; /** ! * Constructor for BrownBoostTest. * @param arg0 */ --- 22,26 ---- RobustBoost m_robustBoost; /** ! * Constructor for RobustBoostTest. * @param arg0 */ *************** *** 23,27 **** /** ! * Tests the BrownBoost constructor and sets up boosters for * other tests. * @see TestCase#setUp() --- 30,34 ---- /** ! * Tests the RobustBoost constructor and sets up boosters for * other tests. * @see TestCase#setUp() *************** *** 94,111 **** } ! final public void testNewtonSolver() { ! double sigma_f = 0.1d; ! double theta = 0.2d; ! double rho = 1.6; - double[] margins = new double[] {-10,-10,-10,-10,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,10,10,10,10}; - double[] steps = new double[] {-1,-1,1,1,-1,-1,1,1,1,-1,1,-1,1,-1,1,1,1,1,-1,1,1}; ! NewtonSolver ns = m_robustBoost.new NewtonSolver(margins,steps,0.1,0.6,10,rho,theta,sigma_f); ! //System.out.print(ns.getLog()); ! assertEquals(ns.getDs(),1.427, 0.001); ! assertEquals(ns.getDt(),0.263, 0.001); } - } --- 101,237 ---- } ! ! ! final public void testLongData() { ! int[] best = new int[] {1,2,10,12,1,7,4,17,8,1,9,2,6,15,3,20,19,13,16,14}; ! double[] times = new double[] { 0.125, 0.172, 0.189, 0.203, 0.212, 0.226, 0.234, 0.240, 0.246, 0.252, ! 0.256, 0.260, 0.264, 0.268, 0.271, 0.273, 0.276, 0.278, 0.2791, 0.280}; ! ! int numFeatures = 21; ! int numExamples = 800; ! ! try { ! ! int[][] data = new int[numExamples][numFeatures]; ! int[] labels = new int[numExamples]; ! ! RobustBoost rBoost = new RobustBoost(false,0.14,new double[] {0.2,0.2}, new double[] {0.1,0.1}, new double[] {1,1}); ! ! assertEquals(rBoost.m_epsilon,0.14,0.0001); ! assertEquals(rBoost.m_theta[0],0.2,0.0001); ! assertEquals(rBoost.m_sigma_f[0],0.1,0.0001); ! assertEquals(rBoost.m_rho[0],0.7233,0.0001); ! ! BufferedReader br = new BufferedReader(new FileReader("demo/Long.train")); ! // read input ! String line = br.readLine(); ! int j=0; ! while (line != null) { ! ! StringTokenizer st = new StringTokenizer(line,",;"); ! ! labels[j] = Math.round(Float.parseFloat(st.nextToken())); ! ! int i =0; ! while (st.hasMoreTokens()) { ! String s = st.nextToken(); ! data[j][i] = Math.round(Float.parseFloat(s)); ! assertEquals(data[j][i] == -1 || data[j][i] == 1, true); ! i++; ! } ! j++; ! line = br.readLine(); ! } ! br.close(); ! ! for (int i=0;i<numExamples;i++) { ! assertEquals((labels[i] == -1) || (labels[i] == 1), true); ! ! if (labels[i] == -1) labels[i] = 0; ! ! rBoost.addExample(i, new Label(labels[i])); ! } ! ! rBoost.finalizeData(); ! int iter=0; ! while (iter < 20 && !rBoost.isFinished()) { ! ! int[][] best_lists = new int[][] {null, null}; ! int best_feature = -1; ! double best_gain = -1; ! ! // find the best weak rule ! for (int i=0;i<numFeatures;i++) { ! ! double gain, gain0 = 0, gain1 = 0; ! int count0 = 0, count1 = 0, p=0; ! for (int k=0;k<numExamples;k++) { ! ! if (data[k][i] == -1) count0++; ! else count1++; ! ! if (labels[k] == 0 && data[k][i] == -1) { ! gain0 += rBoost.m_weights[k]; ! } ! else if (labels[k] == 0 && data[k][i] == 1) { ! gain0 -= rBoost.m_weights[k]; ! } ! ! if (labels[k] == 1 && data[k][i] == 1) { ! gain1 += rBoost.m_weights[k]; ! } ! else if (labels[k] == 1 && data[k][i] == -1) { ! gain1 -= rBoost.m_weights[k]; ! } ! ! } ! ! gain = gain0 + gain1; ! ! int[] list0 = new int[count0]; ! int[] list1 = new int[count1]; ! int list0_idx = 0, list1_idx = 0; ! for (int k=0;k<numExamples;k++) { ! if (data[k][i] == -1) list0[list0_idx++] = k; ! else if (data[k][i] == 1) list1[list1_idx++] = k; ! } ! ! gain = Math.abs(gain); ! ! ! if (gain > best_gain) { ! best_feature = i; ! best_gain = gain; ! best_lists[0] = list0; ! best_lists[1] = list1; ! } ! } ! ! ! Bag[] best_bags = new Bag[2]; ! best_bags[0] = rBoost.newBag(best_lists[0]); ! best_bags[1] = rBoost.newBag(best_lists[1]); ! ! int[][] exampleIndex = new int[2][]; ! exampleIndex[0] = best_lists[0]; ! exampleIndex[1] = best_lists[1]; ! ! Prediction[] predictions = rBoost.getPredictions(best_bags, exampleIndex); ! rBoost.update(predictions, exampleIndex); ! ! assertEquals(best_feature+1,best[iter]); ! assertEquals(rBoost.m_t,times[iter],0.001); ! ! iter++; ! } ! ! } catch (Exception e) { ! // TODO Auto-generated catch block ! e.printStackTrace(); ! } ! } } Index: RobustBinaryPrediction.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/booster/RobustBinaryPrediction.java,v retrieving revision 1.1 retrieving revision 1.2 diff -C2 -d -r1.1 -r1.2 *** RobustBinaryPrediction.java 5 Feb 2009 05:53:27 -0000 1.1 --- RobustBinaryPrediction.java 9 Feb 2009 10:45:30 -0000 1.2 *************** *** 12,15 **** --- 12,17 ---- public class RobustBinaryPrediction extends Prediction{ + /** starting point for Newton's method */ + protected double init_ds; /** time step used to scale previous predictions */ protected double dt; Index: RobustBoost.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/booster/RobustBoost.java,v retrieving revision 1.2 retrieving revision 1.3 diff -C2 -d -r1.2 -r1.3 *** RobustBoost.java 5 Feb 2009 06:32:38 -0000 1.2 --- RobustBoost.java 9 Feb 2009 10:45:30 -0000 1.3 *************** *** 9,14 **** import jboost.examples.Label; - - /** * Java implemantation of RobustBoost. --- 9,12 ---- *************** *** 31,68 **** protected double[] m_sampleWeights; [...1463 lines suppressed...] if (!succeeded) { ! log.append("BinarySearch failed!\n"); ds = Double.NaN; dt = Double.NaN; } else { ! log.append("BinarySearch completed successfully!\n"); } } *************** *** 1137,1140 **** } ! ! } /** end of class AdaBoost */ --- 1403,1406 ---- } ! ! } |