[01a998]: src / modules / glm / samplers / BinaryFactory.cc Maximize Restore History

Download this file

BinaryFactory.cc    88 lines (76 with data), 1.9 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <config.h>
#include <string>
#include "BinaryFactory.h"
#include "BinaryGLM.h"
#include "Linear.h"
#include <graph/StochasticNode.h>
#include <graph/LinkNode.h>
#include <distribution/Distribution.h>
using std::string;
using std::vector;
namespace glm {
BinaryFactory::BinaryFactory(string const &name, bool gibbs)
: GLMFactory(name), _gibbs(gibbs)
{}
bool BinaryFactory::checkOutcome(StochasticNode const *snode,
LinkNode const *lnode) const
{
Node const *N = 0;
string linkname;
if (lnode) {
linkname = lnode->linkName();
}
switch(GLMMethod::getFamily(snode)) {
case GLM_BERNOULLI:
return linkname == "probit" || linkname=="logit";
case GLM_BINOMIAL:
N = snode->parents()[1];
if (N->length() != 1)
return false;
if (!N->isObserved())
return false;
if (N->value(0)[0] != 1)
return false;
return linkname == "probit" || linkname=="logit";
case GLM_NORMAL:
return lnode == 0;
default:
return false;
}
}
GLMMethod *
BinaryFactory::newMethod(GraphView const *view,
vector<GraphView const *> const &sub_views,
unsigned int chain) const
{
/*
If we have a pure gaussian linear model then make a
conjugate linear sampler instead. There is no need, in this
case, for the extra machinery.
*/
bool linear = true;
vector<StochasticNode const*> const &children =
view->stochasticChildren();
for (unsigned int i = 0; i < children.size(); ++i) {
if (GLMMethod::getFamily(children[i]) != GLM_NORMAL) {
linear = false;
break;
}
}
if (linear) {
return new Linear(view, sub_views, chain, _gibbs);
}
else {
return newBinary(view, sub_views, chain);
}
}
bool BinaryFactory::canSample(StochasticNode const *snode) const
{
if (_gibbs) {
return snode->length() == 1;
}
else {
return !isBounded(snode);
}
}
}