From: Aaron A. <aa...@us...> - 2009-01-03 05:13:38
|
Update of /cvsroot/jboost/jboost/src/jboost/atree In directory fdv4jf1.ch3.sourceforge.com:/tmp/cvs-serv2442/atree Modified Files: AlternatingTree.java Log Message: Added python to the code output options. Index: AlternatingTree.java =================================================================== RCS file: /cvsroot/jboost/jboost/src/jboost/atree/AlternatingTree.java,v retrieving revision 1.2 retrieving revision 1.3 diff -C2 -d -r1.2 -r1.3 *** AlternatingTree.java 2 Oct 2007 02:28:06 -0000 1.2 --- AlternatingTree.java 3 Jan 2009 05:13:32 -0000 1.3 *************** *** 893,898 **** } ! private static final boolean allDefined = false; ! /* true if should assume all attributes are defined in Matlab code */ /** Converts this AlternatingTree to Matlab. Assumes that all --- 893,1087 ---- } ! /** true if should assume all attributes are defined in Matlab code */ ! private boolean allDefined = false; ! ! /** Description of all attributes */ ! private AttributeDescription[] ad = null; ! ! /** Converts this AlternatingTree to Matlab. Assumes that all ! attributes are of type number or finite. If allDefined = true ! then all attributes are assumed to be defined. */ ! public String toPython(String fname, ! ExampleDescription exampleDescription) { ! allDefined = true; ! String code = ""; ! code += "# The map from index to spec file description is:\n\n"; ! code += "# index attr.type data.type name\n" ! + "# ------------------------------------------\n"; ! ad = exampleDescription.getAttributes(); ! for (int i = 0; i < ad.length; i++) { ! String s; ! String key = ""; ! String t = ad[i].getType(); ! if (t.equals("number")) { ! s = "number Double "; ! } else if (t.equals("text")) { ! s = "text String "; ! } else if (t.equals("finite")) { ! s = "finite Integer "; ! for (int j = 0; j < ad[i].getNoOfValues(); j++) ! key += (j == 0 ! ? "# key: " ! : "# ") ! + padInteger(j, 5) + " = " ! + ad[i].getAttributeValue(j) + "\n"; ! } else { ! System.err.println("Warning: unrecognized type for attribute " ! + i + ": " + t); ! s = "??? ??? "; ! } ! code += "# " + padInteger(i, 5) + " " + s + " " + ! ad[i].getAttributeName() + "\n" + key; ! } ! code += "" ! + "#return an array of scores correpsonding to the classes:\n" ! + "# index class name\n" ! + "# ------------------------\n"; ! AttributeDescription la = exampleDescription.getLabelDescription(); ! ! for (int j = 0; j < la.getNoOfValues(); j++) ! code += "# " + padInteger(j, 5) + " " ! + la.getAttributeValue(j) + "\n"; ! ! code += "\n\n" ! + "# This class evaluates a jboost-trained classifier. The " ! + (allDefined ? "" : "first ") + "\n" ! + "# argument is an array of values corresponding to the formatted data\n" ! + "# used during training." ! + (allDefined ! ? "\n" ! : " The second argument is an array of values\n" ! + "# indicating which of the attributes are defined (where non-zero means\n" ! + "# that the corresponding attribute is defined).\n") ! + "# This classifier was automatically generated by jboost on:\n" ! + "# " + (new Date()) + "\n" ! + "class ATree:\n" ! + "\t\"\"\"\n" ! + "\tThe ATree class provides access to scoring algorithms" ! + "\tderived from an ATree data structure. This class only\n" ! + "\tprovides scores and does not learn the structure.\n" ! + "\t\"\"\"\n" ! + "\t\n" ! + "\tdef __init__(self, datafile, specfile):\n" ! + "\t\tself.datafile = datafile\n" ! + "\t\tself.specfile = specfile\n" ! + "\t\tself.data = self.__read_in_data()\n" ! + "\t\tself.spec_dict = self.__read_in_spec()\n" ! + "\t\n" ! + "\tdef get_scores(self):\n" ! + "\t\tret = []\n" ! + "\t\tfor i,x in enumerate(self.data):\n" ! + "\t\t\tret.append(self.__score(i))\n" ! + "\t\treturn ret\n" ! + "\t\n" ! + "\tdef __score(self, i):\n" ! + "\t\treturn self.__predict(i)\n" ! + "\t\n" ! + "\tdef get_data_value(self, x, feature):\n" ! + "\t\tidx = self.spec_dict[feature]\n" ! + "\t\treturn float(x[idx])\n" ! + "\t\n" ! + "\tdef __read_in_data(self):\n" ! + "\t\t\"\"\"\n" ! + "\t\tReads in the datafile.\n" ! + "\t\t\"\"\"\n" ! + "\t\tf = open(self.datafile)\n" ! + "\t\tlines = f.readlines()\n" ! + "\t\tf.close()\n" ! + "\t\tret = []\n" ! + "\t\tfor line in lines:\n" ! + "\t\t\tline = line[:-2]\n" ! + "\t\t\tret.append([x.strip() for x in line.split(',')])\n" ! + "\t\treturn ret\n" ! + "\t\n" ! + "\tdef __read_in_spec(self):\n" ! + "\t\t\"\"\"\n" ! + "\t\tReturns a dictionary with feature names => data index\n" ! + "\t\t\"\"\"\n" ! + "\t\tf = open(self.specfile)\n" ! + "\t\tlines = f.readlines()\n" ! + "\t\tf.close()\n" ! + "\t\tret = {}\n" ! + "\t\ti = 0\n" ! + "\t\tfor line in lines:\n" ! + "\t\t\tif '=' in line:\n" ! + "\t\t\t\tcontinue\n" ! + "\t\t\tdescription = line.split()[0]\n" ! + "\t\t\tret[description] = i\n" ! + "\t\t\ti += 1\n" ! + "\t\treturn ret\n" ! + "\t\n" ! + "\tdef __predict(self, i" ! + (allDefined ? "" : ", def") + "):\n" ! + "\t\tx=self.data[i]\n" ! + "\t\tpred = [0 for i in range(" + root.prediction.toCodeArray().length + ")]\n" ! + makePythonCode(root, "\t\t") ! + "\t\treturn pred" ! + "\n" ! + "\n" ! + "import sys\n" ! + "def main():\n" ! + "\tdatafile = sys.argv[1]\n" ! + "\tspecfile = sys.argv[2]\n" ! + "\tatree = ATree(datafile, specfile)\n" ! + "\tscores = atree.get_scores()\n" ! + "\tprint scores\n" ! + "if __name__=='__main__':\n" ! + "\tmain()\n" ! + "\n" ! + "\n"; ! return code; ! } ! ! ! private String makePythonCode(PredictorNode pn, String tab) { ! String code = ""; ! code += tab + "newpred = [ # " + pn.id + "\n"; ! double[] v = pn.prediction.toCodeArray(); ! for (int i = 0; i < v.length; i++) ! code += tab + " " + v[i] + ! (i < v.length - 1 ? "\n" : "]\n"); ! code += tab + "pred = [n+p for n,p in zip(newpred,pred)]\n"; ! int size = pn.splitterNodes.size(); ! for (int i = 0; i < size; i++) ! code += makePythonCode((SplitterNode)pn.splitterNodes.get(i), tab); ! return code; ! } ! ! private String makePythonCode(SplitterNode sn, String tab) { ! String code = ""; ! ! Summary summary=sn.splitter.getSummary(); ! ! code += tab + (allDefined ! ? " " ! : "if def(" + (summary.index+1) + "):") ! + " # " + sn.id + "\n"; ! String stab = (allDefined ? tab : tab); ! switch(summary.type) { ! case Summary.EQUALITY: ! code += stab + "if self.get_data_value(x,'" + ad[(summary.index)].getAttributeName() + "') == " + ! ((Integer)summary.val) + ":\n"; ! code += makeMatlabCode(sn.predictorNodes[0], stab + "\t"); ! code += stab + "else:\n"; ! code += makeMatlabCode(sn.predictorNodes[1], stab + "\t"); ! break; ! case Summary.LESS_THAN: ! code += stab + "if self.get_data_value(x,'" + ad[(summary.index)].getAttributeName() + "') <= " + ! ((Double) summary.val) + ":\n"; ! code += makePythonCode(sn.predictorNodes[0], stab + "\t"); ! code += stab + "else:\n"; ! code += makePythonCode(sn.predictorNodes[1], stab + "\t"); ! break; ! default: ! throw new RuntimeException("Type of split not allowed"); ! } ! ! return code; ! } ! ! ! ! /** Converts this AlternatingTree to Matlab. Assumes that all |