From: Aaron A. <aa...@us...> - 2007-05-30 07:41:59
|
Update of /cvsroot/jboost/jboost/src/jboost/booster In directory sc8-pr-cvs6.sourceforge.net:/tmp/cvs-serv14220 Modified Files: MulticlassWrapMH.java Log Message: muliclass and multilabel margins bug fixed Index: MulticlassWrapMH.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/booster/MulticlassWrapMH.java,v retrieving revision 1.2 retrieving revision 1.3 diff -C2 -d -r1.2 -r1.3 *** MulticlassWrapMH.java 29 May 2007 05:03:57 -0000 1.2 --- MulticlassWrapMH.java 30 May 2007 07:41:55 -0000 1.3 *************** *** 360,384 **** } ! public double[] getMargins(Label l) { ! double[] ret; ! if (m_isMultiLabel) { ! ret = new double[m_numLabels]; ! } else { ! ret = new double[1]; ! } for (int j = 0; j < m_numLabels; j++) { ! if (m_isMultiLabel) { ! ret[j] = preds[j].getMargins(new Label(l.getMultiValue(j) ? ! 1 : 0))[0]; ! } else { ! if (l.getMultiValue(j)){ ! ret[0] = preds[j].getMargins(new Label(l.getMultiValue(j) ? ! 1 : 0))[0]; ! } } } return ret; } public double[] getClassScores() { double[] scores = new double[m_numLabels]; --- 360,399 ---- } ! public double[] getMarginsSingleLabel(Label l) { ! int maxClass = -1; ! double maxScore = Double.MIN_VALUE; ! int thisClass = -1; ! double thisScore = 0; for (int j = 0; j < m_numLabels; j++) { ! double predScore = preds[j].getClassScores()[1]; ! if (l.getMultiValue(j)){ ! thisClass = j; ! thisScore = predScore; ! } else if (maxScore < predScore) { ! maxScore = predScore; ! maxClass = j; } } + double[] ret = new double[1]; + ret[0] = thisScore - maxScore; + return ret; + } + + public double[] getMarginsMultiLabel(Label l) { + double[] ret = new double[m_numLabels]; + for (int j = 0; j < m_numLabels; j++) { + ret[j] = preds[j].getMargins(new Label(l.getMultiValue(j) ? + 1 : 0))[0]; + } return ret; } + + public double[] getMargins(Label l) { + if (m_isMultiLabel) + return getMarginsMultiLabel(l); + return getMarginsSingleLabel(l); + } + public double[] getClassScores() { double[] scores = new double[m_numLabels]; |