From: Aaron A. <aa...@us...> - 2007-05-29 05:04:03
|
Update of /cvsroot/jboost/jboost/src/jboost/booster In directory sc8-pr-cvs6.sourceforge.net:/tmp/cvs-serv16590/booster Modified Files: AbstractBooster.java MulticlassWrapMH.java Log Message: multiclass and multilabel are handled a little differently when outputting margins. Index: MulticlassWrapMH.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/booster/MulticlassWrapMH.java,v retrieving revision 1.1.1.1 retrieving revision 1.2 diff -C2 -d -r1.1.1.1 -r1.2 *** MulticlassWrapMH.java 16 May 2007 04:06:02 -0000 1.1.1.1 --- MulticlassWrapMH.java 29 May 2007 05:03:57 -0000 1.2 *************** *** 23,29 **** class MulticlassWrapMH extends AbstractBooster { ! private AbstractBooster booster; // underlying booster ! private int num_classes; // number of labels /** --- 23,35 ---- class MulticlassWrapMH extends AbstractBooster { + + /** The underlying m_booster */ + private AbstractBooster m_booster; + + /** The number of labels */ + private int m_numLabels; ! /** The number of labels */ ! private boolean m_isMultiLabel; /** *************** *** 32,43 **** * @param k total number of m_labels */ ! MulticlassWrapMH(AbstractBooster booster, int k) { ! this.booster = booster; ! num_classes = k; } public String toString() { ! if (booster==null) { ! String msg = "MulticlassWrapMH.toString: booster is null"; if (Monitor.logLevel>3) { Monitor.log(msg); --- 38,50 ---- * @param k total number of m_labels */ ! MulticlassWrapMH(AbstractBooster booster, int numLabels, boolean isMultiLabel) { ! m_booster = booster; ! m_numLabels = numLabels; ! m_isMultiLabel = isMultiLabel; } public String toString() { ! if (m_booster==null) { ! String msg = "MulticlassWrapMH.toString: m_booster is null"; if (Monitor.logLevel>3) { Monitor.log(msg); *************** *** 46,51 **** } ! return ("MulticlassWrapMH. # of classes = " + num_classes + ! ".\nUnderlying booster:\n" + booster); } --- 53,58 ---- } ! return ("MulticlassWrapMH. # of classes = " + m_numLabels + ! ".\nUnderlying m_booster:\n" + m_booster); } *************** *** 58,64 **** */ public void addExample(int index, Label label, double weight) { ! int s = index * num_classes; ! for (int j = 0; j < num_classes; j++) { ! booster.addExample(s+j, new Label(label.getMultiValue(j) ? 1 : 0), --- 65,71 ---- */ public void addExample(int index, Label label, double weight) { ! int s = index * m_numLabels; ! for (int j = 0; j < m_numLabels; j++) { ! m_booster.addExample(s+j, new Label(label.getMultiValue(j) ? 1 : 0), *************** *** 68,76 **** public void finalizeData() { ! booster.finalizeData(); } public void clear() { ! booster.clear(); } --- 75,83 ---- public void finalizeData() { ! m_booster.finalizeData(); } public void clear() { ! m_booster.clear(); } *************** *** 87,91 **** */ public double calculateWeight(double margin) { ! return booster.calculateWeight(margin); } --- 94,98 ---- */ public double calculateWeight(double margin) { ! return m_booster.calculateWeight(margin); } *************** *** 93,104 **** public void update(Prediction[] preds, int[][] index) { int num_preds = preds.length; ! Prediction[] upreds = new Prediction[num_preds * num_classes]; ! int[][] uindex = new int[index.length * num_classes][]; int i, j, t, k; ! // create array of predictions to pass to underlying booster t = 0; for (i = 0; i < num_preds; i++) { ! for (j = 0; j < num_classes; j++) { upreds[t] = ((MultiPrediction) preds[i]).preds[j]; t++; --- 100,112 ---- public void update(Prediction[] preds, int[][] index) { int num_preds = preds.length; ! Prediction[] upreds = new Prediction[num_preds * m_numLabels]; ! int[][] uindex = new int[index.length * m_numLabels][]; int i, j, t, k; ! ! // create array of predictions to pass to underlying booster t = 0; for (i = 0; i < num_preds; i++) { ! for (j = 0; j < m_numLabels; j++) { upreds[t] = ((MultiPrediction) preds[i]).preds[j]; t++; *************** *** 109,129 **** t = 0; for (i = 0; i < index.length; i++) { ! for (j = 0; j < num_classes; j++) { uindex[t] = new int[index[i].length]; for (k = 0; k < index[i].length; k++) ! uindex[t][k] = index[i][k] * num_classes + j; t++; } } ! booster.update(upreds, uindex); } /** ! * computes theoretical bound as (num_classes/2) * theoretical * bound for underlying booster. This computation may not be * correct in all cases. */ public double getTheoryBound() { ! return 0.5 * num_classes * booster.getTheoryBound(); } --- 117,137 ---- t = 0; for (i = 0; i < index.length; i++) { ! for (j = 0; j < m_numLabels; j++) { uindex[t] = new int[index[i].length]; for (k = 0; k < index[i].length; k++) ! uindex[t][k] = index[i][k] * m_numLabels + j; t++; } } ! m_booster.update(upreds, uindex); } /** ! * computes theoretical bound as (m_numLabels/2) * theoretical * bound for underlying booster. This computation may not be * correct in all cases. */ public double getTheoryBound() { ! return 0.5 * m_numLabels * m_booster.getTheoryBound(); } *************** *** 133,137 **** */ public double[] getMargins() { ! return booster.getMargins(); } --- 141,145 ---- */ public double[] getMargins() { ! return m_booster.getMargins(); } *************** *** 141,158 **** */ public Prediction[] getPredictions(Bag[] b, int[][] exampleIndex) { ! Bag[] ubags = new Bag[b.length * num_classes]; for (int i = 0; i < b.length; i++) ! for (int j = 0; j < num_classes; j++) ! ubags[i * num_classes + j] = ((MultiBag) b[i]).bags[j]; ! Prediction[] upreds = booster.getPredictions(ubags, exampleIndex); Prediction[] preds = new Prediction[b.length]; for (int i = 0; i < b.length; i++) { preds[i] = new MultiPrediction(); ! for (int j = 0; j < num_classes; j++) ((MultiPrediction) preds[i]).preds[j] = ! upreds[i * num_classes + j]; } return preds; --- 149,166 ---- */ public Prediction[] getPredictions(Bag[] b, int[][] exampleIndex) { ! Bag[] ubags = new Bag[b.length * m_numLabels]; for (int i = 0; i < b.length; i++) ! for (int j = 0; j < m_numLabels; j++) ! ubags[i * m_numLabels + j] = ((MultiBag) b[i]).bags[j]; ! Prediction[] upreds = m_booster.getPredictions(ubags, exampleIndex); Prediction[] preds = new Prediction[b.length]; for (int i = 0; i < b.length; i++) { preds[i] = new MultiPrediction(); ! for (int j = 0; j < m_numLabels; j++) ((MultiPrediction) preds[i]).preds[j] = ! upreds[i * m_numLabels + j]; } return preds; *************** *** 167,177 **** public double getLoss(Bag[] b) { ! Bag[] ubags = new Bag[b.length * num_classes]; for (int i = 0; i < b.length; i++) ! for (int j = 0; j < num_classes; j++) ! ubags[i * num_classes + j] = ((MultiBag) b[i]).bags[j]; ! return booster.getLoss(ubags); } --- 175,185 ---- public double getLoss(Bag[] b) { ! Bag[] ubags = new Bag[b.length * m_numLabels]; for (int i = 0; i < b.length; i++) ! for (int j = 0; j < m_numLabels; j++) ! ubags[i * m_numLabels + j] = ((MultiBag) b[i]).bags[j]; ! return m_booster.getLoss(ubags); } *************** *** 186,197 **** private MultiBag() { ! bags = new Bag[num_classes]; ! for (int j = 0; j < num_classes; j++) ! bags[j] = booster.newBag(); } public String toString() { String s = "MultiBag.\n"; ! for (int j = 0; j < num_classes; j++) s += "bag " + j + ":\n" + bags[j]; return s; --- 194,205 ---- private MultiBag() { ! bags = new Bag[m_numLabels]; ! for (int j = 0; j < m_numLabels; j++) ! bags[j] = m_booster.newBag(); } public String toString() { String s = "MultiBag.\n"; ! for (int j = 0; j < m_numLabels; j++) s += "bag " + j + ":\n" + bags[j]; return s; *************** *** 199,208 **** public void reset() { ! for (int j = 0; j < num_classes; j++) bags[j].reset(); } public boolean isWeightless() { ! for (int j = 0; j < num_classes; j++) if (! bags[j].isWeightless() ) { return false; --- 207,216 ---- public void reset() { ! for (int j = 0; j < m_numLabels; j++) bags[j].reset(); } public boolean isWeightless() { ! for (int j = 0; j < m_numLabels; j++) if (! bags[j].isWeightless() ) { return false; *************** *** 212,223 **** public void addExample(int index) { ! int s = index * num_classes; ! for (int j = 0; j < num_classes; j++) bags[j].addExample(s + j); } public void subtractExample(int index) { ! int s = index * num_classes; ! for (int j = 0; j < num_classes; j++) bags[j].subtractExample(s + j); } --- 220,231 ---- public void addExample(int index) { ! int s = index * m_numLabels; ! for (int j = 0; j < m_numLabels; j++) bags[j].addExample(s + j); } public void subtractExample(int index) { ! int s = index * m_numLabels; ! for (int j = 0; j < m_numLabels; j++) bags[j].subtractExample(s + j); } *************** *** 227,233 **** int[] s = new int[l.length]; for (i = 0; i < l.length; i++) ! s[i] = l[i] * num_classes; bags[0].addExampleList(s); ! for (int j = 1; j < num_classes; j++) { for (i = 0; i < l.length; i++) s[i]++; --- 235,241 ---- int[] s = new int[l.length]; for (i = 0; i < l.length; i++) ! s[i] = l[i] * m_numLabels; bags[0].addExampleList(s); ! for (int j = 1; j < m_numLabels; j++) { for (i = 0; i < l.length; i++) s[i]++; *************** *** 240,246 **** int[] s = new int[l.length]; for (i = 0; i < l.length; i++) ! s[i] = l[i] * num_classes; bags[0].subtractExampleList(s); ! for (int j = 1; j < num_classes; j++) { for (i = 0; i < l.length; i++) s[i]++; --- 248,254 ---- int[] s = new int[l.length]; for (i = 0; i < l.length; i++) ! s[i] = l[i] * m_numLabels; bags[0].subtractExampleList(s); ! for (int j = 1; j < m_numLabels; j++) { for (i = 0; i < l.length; i++) s[i]++; *************** *** 251,255 **** public void addBag(Bag b) { MultiBag other = (MultiBag) b; ! for (int j = 0; j < num_classes; j++) bags[j].addBag(other.bags[j]); } --- 259,263 ---- public void addBag(Bag b) { MultiBag other = (MultiBag) b; ! for (int j = 0; j < m_numLabels; j++) bags[j].addBag(other.bags[j]); } *************** *** 257,261 **** public void subtractBag(Bag b) { MultiBag other = (MultiBag) b; ! for (int j = 0; j < num_classes; j++) bags[j].subtractBag(other.bags[j]); } --- 265,269 ---- public void subtractBag(Bag b) { MultiBag other = (MultiBag) b; ! for (int j = 0; j < m_numLabels; j++) bags[j].subtractBag(other.bags[j]); } *************** *** 263,273 **** public void copyBag(Bag b) { MultiBag other = (MultiBag) b; ! for (int j = 0; j < num_classes; j++) bags[j].copyBag(other.bags[j]); } public void refresh(int index) { ! int s = index * num_classes; ! for (int j = 0; j < num_classes; j++) bags[j].refresh(s + j); } --- 271,281 ---- public void copyBag(Bag b) { MultiBag other = (MultiBag) b; ! for (int j = 0; j < m_numLabels; j++) bags[j].copyBag(other.bags[j]); } public void refresh(int index) { ! int s = index * m_numLabels; ! for (int j = 0; j < m_numLabels; j++) bags[j].refresh(s + j); } *************** *** 277,283 **** int[] s = new int[l.length]; for (i = 0; i < l.length; i++) ! s[i] = l[i] * num_classes; bags[0].refreshList(s); ! for (int j = 1; j < num_classes; j++) { for (i = 0; i < l.length; i++) s[i]++; --- 285,291 ---- int[] s = new int[l.length]; for (i = 0; i < l.length; i++) ! s[i] = l[i] * m_numLabels; bags[0].refreshList(s); ! for (int j = 1; j < m_numLabels; j++) { for (i = 0; i < l.length; i++) s[i]++; *************** *** 292,296 **** public double getLoss() { double loss = 0.0; ! for (int j = 0; j < num_classes; j++) loss += bags[j].getLoss(); return loss; --- 300,304 ---- public double getLoss() { double loss = 0.0; ! for (int j = 0; j < m_numLabels; j++) loss += bags[j].getLoss(); return loss; *************** *** 300,304 **** throws jboost.NotSupportedException { double loss = 0.0; ! for (int j = 0; j < num_classes; j++) loss += bags[j].getLoss(s); return loss; --- 308,312 ---- throws jboost.NotSupportedException { double loss = 0.0; ! for (int j = 0; j < m_numLabels; j++) loss += bags[j].getLoss(s); return loss; *************** *** 311,314 **** --- 319,325 ---- * underlying booster, one for each class. */ class MultiPrediction extends Prediction { + /** + * The predictions made. Has same length as the number of classes. + */ private Prediction[] preds; *************** *** 317,321 **** */ private MultiPrediction() { ! preds = new Prediction[num_classes]; } --- 328,332 ---- */ private MultiPrediction() { ! preds = new Prediction[m_numLabels]; } *************** *** 323,327 **** MultiPrediction newpred = new MultiPrediction(); ! for (int j = 0; j < num_classes; j++) { newpred.preds[j] = (Prediction) preds[j].clone(); } --- 334,338 ---- MultiPrediction newpred = new MultiPrediction(); ! for (int j = 0; j < m_numLabels; j++) { newpred.preds[j] = (Prediction) preds[j].clone(); } *************** *** 330,334 **** public Prediction add(Prediction p) { ! for (int j = 0; j < num_classes; j++) { preds[j].add(((MultiPrediction) p).preds[j]); } --- 341,345 ---- public Prediction add(Prediction p) { ! for (int j = 0; j < m_numLabels; j++) { preds[j].add(((MultiPrediction) p).preds[j]); } *************** *** 337,341 **** public Prediction scale(double w) { ! for (int j = 0; j < num_classes; j++) preds[j].scale(w); return this; --- 348,352 ---- public Prediction scale(double w) { ! for (int j = 0; j < m_numLabels; j++) preds[j].scale(w); return this; *************** *** 343,347 **** public Prediction add(double w, Prediction p) { ! for (int j = 0; j < num_classes; j++) { preds[j].add(w, ((MultiPrediction) p).preds[j]); } --- 354,358 ---- public Prediction add(double w, Prediction p) { ! for (int j = 0; j < m_numLabels; j++) { preds[j].add(w, ((MultiPrediction) p).preds[j]); } *************** *** 350,357 **** public double[] getMargins(Label l) { ! double[] ret = new double[num_classes]; ! for (int j = 0; j < num_classes; j++) { ! ret[j] = preds[j].getMargins(new Label(l.getMultiValue(j) ? ! 1 : 0))[0]; } return ret; --- 361,380 ---- public double[] getMargins(Label l) { ! double[] ret; ! if (m_isMultiLabel) { ! ret = new double[m_numLabels]; ! } else { ! ret = new double[1]; ! } ! for (int j = 0; j < m_numLabels; j++) { ! if (m_isMultiLabel) { ! ret[j] = preds[j].getMargins(new Label(l.getMultiValue(j) ? ! 1 : 0))[0]; ! } else { ! if (l.getMultiValue(j)){ ! ret[0] = preds[j].getMargins(new Label(l.getMultiValue(j) ? ! 1 : 0))[0]; ! } ! } } return ret; *************** *** 359,366 **** public double[] getClassScores() { ! double[] scores = new double[num_classes]; double[] uscore; ! for (int j = 0; j < num_classes; j++) { uscore = preds[j].getClassScores(); scores[j] = uscore[1]; --- 382,389 ---- public double[] getClassScores() { ! double[] scores = new double[m_numLabels]; double[] uscore; ! for (int j = 0; j < m_numLabels; j++) { uscore = preds[j].getClassScores(); scores[j] = uscore[1]; *************** *** 377,381 **** boolean retval= true; MultiPrediction other= (MultiPrediction) p; ! for (int k=0; k < num_classes; k++) { if (!preds[k].equals(other.preds[k])) { retval= false; --- 400,404 ---- boolean retval= true; MultiPrediction other= (MultiPrediction) p; ! for (int k=0; k < m_numLabels; k++) { if (!preds[k].equals(other.preds[k])) { retval= false; *************** *** 387,392 **** public String toString() { String s = "MultiPrediction.\n"; ! for (int j = 0; j < num_classes; j++) ! s += "prediction " + j + preds[j] + "\n"; return s; } --- 410,415 ---- public String toString() { String s = "MultiPrediction.\n"; ! for (int j = 0; j < m_numLabels; j++) ! s += "prediction " + j + ": " + preds[j] + "\n"; return s; } *************** *** 394,398 **** public String shortText() { String s = "[,"+preds[0]; ! for (int j = 0; j < num_classes; j++) s += ","+preds[j]; return s+"]"; --- 417,421 ---- public String shortText() { String s = "[,"+preds[0]; ! for (int j = 0; j < m_numLabels; j++) s += ","+preds[j]; return s+"]"; *************** *** 402,417 **** String code = ""; ! code += "typedef double Prediction_t[" + num_classes + "];\n"; code += "#define reset_pred() { \\\n"; ! for (int i = 0; i < num_classes; i++) code += " p["+i+"] = 0.0; \\\n"; code += " }\n"; code += "#define add_pred("; ! for (int i = 0; i < num_classes; i++) code += (i == 0 ? "" : ",") + "X" + i; code += ") { \\\n"; ! for (int i = 0; i < num_classes; i++) code += " p["+i+"] += X"+i+"; \\\n"; code += " }\n"; --- 425,440 ---- String code = ""; ! code += "typedef double Prediction_t[" + m_numLabels + "];\n"; code += "#define reset_pred() { \\\n"; ! for (int i = 0; i < m_numLabels; i++) code += " p["+i+"] = 0.0; \\\n"; code += " }\n"; code += "#define add_pred("; ! for (int i = 0; i < m_numLabels; i++) code += (i == 0 ? "" : ",") + "X" + i; code += ") { \\\n"; ! for (int i = 0; i < m_numLabels; i++) code += " p["+i+"] += X"+i+"; \\\n"; code += " }\n"; *************** *** 419,423 **** code += "#define finalize_pred() \\\n"; code += " (r ? ( \\\n"; ! for (int i = 0; i < num_classes; i++) code += " r["+i+"] = p["+i+"], \\\n"; code += " p[0]) : p[0])\n"; --- 442,446 ---- code += "#define finalize_pred() \\\n"; code += " (r ? ( \\\n"; ! for (int i = 0; i < m_numLabels; i++) code += " r["+i+"] = p["+i+"], \\\n"; code += " p[0]) : p[0])\n"; *************** *** 430,442 **** code += "" ! + " static private double[] p = new double[" + num_classes + "];\n" + " static private void reset_pred() {\n" + " java.util.Arrays.fill(p, 0.0);\n" + " }\n" + " static private void add_pred("; ! for (int i = 0; i < num_classes; i++) code += (i == 0 ? "" : ",") + "double x" + i; code += ") {\n"; ! for (int i = 0; i < num_classes; i++) code += " p["+i+"] += x"+i+";\n"; code += "" --- 453,465 ---- code += "" ! + " static private double[] p = new double[" + m_numLabels + "];\n" + " static private void reset_pred() {\n" + " java.util.Arrays.fill(p, 0.0);\n" + " }\n" + " static private void add_pred("; ! for (int i = 0; i < m_numLabels; i++) code += (i == 0 ? "" : ",") + "double x" + i; code += ") {\n"; ! for (int i = 0; i < m_numLabels; i++) code += " p["+i+"] += x"+i+";\n"; code += "" *************** *** 457,461 **** try { AbstractBooster ada = ! new DebugWrap(new MulticlassWrapMH(new DebugWrap(new AdaBoost(0.0)), 2)); for(int i=0; i< 10; i++) { --- 480,484 ---- try { AbstractBooster ada = ! new DebugWrap(new MulticlassWrapMH(new DebugWrap(new AdaBoost(0.0)), 2, true)); for(int i=0; i< 10; i++) { Index: AbstractBooster.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/booster/AbstractBooster.java,v retrieving revision 1.2 retrieving revision 1.3 diff -C2 -d -r1.2 -r1.3 *** AbstractBooster.java 19 May 2007 01:57:57 -0000 1.2 --- AbstractBooster.java 29 May 2007 05:03:57 -0000 1.3 *************** *** 81,85 **** // If we have a multilable or multiclass problem, we need to wrap it. if (num_labels > 2 || isMultiLabel) { ! result= new MulticlassWrapMH(result, num_labels); } --- 81,85 ---- // If we have a multilable or multiclass problem, we need to wrap it. if (num_labels > 2 || isMultiLabel) { ! result= new MulticlassWrapMH(result, num_labels, isMultiLabel); } |