From: <mig...@us...> - 2006-06-01 15:23:51
|
Revision: 6331 Author: miguelrojasch Date: 2006-06-01 08:10:54 -0700 (Thu, 01 Jun 2006) ViewCVS: http://svn.sourceforge.net/cdk/?rev=6331&view=rev Log Message: ----------- Extraction of the correct predicted value object. Double for regression and String for classification. Modified Paths: -------------- trunk/cdk/src/org/openscience/cdk/libio/weka/Weka.java trunk/cdk/src/org/openscience/cdk/qsar/model/weka/IWekaModel.java trunk/cdk/src/org/openscience/cdk/qsar/model/weka/J48WModel.java trunk/cdk/src/org/openscience/cdk/qsar/model/weka/LinearRegressionWModel.java trunk/cdk/src/org/openscience/cdk/test/libio/weka/WekaTest.java trunk/cdk/src/org/openscience/cdk/test/qsar/model/J48WModelTest.java trunk/cdk/src/org/openscience/cdk/test/qsar/model/LinearRegressionWModelTest.java Modified: trunk/cdk/src/org/openscience/cdk/libio/weka/Weka.java =================================================================== --- trunk/cdk/src/org/openscience/cdk/libio/weka/Weka.java 2006-05-31 14:58:54 UTC (rev 6330) +++ trunk/cdk/src/org/openscience/cdk/libio/weka/Weka.java 2006-06-01 15:10:54 UTC (rev 6331) @@ -29,6 +29,7 @@ import java.io.InputStreamReader; import java.io.Reader; import java.io.StringReader; +import java.util.Vector; import weka.classifiers.Classifier; import weka.core.Instance; @@ -67,8 +68,10 @@ /** type of classifier*/ private Classifier classifier = null; - + /** Class for handling an ordered set of weighted instances*/ private Instances instances; + /**String with the attribut class*/ + private String[] classAttrib = null; /** * Constructor of the Weka */ @@ -85,14 +88,73 @@ public Instances setDataset(String pathTable, Classifier classifier) throws Exception{ this.classifier = classifier; InputStream ins = this.getClass().getClassLoader().getResourceAsStream(pathTable); - Reader insr = new InputStreamReader(ins); - instances = new Instances(new BufferedReader(insr)); + BufferedReader insr = new BufferedReader(new InputStreamReader(ins)); + this.classAttrib = extractClass(insr); + + ins = this.getClass().getClassLoader().getResourceAsStream(pathTable); + insr = new BufferedReader(new InputStreamReader(ins)); + instances = new Instances(insr); instances.setClassIndex(instances.numAttributes() - 1); classifier.buildClassifier(instances); return instances; } - /** + /** + * Extract the class name attribute manually from the file * + * @param insr The BufferedReader + * @return Array with the class attributes + */ + private String[] extractClass(BufferedReader input) { + Vector attribV = new Vector(); + String[] classAttrib = null; + String line = ""; + try { + while ((line = input.readLine()) != null) { + if(line.startsWith("@attribute class {")){ + int strlen = line.length(); + String line_ = null; + out: + for (int i = 0; i < strlen; i++){ + switch(line.charAt(i)){ + case '{': + line_ = line.substring(i); + break out; + } + } + StringBuffer edited = new StringBuffer(); + strlen = line_.length(); + edited = new StringBuffer(); + for (int i = 0; i < strlen; i++){ + switch(line_.charAt(i)){ + case '"': + break; + case ',': + attribV.add(edited.toString()); + edited = new StringBuffer(); + break; + case '{': + break; + case '}': + attribV.add(edited.toString()); + break; + default: + edited.append(line_.charAt(i)); + } + } + + } + } + if(attribV.size() > 0){ + classAttrib = new String[attribV.size()]; + attribV.copyInto(classAttrib); + } + } catch (IOException e) { + e.printStackTrace(); + } + return classAttrib; + } + /** + * * Set the array which contains the dataset and the type of classifier. This method * will be used for classifier which work with numerical values. * @@ -125,6 +187,7 @@ */ public Instances setDataset(String[] attrib, int[] typAttrib, String[] classAttrib, Object[]y, Object[][] x, Classifier classifier) throws Exception{ this.classifier = classifier; + this.classAttrib = classAttrib; Reader reader = createAttributes(attrib,typAttrib,classAttrib,y,x); instances = new Instances(reader); instances.setClassIndex(instances.numAttributes() - 1); @@ -138,8 +201,8 @@ * @return Result of the prediction * @throws Exception */ - public double[] getPrediction(Object[][] value) throws Exception{ - double[] results = new double[value.length]; + public Object[] getPrediction(Object[][] value) throws Exception{ + Object[] object = new Object[value.length]; for(int j = 0 ; j < value.length ; j++){ Instance instance = new Instance(instances.numAttributes()); instance.setDataset(instances); @@ -150,9 +213,14 @@ instance.setValue(i, ""+value[j][i]); } instance.setValue(value[0].length, 0.0); - results[j] = classifier.classifyInstance(instance); + double result = classifier.classifyInstance(instance); + if(classAttrib != null){ + object[j] = classAttrib[(new Double(result)).intValue()]; + } + else + object[j] = new Double(result); } - return results; + return object; } /** * Return of the predicted value @@ -161,15 +229,19 @@ * @return Result of the prediction. * @throws Exception */ - public double[] getPrediction(String pathARFF) throws Exception{ + public Object[] getPrediction(String pathARFF) throws Exception{ InputStream ins = this.getClass().getClassLoader().getResourceAsStream(pathARFF); Reader insr = new InputStreamReader(ins); Instances test = new Instances(new BufferedReader(insr)); - double[] result = new double[test.numInstances()]; + Object[] object = new Object[test.numInstances()]; for(int i = 0 ; i < test.numInstances(); i++){ - result[i] = classifier.classifyInstance(test.instance(i)); + double result = classifier.classifyInstance(test.instance(i)); + if(classAttrib != null) + object[i] = classAttrib[(new Double(result)).intValue()]; + else + object[i] = new Double(result); } - return result; + return object; } /** * create a Reader with necessary attributes to iniziate a Instances for weka. @@ -220,5 +292,16 @@ Reader reader = new StringReader(string); return reader; } - + + /** + * get the value which belongs this position in the classification + * @param result Position in the classification + * @return Real value + */ + private double[] getValue(double[] result) { + Instance instance = instances.instance(0); + instance.numClasses(); + return null; + } + } \ No newline at end of file Modified: trunk/cdk/src/org/openscience/cdk/qsar/model/weka/IWekaModel.java =================================================================== --- trunk/cdk/src/org/openscience/cdk/qsar/model/weka/IWekaModel.java 2006-05-31 14:58:54 UTC (rev 6330) +++ trunk/cdk/src/org/openscience/cdk/qsar/model/weka/IWekaModel.java 2006-06-01 15:10:54 UTC (rev 6331) @@ -81,9 +81,9 @@ * This function only returns meaningful results if the <code>predict</code> * method of this class has been called. * - * @return A double[] containing the predicted values + * @return A Object[] containing the predicted values */ - abstract public double[] getPredictPredicted(); + abstract public Object[] getPredictPredicted(); } Modified: trunk/cdk/src/org/openscience/cdk/qsar/model/weka/J48WModel.java =================================================================== --- trunk/cdk/src/org/openscience/cdk/qsar/model/weka/J48WModel.java 2006-05-31 14:58:54 UTC (rev 6330) +++ trunk/cdk/src/org/openscience/cdk/qsar/model/weka/J48WModel.java 2006-06-01 15:10:54 UTC (rev 6331) @@ -15,7 +15,7 @@ * j48.build(); * j48.setParameters(newX); * j48.predict(); - * double[] predictedvalues = j48.getPredictPredicted(); + * String[] predictedvalues = j48.getPredictPredicted(); * } catch (QSARModelException qme) { * System.out.println(qme.toString()); * } @@ -83,7 +83,7 @@ * which contians the variables and attributes with whose to test.*/ private String pathTest = null; /** results of the prediction*/ - private double[] results; + private String[] results = null; /**A Array Object containing the independent variable*/ private Object[][] newX = null; /**A String specifying the path of the file, format arff, @@ -109,7 +109,7 @@ this.x = x; } /** - * Constructor of the J48WModel object from varibles + * Constructor of the J48WModel object from file * @param pathTest Path of the dataset file format arff to train */ public J48WModel(String pathTest){ @@ -205,10 +205,20 @@ */ public void predict() throws QSARModelException { try{ - if(pathNewX != null) - results = weka.getPrediction(pathNewX); - else if(newX != null) - results = weka.getPrediction(newX); + if(pathNewX != null){ + Object[] object = weka.getPrediction(pathNewX); + results = new String[object.length]; + for(int i = 0 ; i < object.length; i++){ + results[i] = (String)object[i]; + } + } + else if(newX != null){ + Object[] object = weka.getPrediction(newX); + results = new String[object.length]; + for(int i = 0 ; i < results.length; i++){ + results[i] = (String)object[i]; + } + } } catch ( Exception e){ e.printStackTrace(); @@ -220,11 +230,10 @@ * This function only returns meaningful results if the <code>predict</code> * method of this class has been called. * - * @return A double[] containing the predicted values + * @return A String[] containing the predicted values */ - public double[] getPredictPredicted() { + public String[] getPredictPredicted() { return results; } - } Modified: trunk/cdk/src/org/openscience/cdk/qsar/model/weka/LinearRegressionWModel.java =================================================================== --- trunk/cdk/src/org/openscience/cdk/qsar/model/weka/LinearRegressionWModel.java 2006-05-31 14:58:54 UTC (rev 6330) +++ trunk/cdk/src/org/openscience/cdk/qsar/model/weka/LinearRegressionWModel.java 2006-06-01 15:10:54 UTC (rev 6331) @@ -14,7 +14,7 @@ * lrm.build(); * lrm.setParameters(newX); * lrm.predict(); - * double[] predictedvalues = lrm.getPredictPredicted(); + * Double[] predictedvalues = lrm.getPredictPredicted(); * * } catch (QSARModelException qme) { * System.out.println(qme.toString()); @@ -71,7 +71,7 @@ * which contians the variables and attributes with whose to test.*/ private String pathTest = null; /** results of the prediction*/ - private double[] results; + private Double[] results; /**A Array Object containing the independent variable*/ private Object[][] newX = null; /**A String specifying the path of the file, format arff, @@ -88,7 +88,7 @@ this.x = x; } /** - * Constructor of the LinearRegressionWModel object from varibles + * Constructor of the LinearRegressionWModel object from file * @param pathTest Path of the dataset file format arff to train */ public LinearRegressionWModel(String pathTest){ @@ -181,10 +181,20 @@ */ public void predict() throws QSARModelException { try{ - if(pathNewX != null) - results = weka.getPrediction(pathNewX); - else if(newX != null) - results = weka.getPrediction(newX); + if(pathNewX != null){ + Object[] object = weka.getPrediction(pathNewX); + results = new Double[object.length]; + for(int i = 0 ; i < object.length; i++){ + results[i] = (Double)object[i]; + } + } + else if(newX != null){ + Object[] object = weka.getPrediction(newX); + results = new Double[object.length]; + for(int i = 0 ; i < object.length; i++){ + results[i] = (Double)object[i]; + } + } } catch ( Exception e){ e.printStackTrace(); @@ -196,9 +206,9 @@ * This function only returns meaningful results if the <code>predict</code> * method of this class has been called. * - * @return A double[] containing the predicted values + * @return A Double[] containing the predicted values */ - public double[] getPredictPredicted() { + public Double[] getPredictPredicted() { return results; } Modified: trunk/cdk/src/org/openscience/cdk/test/libio/weka/WekaTest.java =================================================================== --- trunk/cdk/src/org/openscience/cdk/test/libio/weka/WekaTest.java 2006-05-31 14:58:54 UTC (rev 6330) +++ trunk/cdk/src/org/openscience/cdk/test/libio/weka/WekaTest.java 2006-06-01 15:10:54 UTC (rev 6331) @@ -72,7 +72,7 @@ lr.setOptions(options); Weka weka = new Weka(); weka.setDataset("data/arff/Table1.arff", lr); - double[] result = weka.getPrediction("data/arff/Table2.arff"); + Object[] result = weka.getPrediction("data/arff/Table2.arff"); assertNotNull(result); } /** @@ -92,7 +92,7 @@ Object[][] testX = {{new Double(2),new Double(2)}, {new Double(5),new Double(5)} }; - double[] result = weka.getPrediction(testX); + Object[] result = weka.getPrediction(testX); assertNotNull(result); } /** @@ -118,7 +118,7 @@ Double[][] testX = {{new Double(2),new Double(2)}, {new Double(5),new Double(5)} }; - double[] result = weka.getPrediction(testX); + Object[] result = weka.getPrediction(testX); assertNotNull(result); } /** @@ -144,7 +144,7 @@ weka.setDataset(attrib, typAttrib, classAttrib, y, xD, j48); Double[][] testX = {{new Double(11),new Double(-11),new Double(-11)}, {new Double(-10),new Double(-10),new Double(-10)}}; - double[] resultY = weka.getPrediction(testX); + Object[] resultY = weka.getPrediction(testX); assertNotNull(resultY); } } Modified: trunk/cdk/src/org/openscience/cdk/test/qsar/model/J48WModelTest.java =================================================================== --- trunk/cdk/src/org/openscience/cdk/test/qsar/model/J48WModelTest.java 2006-05-31 14:58:54 UTC (rev 6330) +++ trunk/cdk/src/org/openscience/cdk/test/qsar/model/J48WModelTest.java 2006-06-01 15:10:54 UTC (rev 6331) @@ -84,9 +84,9 @@ j48.setParameters(testX); j48.predict(); - double[] preds = j48.getPredictPredicted(); - assertEquals(preds[0], 1.0, 0.001); - assertEquals(preds[1], 2.0, 0.001); + String[] preds = j48.getPredictPredicted(); + assertEquals(preds[0], "B_"); + assertEquals(preds[1], "C_"); } /** * @@ -104,10 +104,9 @@ {new Double(-10),new Double(-10),new Double(-10)}}; j48.setParameters(testX); j48.predict(); - double[] result = j48.getPredictPredicted(); - assertNotNull(result); - assertEquals(result[0], 1.0, 0.001); - assertEquals(result[1], 2.0, 0.001); + String[] preds = j48.getPredictPredicted(); + assertEquals(preds[0], "B_"); + assertEquals(preds[1], "C_"); } } Modified: trunk/cdk/src/org/openscience/cdk/test/qsar/model/LinearRegressionWModelTest.java =================================================================== --- trunk/cdk/src/org/openscience/cdk/test/qsar/model/LinearRegressionWModelTest.java 2006-05-31 14:58:54 UTC (rev 6330) +++ trunk/cdk/src/org/openscience/cdk/test/qsar/model/LinearRegressionWModelTest.java 2006-06-01 15:10:54 UTC (rev 6331) @@ -90,7 +90,7 @@ lrm.setParameters(newx); lrm.predict(); - double[] preds = lrm.getPredictPredicted(); + Double[] preds = lrm.getPredictPredicted(); assertEquals(preds[0], 1.0, 0.001); assertEquals(preds[1], 4.0, 0.001); } @@ -111,7 +111,7 @@ lrm.build(); lrm.setParameters("data/arff/Table2.arff"); lrm.predict(); - double[] result = lrm.getPredictPredicted(); + Double[] result = lrm.getPredictPredicted(); assertNotNull(result); assertEquals(result[0], 1.0, 0.001); assertEquals(result[1], 4.0, 0.001); This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |