SourceForge has been redesigned. Learn more.
Close

[50073b]: / src / modules / bugs / samplers / DSumFactory.cc  Maximize  Restore  History

Download this file

123 lines (107 with data), 2.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include <config.h>
#include <distribution/Distribution.h>
#include <graph/Graph.h>
#include <graph/StochasticNode.h>
#include <graph/NodeError.h>
#include <sampler/MutableSampler.h>
#include <sampler/GraphView.h>
#include "DSumFactory.h"
#include "RealDSum.h"
#include "DiscreteDSum.h"
#include "DMultiDSum.h"
#include <algorithm>
using std::set;
using std::vector;
using std::string;
namespace jags {
namespace bugs {
static StochasticNode const *getDSumChild(StochasticNode *node)
{
set<StochasticNode*>::const_iterator p;
for (p = node->stochasticChildren()->begin();
p != node->stochasticChildren()->end(); ++p)
{
//Skip unobserved nodes
if (isObserved(*p) && (*p)->distribution()->name() == "dsum")
return *p;
}
return 0;
}
Sampler * DSumFactory::makeSampler(set<StochasticNode*> const &nodes,
Graph const &graph) const
{
//Find DSum node
StochasticNode const *dsum_node = 0;
for (set<StochasticNode*>::const_iterator p = nodes.begin();
p != nodes.end(); ++p)
{
dsum_node = getDSumChild(*p);
if (dsum_node)
break;
}
if (!dsum_node)
return 0;
//See if we can sample the parents. This can only be done if they
//are unobserved stochastic nodes in the sample set
vector<StochasticNode *> parameters;
vector<Node const *> const &parents = dsum_node->parents();
vector<Node const *>::const_iterator pp;
for (pp = parents.begin(); pp != parents.end(); ++pp) {
set<StochasticNode*>::const_iterator q =
find(nodes.begin(), nodes.end(), *pp);
if (q != nodes.end()) {
parameters.push_back(*q);
}
else {
return 0;
}
}
bool discrete;
bool multinom = false;
string name;
if (RWDSum::canSample(parameters, graph, false, false)) {
discrete = false;
name = "bugs::RealDSum";
}
else if (RWDSum::canSample(parameters, graph, true, false)) {
discrete = true;
name = "bugs::DiscreteDSum";
}
else if (RWDSum::canSample(parameters, graph, true, true)) {
discrete = true;
multinom = true;
name = "bugs::DMultiDSum";
}
else {
return 0;
}
GraphView *gv = new GraphView(parameters, graph, true);
unsigned int nchain = parameters[0]->nchain();
vector<MutableSampleMethod*> methods(nchain, 0);
for (unsigned int ch = 0; ch < nchain; ++ch) {
if (discrete) {
if (multinom)
methods[ch] = new DMultiDSum(gv, ch);
else
methods[ch] = new DiscreteDSum(gv, ch);
}
else {
methods[ch] = new RealDSum(gv, ch);
}
}
return new MutableSampler(gv, methods, name);
}
string DSumFactory::name() const
{
return "bugs::DSum";
}
vector<Sampler*> DSumFactory::makeSamplers(set<StochasticNode*> const &nodes,
Graph const &graph) const
{
Sampler *s = makeSampler(nodes, graph);
if (s)
return vector<Sampler*>(1, s);
else
return vector<Sampler*>();
}
}}