|
From: <sy...@us...> - 2009-08-06 01:09:21
|
Revision: 20864
http://personalrobots.svn.sourceforge.net/personalrobots/?rev=20864&view=rev
Author: syrnick
Date: 2009-08-06 01:09:14 +0000 (Thu, 06 Aug 2009)
Log Message:
-----------
Added an explicit delay between the messages for the bagserver
Created full test for the functional_m3n classifier. It runs training, reloads the model, makes predictions and computes performance.
Modified Paths:
--------------
pkg/trunk/sandbox/functional_m3n_ros/CMakeLists.txt
pkg/trunk/sandbox/functional_m3n_ros/include/functional_m3n_ros/m3n_prediction_node.h
pkg/trunk/sandbox/functional_m3n_ros/manifest.xml
pkg/trunk/sandbox/functional_m3n_ros/src/m3n_learning_node.cpp
pkg/trunk/sandbox/functional_m3n_ros/src/m3n_prediction_node.cpp
pkg/trunk/sandbox/functional_m3n_ros/test_data/m3n_predictor.launch
pkg/trunk/util/bagserver/src/bagserver_srv.py
Added Paths:
-----------
pkg/trunk/sandbox/functional_m3n_ros/performance_notes.txt
pkg/trunk/sandbox/functional_m3n_ros/srv/QueryPerformanceStats.srv
pkg/trunk/sandbox/functional_m3n_ros/srv/SetModel.srv
pkg/trunk/sandbox/functional_m3n_ros/test/test_full_training.launch
pkg/trunk/sandbox/functional_m3n_ros/test/test_full_training.xml
pkg/trunk/sandbox/functional_m3n_ros/test/test_full_training_1.py
Property Changed:
----------------
pkg/trunk/sandbox/functional_m3n_ros/test_data/
Modified: pkg/trunk/sandbox/functional_m3n_ros/CMakeLists.txt
===================================================================
--- pkg/trunk/sandbox/functional_m3n_ros/CMakeLists.txt 2009-08-06 00:59:27 UTC (rev 20863)
+++ pkg/trunk/sandbox/functional_m3n_ros/CMakeLists.txt 2009-08-06 01:09:14 UTC (rev 20864)
@@ -35,3 +35,13 @@
rospack_add_executable (m3n_learning_node src/m3n_learning_node.cpp)
target_link_libraries (m3n_learning_node functional_m3n)
rospack_link_boost(m3n_learning_node filesystem)
+
+
+rospack_download_test_data(http://pr.willowgarage.com/data/${PROJECT_NAME}/pcd_all_1.bag test_data/pcd_all_1.bag)
+rospack_download_test_data(http://pr.willowgarage.com/data/${PROJECT_NAME}/pcd_all_1.index test_data/pcd_all_1.index)
+rospack_download_test_data(http://pr.willowgarage.com/data/${PROJECT_NAME}/pcd_test_1.bag test_data/pcd_test_1.bag)
+rospack_download_test_data(http://pr.willowgarage.com/data/${PROJECT_NAME}/pcd_train_1.bag test_data/pcd_train_1.bag)
+
+#rospack_add_rostest(test/test_full_training.xml)
+
+
Modified: pkg/trunk/sandbox/functional_m3n_ros/include/functional_m3n_ros/m3n_prediction_node.h
===================================================================
--- pkg/trunk/sandbox/functional_m3n_ros/include/functional_m3n_ros/m3n_prediction_node.h 2009-08-06 00:59:27 UTC (rev 20863)
+++ pkg/trunk/sandbox/functional_m3n_ros/include/functional_m3n_ros/m3n_prediction_node.h 2009-08-06 01:09:14 UTC (rev 20864)
@@ -55,6 +55,9 @@
#include <functional_m3n/m3n_model.h>
+#include <functional_m3n_ros/SetModel.h>
+#include <functional_m3n_ros/QueryPerformanceStats.h>
+
namespace m3n
{
@@ -67,7 +70,12 @@
PredictionNode();
void cloudCallback(const sensor_msgs::PointCloudConstPtr& the_cloud);
+ bool setModel(functional_m3n_ros::SetModel::Request &req,
+ functional_m3n_ros::SetModel::Response &res );
+ bool queryPerformanceStats(
+ functional_m3n_ros::QueryPerformanceStats::Request &req,
+ functional_m3n_ros::QueryPerformanceStats::Response &res );
boost::shared_ptr<PtCloudRFCreator> rf_creator_;
bool use_colors_;
@@ -76,9 +84,25 @@
M3NModel m3n_model2;
std::string model_file_name_;
+ std::string ground_truth_channel_name_;
+
ros::Subscriber cloud_sub_;
ros::Publisher cloud_pub_;
+ ros::ServiceServer set_model_svc_;
+ ros::ServiceServer query_perf_stats_svc_;
+
+ unsigned int nbr_correct;
+ unsigned int nbr_gt;
+ double total_accuracy;
+
+ void computeClassificationRates(const vector<float>& inferred_labels, const vector<float>& gt_labels,
+ const vector<unsigned int>& labels,
+ unsigned int& nbr_correct,
+ unsigned int& nbr_gt,
+ double& accuracy);
+
+
};
Modified: pkg/trunk/sandbox/functional_m3n_ros/manifest.xml
===================================================================
--- pkg/trunk/sandbox/functional_m3n_ros/manifest.xml 2009-08-06 00:59:27 UTC (rev 20863)
+++ pkg/trunk/sandbox/functional_m3n_ros/manifest.xml 2009-08-06 01:09:14 UTC (rev 20864)
@@ -16,6 +16,7 @@
<depend package="opencv_latest"/>
<depend package="object_names"/>
+ <depend package="bagserver"/>
</package>
Added: pkg/trunk/sandbox/functional_m3n_ros/performance_notes.txt
===================================================================
--- pkg/trunk/sandbox/functional_m3n_ros/performance_notes.txt (rev 0)
+++ pkg/trunk/sandbox/functional_m3n_ros/performance_notes.txt 2009-08-06 01:09:14 UTC (rev 20864)
@@ -0,0 +1,4 @@
+Train: pcd_train_1.bag
+Train acc: 0.892634
+Test acc: 0.709794
+
Modified: pkg/trunk/sandbox/functional_m3n_ros/src/m3n_learning_node.cpp
===================================================================
--- pkg/trunk/sandbox/functional_m3n_ros/src/m3n_learning_node.cpp 2009-08-06 00:59:27 UTC (rev 20863)
+++ pkg/trunk/sandbox/functional_m3n_ros/src/m3n_learning_node.cpp 2009-08-06 01:09:14 UTC (rev 20864)
@@ -114,7 +114,6 @@
{
ROS_INFO("Received learning command, starting to learn");
-
boost::filesystem::path model_base_path(model_file_path_);
if(~boost::filesystem::exists(model_base_path))
{
Modified: pkg/trunk/sandbox/functional_m3n_ros/src/m3n_prediction_node.cpp
===================================================================
--- pkg/trunk/sandbox/functional_m3n_ros/src/m3n_prediction_node.cpp 2009-08-06 00:59:27 UTC (rev 20863)
+++ pkg/trunk/sandbox/functional_m3n_ros/src/m3n_prediction_node.cpp 2009-08-06 01:09:14 UTC (rev 20864)
@@ -61,17 +61,54 @@
cloud_pub_ = n_.advertise<PointCloud>("predictions_cloud",100);
+ n_.param(std::string("~ground_truth_channel"),ground_truth_channel_name_,std::string("NONE"));
+
if (m3n_model2.loadFromFile(model_file_name_) < 0)
{
ROS_ERROR("couldnt load model");
throw "ERR";
}
+
+ set_model_svc_ = n_.advertiseService(std::string("SetModel"), &PredictionNode::setModel,this);
+
+ query_perf_stats_svc_ = n_.advertiseService(std::string("Performance"), &PredictionNode::queryPerformanceStats,this);
+
+ nbr_correct=0;
+ nbr_gt=0;
+ total_accuracy=0.0;
+
ROS_INFO("Init done.");
+
}
+bool PredictionNode::queryPerformanceStats(
+ functional_m3n_ros::QueryPerformanceStats::Request &req,
+ functional_m3n_ros::QueryPerformanceStats::Response &res )
+{
+ res.accuracy = total_accuracy;
+ res.correct_weight = (double)nbr_correct;
+ res.checked_weight = (double)nbr_gt;
+ return true;
+}
+bool PredictionNode::setModel(functional_m3n_ros::SetModel::Request &req,
+ functional_m3n_ros::SetModel::Response &res )
+{
+ model_file_name_=req.model_reference;
+ ROS_INFO_STREAM("Loading new model "<<model_file_name_);
+ if (m3n_model2.loadFromFile(model_file_name_) < 0)
+ {
+ ROS_ERROR("couldnt load model");
+ }
+ nbr_correct=0;
+ nbr_gt=0;
+ total_accuracy=0.0;
+ return true;
+}
+
+
void PredictionNode::cloudCallback(const PointCloudConstPtr& the_cloud)
{
@@ -154,13 +191,116 @@
}
}
+
+ int chan_gt=cloud_geometry::getChannelIndex(the_cloud,ground_truth_channel_name_);
+ if(chan_gt!=-1)
+ {
+ unsigned int nbr_correct_in_pcd;
+ unsigned int nbr_gt_in_pcd;
+ double accuracy;
+ computeClassificationRates( cloud_out.chan[chan_predictions_id].vals,
+ the_cloud->chan[chan_gt].vals,
+ m3n_model2.getTrainingLabels(),
+ nbr_correct_in_pcd,
+ nbr_gt_in_pcd,
+ accuracy);
+ nbr_correct += nbr_correct_in_pcd;
+ nbr_gt += nbr_gt_in_pcd;
+ if(nbr_gt==0)
+ {
+ total_accuracy=0.0;
+ }
+ else
+ {
+ total_accuracy = static_cast<double>(nbr_correct)/static_cast<double>(nbr_gt);
+ }
+ ROS_INFO("Total correct: %u / %u = %f", nbr_correct, nbr_gt, total_accuracy);
+
+ }
+
+
cloud_pub_.publish(cloud_out);
}
+void PredictionNode::computeClassificationRates(const vector<float>& inferred_labels, const vector<float>& gt_labels,
+ const vector<unsigned int>& labels,
+ unsigned int& nbr_correct,
+ unsigned int& nbr_gt,
+ double& accuracy)
+ {
+ // Initialize counters for each label
+ // (map: label -> counter)
+ std::map<unsigned int, unsigned int> total_label_count; // how many nodes with gt label
+ std::map<unsigned int, unsigned int> correct_label_count; // how many correctly classified
+ std::map<unsigned int, unsigned int> false_pos_label_count; // how many wrongly classified
+ for (unsigned int i = 0 ; i < labels.size() ; i++)
+ {
+ total_label_count[labels[i]] = 0;
+ correct_label_count[labels[i]] = 0;
+ false_pos_label_count[labels[i]] = 0;
+ }
+ // Holds the total number of nodes correctly classified
+ nbr_correct = 0;
+ nbr_gt = 0;
+
+ // Count the total and per-label number correctly classified
+ unsigned int curr_node_id = 0;
+ unsigned int curr_gt_label = 0;
+ unsigned int curr_infer_label = 0;
+ vector<unsigned int>::const_iterator iter_predictions;
+ for (unsigned int i=0;i<inferred_labels.size();i++)
+ {
+ curr_gt_label = (unsigned int)gt_labels[i];
+ curr_infer_label = (unsigned int)inferred_labels[i];
+
+ total_label_count[curr_gt_label]++;
+ if(curr_gt_label==0)
+ {
+ continue;
+ }
+
+ nbr_gt++;
+ if (curr_gt_label == curr_infer_label)
+ {
+ nbr_correct++;
+ correct_label_count[curr_gt_label]++;
+ }
+ else
+ {
+ false_pos_label_count[curr_infer_label]++;
+ }
+ }
+
+ // Print statistics
+ if(nbr_gt==0)
+ {
+ accuracy=0.0;
+ }
+ else
+ {
+ accuracy = static_cast<double>(nbr_correct)/static_cast<double>(nbr_gt);
+ }
+
+
+ ROS_INFO("Total correct: %u / %u = %f", nbr_correct, nbr_gt, accuracy);
+ stringstream ss;
+ ss << "Label distribution: ";
+ unsigned int curr_label = 0;
+ for (unsigned int i = 0 ; i < labels.size() ; i++)
+ {
+ curr_label = labels[i];
+ ss << "[" << curr_label << ": " << correct_label_count[curr_label] << "/"
+ << total_label_count[curr_label] << " (" << false_pos_label_count[curr_label] << ")] ";
+ }
+ ROS_INFO("%s", ss.str().c_str());
+ }
+
+
+
/* ---[ */
int
main (int argc, char** argv)
Added: pkg/trunk/sandbox/functional_m3n_ros/srv/QueryPerformanceStats.srv
===================================================================
--- pkg/trunk/sandbox/functional_m3n_ros/srv/QueryPerformanceStats.srv (rev 0)
+++ pkg/trunk/sandbox/functional_m3n_ros/srv/QueryPerformanceStats.srv 2009-08-06 01:09:14 UTC (rev 20864)
@@ -0,0 +1,4 @@
+---
+float64 accuracy
+float64 correct_weight
+float64 checked_weight
\ No newline at end of file
Added: pkg/trunk/sandbox/functional_m3n_ros/srv/SetModel.srv
===================================================================
--- pkg/trunk/sandbox/functional_m3n_ros/srv/SetModel.srv (rev 0)
+++ pkg/trunk/sandbox/functional_m3n_ros/srv/SetModel.srv 2009-08-06 01:09:14 UTC (rev 20864)
@@ -0,0 +1,4 @@
+string model_type
+string model_name
+string model_reference
+---
Added: pkg/trunk/sandbox/functional_m3n_ros/test/test_full_training.launch
===================================================================
--- pkg/trunk/sandbox/functional_m3n_ros/test/test_full_training.launch (rev 0)
+++ pkg/trunk/sandbox/functional_m3n_ros/test/test_full_training.launch 2009-08-06 01:09:14 UTC (rev 20864)
@@ -0,0 +1,24 @@
+<launch>
+
+
+ <node pkg="functional_m3n_ros" type="m3n_learning_node" name="fm3n_training" output="screen">
+ <param name="model_file_path" value="$(find functional_m3n_ros)/test_data/test_model_root/"/>
+ </node>
+
+ <node pkg="functional_m3n_ros" type="m3n_prediction_node" name="fm3n_predictor" output="screen">
+ <remap from="cloud" to="/hist/training_cloud"/>
+ </node>
+ <node pkg="bagserver" type="bagserver_srv.py" name="hist_server" output="screen">
+ <param name="namespace" value="hist"/>
+ <param name="index" value="$(find functional_m3n_ros)/test_data/pcd_all_1.index"/>
+ <param name="message_publishing_delay" value="0.5"/>
+ </node>
+
+
+ <node pkg="rosrecord" type="rosplay" args="-s 1 $(find functional_m3n_ros)/test_data/pcd_train_1.bag" output="screen"/>
+
+ <node pkg="functional_m3n_ros" type="test_full_training_1.py" name="test_executive" output="screen"/>
+
+
+
+</launch>
Added: pkg/trunk/sandbox/functional_m3n_ros/test/test_full_training.xml
===================================================================
--- pkg/trunk/sandbox/functional_m3n_ros/test/test_full_training.xml (rev 0)
+++ pkg/trunk/sandbox/functional_m3n_ros/test/test_full_training.xml 2009-08-06 01:09:14 UTC (rev 20864)
@@ -0,0 +1,24 @@
+<launch>
+
+
+ <node pkg="functional_m3n_ros" type="m3n_learning_node" name="fm3n_training" output="screen">
+ <param name="model_file_path" value="$(find functional_m3n_ros)/test_data/test_model_root/"/>
+ </node>
+
+ <node pkg="functional_m3n_ros" type="m3n_prediction_node" name="fm3n_predictor" output="screen">
+ <remap from="cloud" to="/hist/training_cloud"/>
+ </node>
+ <node pkg="bagserver" type="bagserver_srv.py" name="hist_server" output="screen">
+ <param name="namespace" value="hist"/>
+ <param name="index" value="$(find functional_m3n_ros)/test_data/pcd_all_1.index"/>
+ <param name="message_publishing_delay" value="0.5"/>
+ </node>
+
+
+ <node pkg="rosrecord" type="rosplay" args="-s 1 $(find functional_m3n_ros)/test_data/pcd_train_1.bag" output="screen"/>
+
+ <test test-name="functional_m3n_basic" pkg="functional_m3n_ros" type="test_full_training_1.py" name="test_executive" time-limit="600"/>
+
+
+
+</launch>
Added: pkg/trunk/sandbox/functional_m3n_ros/test/test_full_training_1.py
===================================================================
--- pkg/trunk/sandbox/functional_m3n_ros/test/test_full_training_1.py (rev 0)
+++ pkg/trunk/sandbox/functional_m3n_ros/test/test_full_training_1.py 2009-08-06 01:09:14 UTC (rev 20864)
@@ -0,0 +1,99 @@
+#!/usr/bin/env python
+# Software License Agreement (BSD License)
+#
+# Copyright (c) 2009, Willow Garage, Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following
+# disclaimer in the documentation and/or other materials provided
+# with the distribution.
+# * Neither the name of Willow Garage, Inc. nor the names of its
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+
+PKG = 'functional_m3n_ros'
+import roslib; roslib.load_manifest(PKG)
+
+import rospy
+
+import unittest
+
+from functional_m3n_ros.srv import *
+from bagserver.srv import *
+
+class FullTestLearningBasic(unittest.TestCase):
+
+
+ def testLearningSimple(self):
+ rospy.sleep(5);
+ rospy.wait_for_service('/learn')
+
+ learn_proxy = rospy.ServiceProxy('learn', Learn)
+ model_name="test_model_%s" % str(rospy.get_rostime());
+ result=learn_proxy(model_name);
+
+ print model_name
+ print result
+ predictor_proxy = rospy.ServiceProxy('SetModel', SetModel)
+ predictor_performance = rospy.ServiceProxy('Performance', QueryPerformanceStats)
+
+ set_model_resul=predictor_proxy(result.model_type,
+ model_name,
+ result.model_reference);
+
+ play_history = rospy.ServiceProxy('hist', History)
+
+ begin=rospy.Time(1247098041,895116000);
+ end =rospy.Time(1247098087,908848000);
+ play_history(begin,end,"ALL")
+ rospy.sleep(10);
+
+ perf1=predictor_performance();
+
+ rospy.loginfo("Accuracy %f (%f of %f )" %(perf1.accuracy,perf1.correct_weight,perf1.checked_weight))
+
+ self.failIf(perf1.accuracy<0.8);
+
+ set_model_resul=predictor_proxy(result.model_type,
+ model_name,
+ result.model_reference);
+
+ begin=rospy.Time( 1247098100, 316937000);
+ end =rospy.Time( 1247098133, 726237000);
+
+ play_history(begin,end,"ALL")
+
+ rospy.sleep(10);
+ perf2=predictor_performance();
+ rospy.loginfo("Accuracy %f (%f of %f )" %(perf2.accuracy,perf2.correct_weight,perf2.checked_weight))
+
+ self.failIf(perf2.accuracy<0.7);
+
+from threading import Thread
+
+if __name__ == "__main__":
+ import rostest
+ rospy.init_node("test_content");
+
+ rostest.rosrun('functional_m3n_ros','test_full_training_1',FullTestLearningBasic)
+
Property changes on: pkg/trunk/sandbox/functional_m3n_ros/test/test_full_training_1.py
___________________________________________________________________
Added: svn:executable
+ *
Property changes on: pkg/trunk/sandbox/functional_m3n_ros/test_data
___________________________________________________________________
Added: svn:ignore
+ pcd_all_1.index
pcd_all_1.bag
pcd_train_1.bag
pcd_test_1.bag
Modified: pkg/trunk/sandbox/functional_m3n_ros/test_data/m3n_predictor.launch
===================================================================
--- pkg/trunk/sandbox/functional_m3n_ros/test_data/m3n_predictor.launch 2009-08-06 00:59:27 UTC (rev 20863)
+++ pkg/trunk/sandbox/functional_m3n_ros/test_data/m3n_predictor.launch 2009-08-06 01:09:14 UTC (rev 20864)
@@ -1,13 +1,13 @@
<launch>
- <node pkg="functional_m3n_ros" type="m3n_prediction_node" name="fm3n_learner" output="screen"
- launch-prefix="xterm -e gdb -args"
+ <node pkg="functional_m3n_ros" type="m3n_prediction_node" name="fm3n_predictor" output="screen"
>
<!-- launch-prefix="xterm -e gdb -args" -->
<param name="use_color" value="True"/>
+ <param name="ground_truth_channel" value="ann-w-env-layer-1p"/>
- <param name="model_file" value="$(find functional_m3n_ros)/test_data/test_model_root/test_model_1248938712259804010/f_m3n"/>
+ <param name="model_file" value="$(find functional_m3n_ros)/test_data/test_model_root/test_model_1249495947974724054/f_m3n"/>
<!--param name="model_file" value="$(find functional_m3n_ros)/test_data/test_model_root/test_model_1248824163850214958/f_m3n" -->
<!--param name="model_file" value="$(find functional_m3n_ros)/test_data/test_model_root/test_model_1248845007846792936/f_m3n"-->
Modified: pkg/trunk/util/bagserver/src/bagserver_srv.py
===================================================================
--- pkg/trunk/util/bagserver/src/bagserver_srv.py 2009-08-06 00:59:27 UTC (rev 20863)
+++ pkg/trunk/util/bagserver/src/bagserver_srv.py 2009-08-06 01:09:14 UTC (rev 20864)
@@ -61,6 +61,7 @@
self.out_namespace=rospy.get_param("~namespace");
self.index_name_=rospy.get_param("~index");
+ self.pub_delay_=rospy.get_param("~message_publishing_delay",0);
self.setup_hist();
@@ -185,7 +186,7 @@
self.active_topics={};
if not req.topic == "":
- if req.topic =="*":
+ if req.topic =="*" or req.topic =="ALL":
topic_filter_dict=None;
else:
topic_filter_dict={};
@@ -223,7 +224,7 @@
def handle_query(self,req):
- rospy.logdebug(" Query %s - %s " % (req.begin,req.end))
+ rospy.loginfo(" Query %s - %s " % (req.begin,req.end))
self.setll(req);
@@ -234,6 +235,7 @@
if rospy.is_shutdown():
break
nextT=self.pick_next_topic();
+
if nextT is None:
break
(sec,nsec,idx,file_pos,topic,iBag)=self.active_topics[nextT]
@@ -252,7 +254,10 @@
sim_time.rostime.secs=sec;
sim_time.rostime.nsecs=nsec;
self.time_pub_.publish(sim_time)
- rospy.sleep(0.00001)
+ if self.pub_delay_>0:
+ rospy.sleep(self.pub_delay_);
+ else:
+ rospy.sleep(0.00001)
self.advance_topic(nextT,req.end)
This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site.
|