|
From: Herton R. K. <he...@us...> - 2005-04-10 20:43:22
|
Update of /cvsroot/kimageprocess/kimageprocess/src/plugins/fann In directory sc8-pr-cvs1.sourceforge.net:/tmp/cvs-serv2680/src/plugins/fann Modified Files: fann.cpp fann.h Log Message: - Enhancements, cleanups, cosmetics and fixes on fann plugin. Index: fann.h =================================================================== RCS file: /cvsroot/kimageprocess/kimageprocess/src/plugins/fann/fann.h,v retrieving revision 1.1 retrieving revision 1.2 diff -u -d -r1.1 -r1.2 --- fann.h 5 Mar 2005 21:14:05 -0000 1.1 +++ fann.h 10 Apr 2005 20:42:58 -0000 1.2 @@ -50,52 +50,50 @@ KTFANNPlugin(QObject *parent, const char *name, const QStringList&); ~KTFANNPlugin(); - + void setupPlugin(); - + int iterations() { return m_iterations; } void setIterations(int value) { m_iterations = value; } - + int hiddenLayers() { return m_hiddenLayers; } void setHiddenLayers(int value) { m_hiddenLayers = value; } - + // Reimplemented from KTClassifBackends void parseResults(const QValueList<float> &results); - + void parseSampleResults(const QValueList<float> &results, int sampleClass); - + private: KAction *m_trainNetwork; KAction *m_classifyImage; - + struct fann *m_network; int m_input; int m_output; float m_connectionRate; float m_learningRate; - + QImage *m_img; int m_width; int m_height; int m_colorStep; - + int m_x; int m_y; int m_iterations; int m_hiddenLayers; - + double m_error; - long m_count; - + QValueList<pattern> m_patterns; - + private slots: void slotTrainNetwork(); // for now assume we are using one test image for each project. // TODO: Fix that (will require API change) void slotClassifyImage(); - }; -#endif +#endif \ No newline at end of file Index: fann.cpp =================================================================== RCS file: /cvsroot/kimageprocess/kimageprocess/src/plugins/fann/fann.cpp,v retrieving revision 1.1 retrieving revision 1.2 diff -u -d -r1.1 -r1.2 --- fann.cpp 5 Mar 2005 21:14:05 -0000 1.1 +++ fann.cpp 10 Apr 2005 20:42:58 -0000 1.2 @@ -35,24 +35,24 @@ K_EXPORT_COMPONENT_FACTORY( kimageprocess_fann, KGenericFactory<KTFANNPlugin>( "kimageprocess_fann" ) ) - + KTFANNPlugin::KTFANNPlugin(QObject *parent, const char* name, const QStringList&) : KTClassifBackend(parent, name) { m_trainNetwork = new KAction(i18n("Train &Network"), 0, this, SLOT(slotTrainNetwork()), actionCollection(), "fann_train_network"); - + m_classifyImage = new KAction(i18n("Classify &Image"), 0, this, SLOT(slotClassifyImage()), actionCollection(), "fann_classify_image"); - + setInstance(KGenericFactory<KTFANNPlugin>::instance()); setXMLFile("kimageprocess_fann.rc"); m_network = 0; m_img = 0; - m_iterations = 10000; + m_iterations = 100000; m_hiddenLayers = 1; m_connectionRate = 1.0; @@ -66,31 +66,59 @@ } void KTFANNPlugin::slotTrainNetwork() -{ +{ if (m_network) fann_destroy(m_network); - + m_input = imageManager()->activeFeaturesCount(); m_output = imageManager()->sampleCount(); - + //create the network - m_network = fann_create(m_connectionRate, m_learningRate, 3, m_input, m_output, m_output); - - m_count = 0; + m_network = fann_create(m_connectionRate, m_learningRate, 3, m_input, m_input, m_output); + fann_set_activation_steepness_hidden(m_network, 1.0); + fann_set_activation_steepness_output(m_network, 1.0); + fann_set_activation_function_hidden(m_network, FANN_SIGMOID_SYMMETRIC_STEPWISE); + fann_set_activation_function_output(m_network, FANN_SIGMOID_SYMMETRIC_STEPWISE); m_error = 0.0; - + //start training imageManager()->generatePatternFile(0, this); - - float outputs[m_output]; - - //prepare the output - for (int i = 0; i < m_output; i++) - outputs[i] = 0; - - int print_debug = 0; + + struct fann_train_data tfanndata; + float *tempf_outputs; + int i, j; + QValueList<pattern>::size_type c_patcount; + + c_patcount = m_patterns.count(); + tempf_outputs = new float[c_patcount * m_output]; + tfanndata.num_data = c_patcount; + tfanndata.num_input = m_input; + tfanndata.num_output = m_output; + tfanndata.input = new float*[c_patcount]; + tfanndata.output = new float*[c_patcount]; + for (i = 0; i < c_patcount; i++) + { + tfanndata.output[i] = tempf_outputs + i * m_output; + for (j = 0; j < m_output; j++) + { + tfanndata.output[i][j] = 0; + } + } + j = 0; QValueList<pattern>::iterator end = m_patterns.end(); - for (int i=0; i < m_iterations; ++i) + for (QValueList<pattern>::iterator it = m_patterns.begin(); it != end; ++it) + { + tfanndata.output[j][(*it).sampleClass-1] = 1; + tfanndata.input[j] = (*it).inputs; + j++; + } + fann_init_weights(m_network, &tfanndata); + fann_train_on_data(m_network, &tfanndata, m_iterations, m_iterations / 100, 0.001); + m_error = fann_get_MSE(m_network); + + /*int print_debug = 0; + QValueList<pattern>::iterator end = m_patterns.end(); + for (i = 0; i < m_iterations; ++i) { for (QValueList<pattern>::iterator it = m_patterns.begin(); it != end; ++it) { @@ -105,40 +133,40 @@ kdDebug() << "Network output average error:" << m_error << endl; print_debug = 0; } - } - + }*/ + //clear data end = m_patterns.end(); for (QValueList<pattern>::iterator it = m_patterns.begin(); it != end; ++it) delete (*it).inputs; + delete tfanndata.input; + delete tempf_outputs; + delete tfanndata.output; m_patterns.clear(); - - } void KTFANNPlugin::slotClassifyImage() { if (!m_network) return; - + KTImage *img = imageManager()->testingImage(); - + m_width = img->width(); m_height = img->height(); m_x = m_y = 0; - + if (m_img) delete m_img; m_img = new QImage(m_width,m_height,32); m_colorStep = 255/(m_output-1); kdDebug() << "Color step is: " << m_colorStep << endl; - + imageManager()->generateTestPatFile(0, this); - + //save the image - KURL dest = KFileDialog::getSaveURL(QString::null,i18n("*.pgm|PGM File"),0,i18n("Save classified image")); - + if (!dest.isEmpty()) { KTempFile tmpImg(QString::null,".pgm"); @@ -150,15 +178,15 @@ } void KTFANNPlugin::setupPlugin() -{ -//connect signals/slots here +{ + //connect signals/slots here } void KTFANNPlugin::parseResults(const QValueList<float> &results) { float *inputs = new float[m_input]; float *outputs; - + QValueList<float>::const_iterator it; int i = 0; kdDebug() << "------------------------" << endl; @@ -168,9 +196,9 @@ kdDebug() << "inputs[" << i << "] = " << (*it) << endl; inputs[i++] = (*it); } - + outputs = fann_run(m_network, inputs); - + int maxclass = 0; float maxvalue = 0; for (i = 0; i < m_output; i++) @@ -185,30 +213,30 @@ int level = maxclass * m_colorStep; //kdDebug() << "Pixel x=" << m_x << " y=" << m_y << " is of class " << maxclass << " with value = " << maxvalue << endl; m_img->setPixel(m_x++,m_y,qRgb(level,level,level)); - + if (m_x >= m_width) { m_x = 0; m_y++; } - + } void KTFANNPlugin::parseSampleResults(const QValueList<float> &results, int sampleClass) { float *inputs = new float[m_input]; - + QValueList<float>::const_iterator it; int i = 0; + //prepare the input for ( it = results.begin(); it != results.end(); ++it) inputs[i++] = (*it); - + pattern pat; pat.inputs = inputs; pat.sampleClass = sampleClass; m_patterns.append(pat); } - -#include "fann.moc" +#include "fann.moc" \ No newline at end of file |