Skip to content
Snippets Groups Projects
Commit a0e4c0eb authored by Denis John PC's avatar Denis John PC
Browse files

Working on FactorTree

parent 30b6c912
No related branches found
No related tags found
No related merge requests found
......@@ -3,24 +3,26 @@ from primo.core import Node
class Factor(object):
evidence = ""
cluster = set()
isEvidence = False
def __init__(self,node):
self.node = node
self.calCPD = node.get_cpd()
self.calCPD = node.get_cpd().copy()
def __str__(self):
return self.node.name
def set_evidence(self,evd):
#TODO: copy from originalCPD
# setCPD
self.evidence = evd
self.calCPD = self.node.get_cpd().copy()
self.calCPD = self.calCPD.setEvidence(evd)
self.isEvidene = True
def clear_evidence(self):
self.evidence = ""
self.calCPD = self.node.get_cpd().copy()
self.isEvidence = False
def set_cluster(self,cluster):
self.cluster = cluster
......@@ -33,5 +35,8 @@ class Factor(object):
def get_node(self):
return self.node
def contains_node(self,node):
return self.node == node
......@@ -12,6 +12,33 @@ class FactorTree(object):
self.rootNode = rootNode
self.graph['messagesValid'] = False
def calculate_marginal(self,variables):
if not self.graph['messagesValid']:
self.calculateMessages()
resPT = ProbabilityTable()
for f in self.graph.nodes():
if f.get_node() in variables:
resPT = resPT.multiplication.calculate_marginal_forOne(f)
return resPT
def calculate_marginal_forOne(self,factor):
curCPD = factor.get_calculation_CDP().copy()
for p in self.graph.predecessors(factor):
curCPD = curCPD.multiplication(self.graph[p][factor]['inMessage'])
for p in self.graph.neighbors(factor):
curCPD = curCPD.multiplication(self.graph[factor][p]['outMessage'])
for v in curCPD.get_variables()[:]:
if v != factor.get_node():
curCPD = curCPD.marginalization(v)
return curCPD
def draw(self):
import matplotlib.pyplot as plt
nx.draw_circular(self.graph)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment