Newer
Older
import primo.reasoning.density.ProbabilityTable as ProbabilityTable
'''The factor tree contains for each node of the BayesNet a factor. It
is a directed graph with one root node. To speed up the reasoning it uses
a message based approach which stores calculated intermediate results at
edges. Thus, the first query is expensive and all following are easy calculated.
The speed of the first message calculation depends on how the tree was build.
Literature: Modeling and Reasoning with Bayesian Networks - Adnan Darwiche
Chapter 7
'''
def __init__(self,graph,rootNode):
self.graph = graph
self.rootNode = rootNode
'''Calculates the probability of evidence with the set evidence'''
if not self.graph.graph['messagesValid']:
self.calculate_messages()
cpd = self.calculate_marginal_forOne(self.rootNode)
for v in cpd.get_variables()[:]:
cpd = cpd.marginalization(v)
return cpd
''' If evidence is set, then this methods calculates the posterior marginal.
With an empty evidence this is automatically the prior marginal.'''
resPT = ProbabilityTable.get_neutral_multiplication_PT()
for f in self.graph.nodes():
if f.get_node() in variables:
resPT = resPT.multiplication(self.calculate_marginal_forOne(f))
return resPT
def calculate_marginal_forOne(self,factor):
curCPD = factor.get_calculation_CDP().copy()
tmpCPD = self.graph[p][factor]['msgRightWay']
curCPD = curCPD.multiplication(tmpCPD)
tmpCPD = self.graph[factor][p]['msgAgainstWay']
curCPD = curCPD.multiplication(tmpCPD)
for v in curCPD.get_variables()[:]:
if v != factor.get_node():
curCPD = curCPD.marginalization(v)
return curCPD
import matplotlib.pyplot as plt
nx.draw_circular(self.graph)
plt.show()
''' Calculates the messages and stores the intermediate results.'''
self.pull_phase(self.rootNode,self.graph)
self.push_phase(self.rootNode,self.graph,ProbabilityTable.get_neutral_multiplication_PT())
self.graph.graph['messagesValid'] = True
evNodes = zip(*evidences)
for factor in self.graph.nodes():
if factor.get_node() in evNodes[0]:
idx = evNodes[0].index(factor.get_node())
def pull_phase(self,factor,graph):
#calculate the messages of the children
for child in graph.neighbors(factor):
tmpInput = self.pull_phase(child,graph)
#project each factor on the specific separator
separator = graph[factor][child]['separator']
tmpInput = tmpInput.marginalization(var)
#save message on edge: it's the opposite of the direction of the edge
graph[factor][child]['msgAgainstWay'] = tmpInput
#calculate the new message
calCPD = calCPD.multiplication(tmpInput)
return calCPD
def push_phase(self,factor,graph,inCPD):
for child in graph.neighbors(factor):
tmpCPD = inCPD.multiplication(factor.get_calculation_CDP())
for child2 in graph.neighbors(factor):
if (child != child2):
tmpCPD = tmpCPD.multiplication(graph[factor][child2]['msgAgainstWay'])
separator = graph[factor][child]['separator']
#project on outgoing edge separator
for var in tmpCPD.variables:
if var not in separator:
tmpCPD = tmpCPD.marginalization(var)
#add setOut to outgoing vars from child
#Message with the direction of the edge
graph[factor][child]['msgRightWay'] = tmpCPD
self.push_phase(child,graph,tmpCPD)