Newer
Older
import primo.reasoning.density.ProbabilityTable as ProbabilityTable
class FactorTree(object):
def __init__(self,graph,rootNode):
self.graph = graph
self.rootNode = rootNode
def calculate_marginal(self,variables):
if not self.graph['messagesValid']:
self.calculateMessages()
resPT = ProbabilityTable()
for f in self.graph.nodes():
if f.get_node() in variables:
resPT = resPT.multiplication.calculate_marginal_forOne(f)
return resPT
def calculate_marginal_forOne(self,factor):
curCPD = factor.get_calculation_CDP().copy()
for p in self.graph.predecessors(factor):
curCPD = curCPD.multiplication(self.graph[p][factor]['inMessage'])
for p in self.graph.neighbors(factor):
curCPD = curCPD.multiplication(self.graph[factor][p]['outMessage'])
for v in curCPD.get_variables()[:]:
if v != factor.get_node():
curCPD = curCPD.marginalization(v)
return curCPD
def draw(self):
import matplotlib.pyplot as plt
nx.draw_circular(self.graph)
plt.show()
def calculateMessages(self):
self.push_phase(self.rootNode,self.graph)
self.pull_phase(self.rootNode,self.graph,ProbabilityTable())
self.graph['messagesValid'] = True
def setEvidences(self,evidences):
self.graph['messagesValid'] = False
evNodes = zip(*evidences)
for factor in self.graph.get_all_nodes():
if factor.get_node() in evNodes:
idx = evNodes.index(factor.get_node())
factor.set_evidence(evidences[idx])
def pull_phase(self,factor,graph):
calCPD = factor.get_calculate_CPD()
#calculate the messages of the children
for child in graph.neighbors(factor):
tmpInput = self.pull_phase(child,graph)
#project each factor on the specific seperator
seperator = graph[factor][child]['seperator']
for var in tmpInput.variables[:]:
if var not in seperator:
tmpInput = tmpInput.marginalization(var)
#save message on edge
graph[factor][child]['inMessage'] = tmpInput
#calculate the new message
calCPD = calCPD.multiplication(tmpInput)
return calCPD
def push_phase(self,factor,graph,inCPD):
for child in graph.neighbors(factor):
for child2 in graph.neighbors(factor):
if (child != child2):
tmpCPD = tmpCPD.multiplication(graph[factor][child2]['inMessage'])
seperator = graph[factor][child]['seperator']
#project on outgoing edge seperator
for var in tmpCPD.variables[:]:
if var not in seperator:
tmpCPD = tmpCPD.marginalization(var)
#add setOut to outgoing vars from child
graph[factor][child]['outMessage'] = tmpCPD
self.push_phase(child,graph,tmpCPD)