From: Yoav F. <yf...@us...> - 2008-09-03 17:32:31
|
Update of /cvsroot/jboost/jboost/src/jboost/visualization In directory sc8-pr-cvs17.sourceforge.net:/tmp/cvs-serv6664/src/jboost/visualization Added Files: DataSet.java HistogramFrame.java Log Message: Adding Histogram visualization code. --- NEW FILE: DataSet.java --- package mljava.visualization; import java.util.Random; import java.util.TreeSet; import org.jfree.data.xy.XYDataset; import org.jfree.data.xy.XYSeries; import org.jfree.data.xy.XYSeriesCollection; /** * @author yoavfreund * A data structure that stores a set of scores and labels for examples. Used to store the data for {@HistogramFrame}. */ public class DataSet { private TreeSet<DataElement> data; public DataSet(int size) { data = new TreeSet<DataElement>(); Random generator = new Random(12345678L); for (int i = 0; i < 2*size; i++) { DataElement element = new DataElement(); int label = (i<size)?-1:1; element.value = generator.nextGaussian() + label/2.0+5; element.index = i; element.label = label; data.add(element); } } public DataSet(int[] indexes, double[] values, int[] labels) { data = new TreeSet<DataElement>(); for (int i = 0; i < indexes.length; i++) { DataElement element = new DataElement(); element.value = values[i]; element.index = indexes[i]; element.label = labels[i]; data.add(element); } } public double getMin() { return data.first().value; } public double getMax() { return data.last().value; } public double[] getFPTP(double v) { DataElement e = data.tailSet(new DataElement(v)).first(); double[] answer = {e.FPR,e.TPR}; return answer; } public double[] getSet(int label) { Object[] a = (Object[]) data.toArray(); int count=0; for (int i=0; i<a.length; i++) { DataElement e = ((DataElement) a[i]); if(e.label == label) { count++; } } if(count==0) {return new double[] {0.0};} double[] answer = new double[count]; count=0; for (int i=0; i<a.length; i++) { DataElement e = ((DataElement) a[i]); if(e.label == label) { answer[count] = e.value; count++; } } return answer; } public XYDataset generateRoC(int neg_label, int pos_label) { XYSeries roc = new XYSeries("ROC"); Object[] a = (Object[]) data.toArray(); int neg_count=0; int pos_count=0; for(int i=0; i<a.length; i++) { DataElement e = ((DataElement) a[i]); if(e.label == neg_label) neg_count++; else if(e.label == pos_label) pos_count++; } double total_neg=neg_count; double total_pos=pos_count; neg_count=0; pos_count=0; for(int i=a.length-1; i >= 0; i--) { DataElement e = ((DataElement) a[i]); if(e.label == neg_label) neg_count++; else if(e.label == pos_label) pos_count++; e.TPR=pos_count/total_pos; e.FPR=neg_count/total_neg; // System.out.printf("%f: %f,%f%n",e.value,e.FPR,e.TPR); roc.add(e.FPR, e.TPR); } XYSeriesCollection dataset = new XYSeriesCollection(); dataset.addSeries(roc); return dataset; } public static void main(String[] args) { DataSet test = new DataSet(50); test.generateRoC(-1,1); double[] a = test.getFPTP(5.0); System.out.printf("%f; %f%n",a[0],a[1]); System.out.println("yoav was here!"); } } class DataElement implements Comparable{ protected double value; protected int label; protected int index; protected double FPR,TPR; public DataElement(double v) { value=v; } public DataElement() {} public int compareTo(Object that) { return (int) Math.signum(this.value-((DataElement) that).value); } } --- NEW FILE: HistogramFrame.java --- package mljava.visualization; import java.awt.Color; import javax.swing.BoxLayout; import javax.swing.JComponent; import javax.swing.JEditorPane; import javax.swing.JPanel; import javax.swing.JSlider; import javax.swing.JSplitPane; import javax.swing.JTextArea; import javax.swing.WindowConstants; import org.jdesktop.layout.GroupLayout; import org.jdesktop.layout.LayoutStyle; import org.jfree.chart.ChartFactory; import org.jfree.chart.ChartPanel; import org.jfree.chart.JFreeChart; import org.jfree.chart.plot.IntervalMarker; import org.jfree.chart.plot.PlotOrientation; import org.jfree.chart.plot.ValueMarker; import org.jfree.chart.plot.XYPlot; import org.jfree.chart.renderer.xy.XYBarRenderer; import org.jfree.data.Range; import org.jfree.data.general.Dataset; import org.jfree.data.statistics.HistogramDataset; import org.jfree.data.xy.IntervalXYDataset; import org.jfree.data.xy.XYDataset; import org.jfree.ui.Layer; import javax.swing.SwingUtilities; import javax.swing.event.ChangeEvent; import javax.swing.event.ChangeListener; /** * @author yoavfreund * A class based on swing and jFreeChart that implements a frame for visualizing an ROC and a histogram * for a two-class distribution. Used to visualize the scores distribution generated by boosting. */ public class HistogramFrame extends javax.swing.JFrame { private static final long serialVersionUID = 1L; private JSplitPane jSplitPane1; private JSlider jSlider1; private JSlider jSlider2; private JPanel jPanel1; private JPanel jPanel2; private JFreeChart histogramChart; private ChartPanel histogramPanel; private JFreeChart rocChart; private ChartPanel rocPanel; private static double upper_limit; private static double lower_limit; private static IntervalMarker histMarker; private static ValueMarker lower_tprMarker, lower_fprMarker; //markers for ROC graph private static ValueMarker upper_tprMarker, upper_fprMarker; //markers for ROC graph private static DataSet rawData; private static final int SampleSize = 10000; /** * Auto-generated main method to display this JFrame */ public static void main(String[] args) { final DataSet dataset = new DataSet(SampleSize); SwingUtilities.invokeLater(new Runnable() { public void run() { HistogramFrame inst = new HistogramFrame(dataset); inst.setLocationRelativeTo(null); inst.setVisible(true); } }); } public HistogramFrame(DataSet dataset) { super(); rawData = dataset; initGUI(); } private void initGUI() { try { GroupLayout thisLayout = new GroupLayout((JComponent)getContentPane()); getContentPane().setLayout(thisLayout); thisLayout.setVerticalGroup(thisLayout.createSequentialGroup() .add(getJSplitPane1(), 0, 407, Short.MAX_VALUE) .addPreferredGap(LayoutStyle.RELATED)); thisLayout.setHorizontalGroup(thisLayout.createSequentialGroup() .add(getJSplitPane1(), 0, 632, Short.MAX_VALUE)); setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); pack(); this.setSize(632, 434); } catch (Exception e) { e.printStackTrace(); } } private JSplitPane getJSplitPane1() { if(jSplitPane1 == null) { jSplitPane1 = new JSplitPane(); { jPanel1 = new JPanel(); jSplitPane1.add(jPanel1, JSplitPane.RIGHT); jPanel1.setPreferredSize(new java.awt.Dimension(263, 393)); BoxLayout jPanel1Layout = new BoxLayout(jPanel1, javax.swing.BoxLayout.Y_AXIS); jPanel1.setLayout(jPanel1Layout); { IntervalXYDataset dataset = createHistogramDataset(); histogramChart = createHistChart(dataset); histogramPanel = new ChartPanel(histogramChart); jPanel1.add(histogramPanel); histogramPanel.setPopupMenu(null); } } { jPanel2 = new JPanel(); jSplitPane1.add(jPanel2, JSplitPane.LEFT); jPanel2.setPreferredSize(new java.awt.Dimension(291, 393)); BoxLayout jPanel2Layout = new BoxLayout(jPanel2, javax.swing.BoxLayout.Y_AXIS); jPanel2.setLayout(jPanel2Layout); { rocChart = createRocChart(rawData.generateRoC(-1, 1)); rocPanel = new ChartPanel(rocChart); jPanel2.add(rocPanel); rocPanel.setPopupMenu(null); } } { jSlider1 = new JSlider(); jPanel1.add(jSlider1); jSlider1.setLayout(null); jSlider1.addChangeListener(new ChangeListener() { public void stateChanged(ChangeEvent evt) { int pos = (int)jSlider1.getValue(); double v=lower_limit+pos*(upper_limit-lower_limit)/100.0; histMarker.setEndValue(v); if (!jSlider1.getValueIsAdjusting()) { double[] FPTP = rawData.getFPTP(v); upper_fprMarker.setValue(FPTP[0]); upper_tprMarker.setValue(FPTP[1]); } } }); } { jSlider2 = new JSlider(); jPanel1.add(jSlider2); jSlider2.setLayout(null); jSlider2.addChangeListener(new ChangeListener() { public void stateChanged(ChangeEvent evt) { int pos = (int)jSlider2.getValue(); double v=lower_limit+pos*(upper_limit-lower_limit)/100.0; histMarker.setStartValue(v); if (!jSlider2.getValueIsAdjusting()) { double[] FPTP = rawData.getFPTP(v); lower_fprMarker.setValue(FPTP[0]); lower_tprMarker.setValue(FPTP[1]); } } }); } } return jSplitPane1; } /** * Creates a sample * * @return the dataset. */ private static IntervalXYDataset createHistogramDataset() { HistogramDataset dataset = new HistogramDataset(); double[] negSet = rawData.getSet(-1); double[] posSet = rawData.getSet(1); System.out.printf("negSet:%d,posSet:%d%n",negSet.length,posSet.length); dataset.addSeries("H1", negSet, 100, rawData.getMin(), rawData.getMax()); dataset.addSeries("H2", posSet, 100, rawData.getMin(), rawData.getMax()); return dataset; } /** * Creates a chart. * * @param dataset a dataset. * * @return The chart. */ private static JFreeChart createHistChart(IntervalXYDataset dataset) { JFreeChart chart = ChartFactory.createHistogram( "Histogram", null, null, dataset, PlotOrientation.VERTICAL, false, /* No Legend */ true, false ); XYPlot plot = (XYPlot) chart.getPlot(); Range range = plot.getDataRange(plot.getDomainAxis()); System.out.println(range); upper_limit = range.getUpperBound(); lower_limit = range.getLowerBound(); plot.setForegroundAlpha(0.85f); XYBarRenderer renderer = (XYBarRenderer) plot.getRenderer(); renderer.setDrawBarOutline(false); histMarker = new IntervalMarker((lower_limit+upper_limit)/2.0,(lower_limit+upper_limit)/2.0); plot.addDomainMarker(histMarker, Layer.BACKGROUND); return chart; } private static JFreeChart createRocChart(XYDataset dataset) { JFreeChart chart = ChartFactory.createXYLineChart( "ROC", // chart title "False positive rate", // x axis label "True positive rate", // y axis label dataset, // data PlotOrientation.VERTICAL, false, // include legend true, // tooltips false // urls ); XYPlot plot = (XYPlot) chart.getPlot(); plot.setBackgroundPaint(Color.lightGray); plot.setDomainGridlinePaint(Color.white); plot.setRangeGridlinePaint(Color.white); lower_tprMarker = new ValueMarker(0.5); lower_tprMarker.setPaint(Color.blue); lower_fprMarker = new ValueMarker(0.5); lower_fprMarker.setPaint(Color.blue); plot.addRangeMarker(lower_tprMarker); plot.addDomainMarker(lower_fprMarker); upper_tprMarker = new ValueMarker(0.5); upper_tprMarker.setPaint(Color.red); upper_fprMarker = new ValueMarker(0.5); upper_fprMarker.setPaint(Color.red); plot.addRangeMarker(upper_tprMarker); plot.addDomainMarker(upper_fprMarker); return chart; } } |