From: Sunsern C. <sch...@us...> - 2009-03-12 23:42:13
|
Update of /cvsroot/jboost/jboost/src/jboost/visualization In directory fdv4jf1.ch3.sourceforge.com:/tmp/cvs-serv16854/src/jboost/visualization Modified Files: DataSet.java HistogramFrame.java Log Message: * Remove an early stopping criteria from RobustBoost * New HistogramFrame.java and DataSet.java that works with new VisualizeScores.py. Index: DataSet.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/visualization/DataSet.java,v retrieving revision 1.6 retrieving revision 1.7 diff -C2 -d -r1.6 -r1.7 *** DataSet.java 15 Sep 2008 22:24:44 -0000 1.6 --- DataSet.java 12 Mar 2009 23:42:04 -0000 1.7 *************** *** 110,116 **** */ public void addScoresList(ArrayList<DataElement> scores,int index) { ! Collections.sort(scores); Object[] a = (Object[]) scores.toArray(); ! System.out.printf("index=%d, a.length=%d%n",index,a.length); int neg_count=0; --- 110,118 ---- */ public void addScoresList(ArrayList<DataElement> scores,int index) { ! ! Collections.sort(scores); Object[] a = (Object[]) scores.toArray(); ! ! //System.out.printf("index=%d, a.length=%d%n",index,a.length); int neg_count=0; *************** *** 126,131 **** --- 128,135 ---- else if(e.label == pos_label) pos_count++; } + total_neg=neg_count; total_pos=pos_count; + //System.out.printf("total_neg=%d, total_pos=%d%n",total_neg,total_pos); *************** *** 139,146 **** for(int i=a.length-1; i >= 0; i--) { DataElement e = ((DataElement) a[i]); if(e.label == neg_label) neg_count++; else if(e.label == pos_label) pos_count++; e.truePositives=pos_count; ! e.falsePositives=neg_count; //System.out.println(e); } --- 143,153 ---- for(int i=a.length-1; i >= 0; i--) { DataElement e = ((DataElement) a[i]); + if(e.label == neg_label) neg_count++; else if(e.label == pos_label) pos_count++; + e.truePositives=pos_count; ! e.falsePositives=neg_count; ! //System.out.println(e); } *************** *** 165,177 **** for(int i=0; i<bins; i++) { DataElement e = iterData.get(binarySearch(iterData, s)); //System.out.printf("label=%d, i= %d, s=%f, prev=%f, e=",label,i,s,prev); //System.out.println(e); if(label==pos_label) { ! h[i]=prev-e.truePositives; ! prev=e.truePositives; } else { ! h[i]=prev-e.falsePositives; ! prev=e.falsePositives; } s=s+step; //System.out.println(prev); --- 172,191 ---- for(int i=0; i<bins; i++) { DataElement e = iterData.get(binarySearch(iterData, s)); + //System.out.printf("label=%d, i= %d, s=%f, prev=%f, e=",label,i,s,prev); //System.out.println(e); + + // exclude its own label + double tp = (e.label == pos_label)? e.truePositives-1: e.truePositives; + double fp = (e.label == neg_label)? e.falsePositives-1: e.falsePositives; + if(label==pos_label) { ! h[i]=prev-tp; ! prev=tp; } else { ! h[i]=prev-fp; ! prev=fp; } + s=s+step; //System.out.println(prev); *************** *** 196,200 **** int i=binarySearch(iterData, lowerScore); DataElement e = iterData.get(i); ! if(e.value<lowerScore && i<iterData.size()) { i++; e = iterData.get(i); --- 210,214 ---- int i=binarySearch(iterData, lowerScore); DataElement e = iterData.get(i); ! if(e.value<lowerScore && i+1<iterData.size()) { i++; e = iterData.get(i); *************** *** 219,230 **** if(s<list.get(0).value) return 0; if(s>list.get(l-1).value) return list.size()-1; - double l2=Math.floor(Math.log((double) l)/Math.log(2.0)); int index= 0; int step= (int) Math.pow(2, l2); DataElement e=list.get(index); ! while(e.value != s && step>0) { if(index+step<l) { ! if(list.get(index+step).value<=s) {index=index+step;} } e=list.get(index); --- 233,243 ---- if(s<list.get(0).value) return 0; if(s>list.get(l-1).value) return list.size()-1; double l2=Math.floor(Math.log((double) l)/Math.log(2.0)); int index= 0; int step= (int) Math.pow(2, l2); DataElement e=list.get(index); ! while(Math.abs(e.value-s) > 1e-7 && step>0) { if(index+step<l) { ! if(list.get(index+step).value <= s) {index=index+step;} } e=list.get(index); *************** *** 248,252 **** public double[] getFPTP(double v) { - ArrayList<DataElement> iterData = data.get(iteration); DataElement e = iterData.get(binarySearch(iterData,v)); --- 261,264 ---- *************** *** 255,258 **** --- 267,284 ---- } + public double getScoreAtTPThreshold(double threshold) { + Object[] a = (Object[]) data.get(iteration).toArray(); + + for(int i=a.length-1; i >= 0; i--) { + DataElement e = ((DataElement) a[i]); + if (e.truePositives / total_pos > threshold) { + return e.value; + } + } + + return Double.NEGATIVE_INFINITY; + } + + /** * @return the iteration Index: HistogramFrame.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/visualization/HistogramFrame.java,v retrieving revision 1.6 retrieving revision 1.7 diff -C2 -d -r1.6 -r1.7 *** HistogramFrame.java 15 Sep 2008 22:24:44 -0000 1.6 --- HistogramFrame.java 12 Mar 2009 23:42:04 -0000 1.7 *************** *** 1,19 **** package jboost.visualization; import java.awt.Color; ! import java.awt.GradientPaint; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; ! import java.awt.event.ComponentAdapter; ! import javax.swing.BorderFactory; import javax.swing.BoxLayout; ! import javax.swing.DefaultComboBoxModel; [...1576 lines suppressed...] + for (int i=0;i<prediction.length-1;i++) { + ret = ret + (prediction[i]+label_offset) + ","; + } + ret = ret + (prediction[prediction.length-1]+label_offset) + "]"; + return ret; + } + } + + public boolean contains(int k) { + for (int i=0;i<prediction.length;i++) { + if (prediction[i] == k) return true; + } + return false; + } + } + + + + } |