Skip to content
Snippets Groups Projects
Commit 2345c104 authored by Denis John PC's avatar Denis John PC
Browse files

FactorTree input und output messages

parent f6e62e9f
No related branches found
No related tags found
No related merge requests found
......@@ -193,6 +193,15 @@ class ProbabilityTable(Density):
ev.table[pos_value] = tmpCpd[pos_value]
return ev
def copy(self):
'''Returns a copied version of this probabilityTable'''
ev = ProbabilityTable()
ev.variables = copy.copy(self.variables)
ev.table = copy.copy(self.table)
return ev
......
......@@ -9,7 +9,7 @@ class Factor(object):
def __init__(self,node):
self.node = node
self.calPT = node.get_cpd()
self.calCPD = node.get_cpd()
def __str__(self):
return self.node.name
......@@ -23,10 +23,10 @@ class Factor(object):
def set_cluster(self,cluster):
self.cluster = cluster
def add_to_cluster(self,node):
self.cluster.add(node)
def get_variables(self):
return self.node.get_cpd().variables
def get_calculation_CDP(self):
return self.calCPD;
import networkx as nx
import primo.reasoning.density.ProbabilityTable as ProbabilityTable
import copy
......@@ -14,6 +16,50 @@ class FactorTree(object):
import matplotlib.pyplot as plt
nx.draw_circular(self.graph)
plt.show()
def pull_phase(self,factor,graph):
calCPD = factor.get_calculate_CPD()
#calculate the messages of the children
for child in graph.neighbors(factor):
tmpInput = self.pull_phase(child,graph)
#project each factor on the specific seperator
seperator = graph[factor][child]['seperator']
for var in tmpInput.variables[:]:
if var not in seperator:
tmpInput = tmpInput.marginalization(var)
#save message on edge
graph[factor][child]['inMessage'] = tmpInput
#calculate the new message
calCPD = calCPD.multiplication(tmpInput)
return calCPD
def push_phase(self,factor,graph,inCPD):
#add local vars to set
calCPD = inCPD.copy()
for child in graph.neighbors(factor):
tmpCPD = calCPD
for child2 in graph.neighbors(factor):
if (child != child2):
tmpCPD = tmpCPD.multiplication(graph[factor][child2]['inMessage'])
seperator = graph[factor][child]['seperator']
#project on outgoing edge seperator
for var in tmpCPD.variables[:]:
if var not in seperator:
tmpCPD = tmpCPD.marginalization(var)
#add setOut to outgoing vars from child
graph[factor][child]['outMessage'] = tmpCPD
self.push_phase(child,graph,tmpCPD)
......
......@@ -30,7 +30,9 @@ class FactorTreeFactory(object):
self.calculate_seperators_pull(rootFactor,graph)
self.calculate_seperators_push(rootFactor,graph,set())
self.inject_seperators(graph)
self.intersect_seperators(graph)
self.calculate_clusters(rootFactor,graph,set())
return FactorTree(graph,rootFactor)
......@@ -46,18 +48,6 @@ class FactorTreeFactory(object):
# add s to incoming vars from child
tmp = graph[factor][child]['inVars']
graph[factor][child]['inVars'] = tmp | s
# for c2 in graph.children(factor):
# if child != c2:
# #add s to outgoing vars from c2
# if graph[factor][child]['outVars'] == None:
# graph[factor][child]['outVars'] = copy.copy(s)
# else:
# tmp = graph[factor][child]['outVars']
# graph[factor][child]['outVars'] = tmp | s
#self.calculate_seperators_push(child,graph,s)
pullSet = s | pullSet
......@@ -84,11 +74,22 @@ class FactorTreeFactory(object):
self.calculate_seperators_push(child,graph,tmpSet)
def inject_seperators(self,graph):
def intersect_seperators(self,graph):
for n,nbrs in graph.adjacency_iter():
for nbr,eattr in nbrs.items():
eattr['seperator'] = eattr['inVars'] & eattr['outVars']
eattr['seperator'] = eattr['inVars'] & eattr['outVars']
def calculate_clusters(self,factor,graph,parent_seperator):
localCluster = parent_seperator | set(factor.get_variables())
for n in graph.neighbors(factor):
tmpSeperator = graph[factor][n]['seperator']
localCluster = localCluster | tmpSeperator
self.calculate_clusters(n,graph,tmpSeperator)
factor.set_cluster(localCluster)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment