Skip to content
Snippets Groups Projects
Commit c8be19fc authored by Denis John's avatar Denis John
Browse files

easy factorelimination

parent 3a87d612
Branches
Tags
No related merge requests found
......@@ -173,6 +173,26 @@ class ProbabilityTable(Density):
reduced.variables.remove(node)
return reduced
def set_evidence(self,evidence):
'''Returns a new version of the ProbabilityTable with only the evidence
not equal zero'''
ev = ProbabilityTable()
ev.variables = copy.copy(self.variables)
ev.table = numpy.zeros(self.table.shape)
tmpCpd = self.table
pos_variable = ev.variables.index(evidence[0])
pos_value = ev.variables[pos_variable].value_range.index(evidence[1])
ev.table = numpy.rollaxis(ev.table,pos_variable,0)
tmpCpd = numpy.rollaxis(tmpCpd,pos_variable,0)
ev.variables.insert(0,ev.variables.pop(pos_variable))
ev.table[pos_value] = tmpCpd[pos_value]
return ev
......
......@@ -31,10 +31,59 @@ class EasiestFactorElimination(object):
return finCpd
def calculate_PosteriorMarginal(self,variables,evidence):
bn = bn.copy()
#TODO
# Verbundwahrscheinlichkeit / PoE
# Erst wie Prior Marginal nur mit setzen der Evidence
# Dann PoE berechnen und damit normalisieren
nodes = self.bn.get_all_nodes()
unzipped_list = zip(*evidence)
node1 = nodes.pop()
if node1 in unzipped_list[0]:
ind = unzipped_list[0].index(node1)
finCpd = node1.get_cpd().set_evidence(evidence[ind])
else:
finCpd = node1.get_cpd()
for n in nodes:
if n in unzipped_list[0]:
ind = unzipped_list[0].index(n)
nCPD = n.get_cpd().set_evidence(evidence[ind])
finCpd = finCpd.multiplication(nCPD)
else:
finCpd = finCpd.multiplication(n.get_cpd())
for v in finCpd.get_variables()[:]:
finCpd = finCpd.marginalization(v)
return finCpd
def calculate_PoE(self,evidence):
bn = bn.copy()
def calculate_PoE(self,evidence):
nodes = self.bn.get_all_nodes()
unzipped_list = zip(*evidence)
node1 = nodes.pop()
if node1 in unzipped_list[0]:
ind = unzipped_list[0].index(node1)
finCpd = node1.get_cpd().set_evidence(evidence[ind])
else:
finCpd = node1.get_cpd()
for n in nodes:
if n in unzipped_list[0]:
ind = unzipped_list[0].index(n)
nCPD = n.get_cpd().set_evidence(evidence[ind])
finCpd = finCpd.multiplication(nCPD)
else:
finCpd = finCpd.multiplication(n.get_cpd())
for v in finCpd.get_variables()[:]:
finCpd = finCpd.marginalization(v)
return finCpd
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment