[50073b]: src / modules / mix / samplers / MixSamplerFactory.cc Maximize Restore History

Download this file

MixSamplerFactory.cc    121 lines (106 with data), 3.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#include <config.h>
#include "MixSamplerFactory.h"
#include "NormMix.h"
#include <graph/GraphMarks.h>
#include <graph/Graph.h>
#include <graph/StochasticNode.h>
#include <distribution/Distribution.h>
#include <sampler/MutableSampler.h>
#include <sampler/SingletonGraphView.h>
#include <set>
using std::set;
using std::vector;
using std::string;
#define NLEVEL 200
#define MAX_TEMP 100
#define NREP 5
namespace jags {
/*
* Returns a pointer to a newly allocated SingletonGraphView if snode
* has a stochastic child with distribution "dnormmix", otherwise a
* null pointer.
*/
static SingletonGraphView * isCandidate(StochasticNode *snode,
Graph const &graph)
{
SingletonGraphView *gv = new SingletonGraphView(snode, graph);
vector<StochasticNode *> const &schildren = gv->stochasticChildren();
for (unsigned int i = 0; i < schildren.size(); ++i) {
if (schildren[i]->distribution()->name() == "dnormmix") {
return gv;
}
}
delete gv;
return 0;
}
/*
* Used to aggregate nodes with common stochastic children.
*/
static void aggregate(SingletonGraphView const *gv,
vector<StochasticNode *> &nodes,
set<StochasticNode const*> &common_children)
{
bool agg = nodes.empty();
vector<StochasticNode *> const &schildren = gv->stochasticChildren();
for (unsigned int i = 0; i < schildren.size(); ++i) {
if (common_children.count(schildren[i])) {
agg = true;
break;
}
}
if (agg) {
for (unsigned int i = 0; i < schildren.size(); ++i) {
common_children.insert(schildren[i]);
}
nodes.push_back(gv->node());
}
}
namespace mix {
Sampler * MixSamplerFactory::makeSampler(set<StochasticNode*> const &nodes,
Graph const &graph) const
{
vector<SingletonGraphView*> gvec;
for (set<StochasticNode*>::const_iterator p = nodes.begin();
p != nodes.end(); ++p)
{
SingletonGraphView *gv = isCandidate(*p, graph);
if (gv) {
gvec.push_back(gv);
}
}
if (gvec.empty())
return 0;
vector<StochasticNode *> sample_nodes;
set<StochasticNode const *> common_children;
for (unsigned int i = 0; i < gvec.size(); ++i) {
aggregate(gvec[i], sample_nodes, common_children);
delete gvec[i];
}
if (NormMix::canSample(sample_nodes)) {
GraphView *gv = new GraphView(sample_nodes, graph, true);
unsigned int nchain = sample_nodes[0]->nchain();
vector<MutableSampleMethod*> methods(nchain,0);
for (unsigned int ch = 0; ch < nchain; ++ch) {
methods[ch] = new NormMix(gv, ch, NLEVEL, MAX_TEMP, NREP);
}
return new MutableSampler(gv, methods, "mix::NormMix");
}
else {
return 0;
}
}
string MixSamplerFactory::name() const
{
return "mix::TemperedMix";
}
vector<Sampler*>
MixSamplerFactory::makeSamplers(set<StochasticNode*> const &nodes,
Graph const &graph) const
{
Sampler *s = makeSampler(nodes, graph);
if (s)
return vector<Sampler*>(1, s);
else
return vector<Sampler*>();
}
}}