Commit 785b7f46 authored by Manuel Baum's avatar Manuel Baum
Browse files

Encapsulated PoE and PosteriorMarginal in reasoning/MCMC object

parent def0efc8
......@@ -2,9 +2,8 @@
# -*- coding: utf-8 -*-
from primo.core import BayesNet
from primo.reasoning import DiscreteNode
from primo.reasoning import MarkovChainSampler
from primo.reasoning import GibbsTransitionModel
from primo.reasoning.density import ProbabilityTable
from primo.reasoning import MCMC
import numpy
import pdb
......@@ -24,58 +23,19 @@ burglary.set_probability_table(burglary_cpt, [burglary])
alarm_cpt=numpy.array([[0.8,0.15,0.05],[0.05,0.9,0.05]])
alarm.set_probability_table(alarm_cpt, [burglary,alarm])
#Construct a Markov Chain by sampling states from this Network
transition_model = GibbsTransitionModel()
mcs = MarkovChainSampler()
initial_state={burglary:"Safe",alarm:"Silent"}
chain = mcs.generateMarkovChain(bn, 5000, transition_model, initial_state)
#for c in chain:
# print c
mcmc_ask=MCMC(bn)
pt = ProbabilityTable()
pt.add_variable(burglary)
pt.add_variable(alarm)
pt.to_jpt_by_states(chain)
print "----joint-probability----"
print pt
print "----burglary----"
print pt.marginalization(alarm)
print "----alarm----"
#print pt.division(burglary.get_cpd())
print "----ProbabilityOfEvidence----"
evidence={burglary:"Intruder"}
chain = mcs.generateMarkovChain(bn, 5000, transition_model, initial_state)
compatible_count=0
number_of_samples=0
for state in chain:
compatible = True
for node,value in evidence.items():
if state[node] != value:
compatible = False
break
print compatible
if compatible:
compatible_count = compatible_count + 1
number_of_samples = number_of_samples + 1
print compatible_count
print number_of_samples
probability_of_evidence = float(compatible_count)/float(number_of_samples)
print probability_of_evidence
print "ProbabilityOfEvidence: "
poe=mcmc_ask.calculate_PoE(evidence)
print poe
print "----EVIDENCE: burglary=Intruder----"
evidence={burglary:"Intruder"}
initial_state={burglary:"Intruder",alarm:"Silent"}
chain = mcs.generateMarkovChain(bn, 5000, transition_model, initial_state, evidence)
pt.to_jpt_by_states(chain)
print pt
print "PosteriorMarginal:"
pm=mcmc_ask.calculate_PosteriorMarginal([alarm],evidence)
print pm
from MarkovChainSampler import MarkovChainSampler
from primo.reasoning import MarkovChainSampler
from primo.reasoning import GibbsTransitionModel
from primo.reasoning.density import ProbabilityTable
class MCMC(object):
def __init__(self, bn):
self.transition_model = GibbsTransitionModel()
self.bn=bn
self.mcs = MarkovChainSampler()
self.times=5000
def calculate_marginal(self):
def calculate_PriorMarginal(self,variables):
def calculate_PosteriorMarginal(self,variables,evidence):
pass
def calculate_PriorMarginal(self,variables):
#Construct a Markov Chain by sampling states from this Network
initial_state={burglary:"Safe",alarm:"Silent"}
chain = mcs.generateMarkovChain(bn, 5000, transition_model, initial_state)
#for c in chain:
# print c
pt = ProbabilityTable()
pt.add_variable(burglary)
pt.add_variable(alarm)
pt.to_jpt_by_states(chain)
print "----joint-probability----"
print pt
print "----burglary----"
print pt.marginalization(alarm)
print "----alarm----"
#print pt.division(burglary.get_cpd())
def calculate_PosteriorMarginal(self,variables_of_interest,evidence):
initial_state=self._generateInitialStateWithEvidence(evidence)
chain = self.mcs.generateMarkovChain(self.bn, self.times, self.transition_model, initial_state, evidence, variables_of_interest)
pt = ProbabilityTable()
pt.add_variables(variables_of_interest)
pt.to_jpt_by_states(chain)
return pt
def calculate_PoE(self,evidence):
initial_state=self._generateInitialStateWithEvidence(evidence)
chain = self.mcs.generateMarkovChain(self.bn, self.times, self.transition_model, initial_state)
compatible_count=0
number_of_samples=0
for state in chain:
compatible = True
for node,value in evidence.items():
if state[node] != value:
compatible = False
break
if compatible:
compatible_count = compatible_count + 1
number_of_samples = number_of_samples + 1
probability_of_evidence = float(compatible_count)/float(number_of_samples)
return probability_of_evidence
def _generateInitialStateWithEvidence(self, evidence):
state=[]
for var in self.bn.get_nodes([]):
if var in evidence.keys():
state.append((var,evidence[var]))
else:
state.append((var,var.value_range[0]))
return dict(state)
......@@ -53,7 +53,7 @@ class MarkovChainSampler(object):
def __init__(self):
pass
def generateMarkovChain(self, network, time_steps, transition_model, initial_state, evidence=[]):
def generateMarkovChain(self, network, time_steps, transition_model, initial_state, evidence=[], variables_of_interest=[]):
state=initial_state
if evidence:
for node in evidence.keys():
......@@ -63,6 +63,12 @@ class MarkovChainSampler(object):
else:
constant_nodes=[]
for t in xrange(time_steps):
yield state
if variables_of_interest:
yield self._reduce_state_to_variables_of_interest(state, variables_of_interest)
else:
yield state
state=transition_model.transition(network, state, constant_nodes)
def _reduce_state_to_variables_of_interest(self, state, variables_of_interest):
return dict((k,v) for (k,v) in state.iteritems() if k in variables_of_interest)
from RandomNode import RandomNode
from DiscreteNode import DiscreteNode
from GaussNode import GaussNode
from MCMC import MarkovChainSampler
from MCMC import GibbsTransitionModel
\ No newline at end of file
from MarkovChainSampler import MarkovChainSampler
from MarkovChainSampler import GibbsTransitionModel
from MCMC import MCMC
......@@ -34,6 +34,10 @@ class ProbabilityTable(Density):
ax = self.table.ndim
self.table=numpy.expand_dims(self.table,ax)
self.table=numpy.repeat(self.table,len(variable.value_range),axis = ax)
def add_variables(self, variables):
for v in variables:
self.add_variable(v)
def set_probability_table(self, table, nodes):
if not set(nodes) == set(self.variables):
......@@ -58,7 +62,6 @@ class ProbabilityTable(Density):
index = self.get_cpt_index(state.items())
self.table[index] = self.table[index] + 1
print self.table
return self.normalize_as_jpt()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment