Skip to content
Snippets Groups Projects
Commit 68231f08 authored by Denis John PC's avatar Denis John PC
Browse files

Marginalization implemented with tests

parent aac64176
No related branches found
No related tags found
No related merge requests found
......@@ -65,6 +65,9 @@ class ProbabilityTable(Density):
return numpy.sum(self.table) == 1.0
def multiplication(self, inputFactor):
'''This method returns a unified ProbabilityTable which contains the variables of both; the inputFactor
and this factor(self). The new values of the returned factor is the product of the values from the input factors
which are compatible to the variable instantiation of the returned value.'''
#init a new probability tabel
factor1 = ProbabilityTable()
......@@ -108,7 +111,7 @@ class ProbabilityTable(Density):
#pointwise multiplication
if factor1.table.shape != factor2.table.shape:
raise Exception("Multiplication: The probability tables have the wrong dimensions for unification")
raise Exception("Multiplication: The probability tables have the wrong dimensions for unification!")
factor1.table = factor1.table *factor2.table;
......@@ -116,7 +119,23 @@ class ProbabilityTable(Density):
def marginalization(self, variable):
raise Exception("Called unimplemented function")
'''This method returns a new instantiation with the given variable summed out.'''
if not variable in self.variables:
raise Exception("Marginalization: The given variable isn't in the ProbabilityTable!")
#new instance for returning
retInstance = ProbabilityTable()
retInstance.table = copy.copy(self.table)
retInstance.variables = copy.copy(self.variables)
ax = retInstance.variables.index(variable)
retInstance.table = numpy.sum(retInstance.table,ax)
retInstance.variables.remove(variable)
return retInstance
def reduction(self, evidence):
'''Returns a reduced version of this ProbabilityTable, evidence is a list of pairs.
......
......@@ -56,6 +56,28 @@ class MultiplicationTest(unittest.TestCase):
cptN = numpy.array([[20, 28, 36],[45, 63, 81]])
numpy.testing.assert_array_equal(c3.table,cptN)
class MarginalizationTest(unittest.TestCase):
def test_easy_marginalize(self):
n1 = DiscreteNode("Some Node", [True, False])
n2 = DiscreteNode("Second Node" , [True, False, "other"])
cpt1 = numpy.array([2,3])
cpt2 = numpy.array([5,7,3])
n1.set_probability_table(cpt1,[n1])
n2.set_probability_table(cpt2,[n2])
s = n1.get_cpd().multiplication(n2.get_cpd())
s =s.marginalization(n2)
print s.table
cptN = numpy.array([30,45])
numpy.testing.assert_array_equal(s.table,cptN)
self.asserEqual(s.variables[0],n1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment