From aac641766718f37d2b10a8aa998c521784694767 Mon Sep 17 00:00:00 2001 From: Denis John PC <djohn@techfak.uni-bielefeld.de> Date: Sun, 27 Jan 2013 19:05:09 +0100 Subject: [PATCH] multiplication of ProbabilityTables done --- primo/reasoning/density/ProbabilityTable.py | 103 +++++++++++--------- primo/tests/ProbabilityTable_test.py | 65 ++++++++++++ 2 files changed, 121 insertions(+), 47 deletions(-) create mode 100644 primo/tests/ProbabilityTable_test.py diff --git a/primo/reasoning/density/ProbabilityTable.py b/primo/reasoning/density/ProbabilityTable.py index 2c60cc8..5c366f4 100644 --- a/primo/reasoning/density/ProbabilityTable.py +++ b/primo/reasoning/density/ProbabilityTable.py @@ -6,11 +6,11 @@ from primo.reasoning.density import Density class ProbabilityTable(Density): '''TODO: write doc''' - + def __init__(self): super(ProbabilityTable, self).__init__() - + #self.owner = owner #self.variables = [owner] @@ -25,7 +25,7 @@ class ProbabilityTable(Density): ax = self.table.ndim self.table=numpy.expand_dims(self.table,ax) - self.table=numpy.repeat(self.table,len(variable.value_range),axis = ax) + self.table=numpy.repeat(self.table,len(variable.value_range),axis = ax) def set_probability_table(self, table, nodes): if not set(nodes) == set(self.variables): @@ -41,7 +41,7 @@ class ProbabilityTable(Density): self.variables = nodes def set_probability(self, value, node_value_pairs): - index = self.get_cpt_index(node_value_pairs) + index = self.get_cpt_index(node_value_pairs) self.table[tuple(index)]=value def get_cpt_index(self, node_value_pairs): @@ -62,53 +62,62 @@ class ProbabilityTable(Density): return set(sum_of_owner_probs.flatten()) == set([1]) def is_normalized_as_jpt(self): - return numpy.sum(table) == 1.0 + return numpy.sum(self.table) == 1.0 def multiplication(self, inputFactor): - raise Exception("Called unimplemented function") #init a new probability tabel factor1 = ProbabilityTable() - + #all variables from both factors are needed factor1.variables = copy.copy(self.variables) - for v in factor.variables: - if not v in factor1.variables: - factor1.variables.append(v) - - #the table from the first factor is copied - factor1.table = copy.copy(self.table) - - #and extended by the dimensions for the left variables - for curIdx in range(factor1.table.ndim, len(factor1.variables)): - ax = factor1.table.ndim - factor1.table=numpy.expand_dims(factor1.table,ax) - factor1.table=numpy.repeat(factor1.table,len(factor1.variables[curIdx].values),axis = ax) - - #copy factor 2 and it's variables ... - factor2 = ProbabilityTable() - factor2.variables = copy.copy(inputFactor.variables) - factor2.table = copy.copy(inputFactor.table) - - #extend the dimensions of factors 2 to the dimensions of factor 1 - for v in factor1.variables: - if not v in factor2.variables: - factor2.variables.append(v) - - for curIdx in range(factor2.table.ndim, len(factor2.variables)): - ax = factor2.table.ndim - factor2.table=numpy.expand_dims(factor2.table,ax) - factor2.table=numpy.repeat(factor2.table,len(factor2.variables[curIdx].values),axis = ax) - - #sort the variables to the same order - - #pointwise multiplication - - - + for v in (inputFactor.variables): + if not v in factor1.variables: + factor1.variables.append(v) + + #the table from the first factor is copied + factor1.table = copy.copy(self.table) + + #and extended by the dimensions for the left variables + for curIdx in range(factor1.table.ndim, len(factor1.variables)): + ax = factor1.table.ndim + factor1.table=numpy.expand_dims(factor1.table,ax) + factor1.table=numpy.repeat(factor1.table,len(factor1.variables[curIdx].value_range),axis = ax) + + #copy factor 2 and it's variables ... + factor2 = ProbabilityTable() + factor2.variables = copy.copy(inputFactor.variables) + factor2.table = copy.copy(inputFactor.table) + + #extend the dimensions of factors 2 to the dimensions of factor 1 + for v in factor1.variables: + if not v in factor2.variables: + factor2.variables.append(v) + + for curIdx in range(factor2.table.ndim, len(factor2.variables)): + ax = factor2.table.ndim + factor2.table=numpy.expand_dims(factor2.table,ax) + factor2.table=numpy.repeat(factor2.table,len(factor2.variables[curIdx].value_range),axis = ax) + + #sort the variables to the same order + for endDim,variable in enumerate(factor1.variables): + startDim = factor2.variables.index(variable); + if not startDim == endDim: + factor2.table = numpy.rollaxis(factor2.table, startDim, endDim) + factor2.variables.insert(endDim,factor2.variables.pop(startDim)) + + #pointwise multiplication + if factor1.table.shape != factor2.table.shape: + raise Exception("Multiplication: The probability tables have the wrong dimensions for unification") + + factor1.table = factor1.table *factor2.table; + + return factor1 + + def marginalization(self, variable): - raise Exception("Called unimplemented function") - + raise Exception("Called unimplemented function") + def reduction(self, evidence): '''Returns a reduced version of this ProbabilityTable, evidence is a list of pairs. Important: This node is not being changed!''' @@ -120,13 +129,13 @@ class ProbabilityTable(Density): axis=reduced.variables.index(node) position=node.value_range.index(value) reduced.table = numpy.take(reduced.table,[position],axis=axis) - + reduced.table=reduced.table.squeeze() reduced.variables.remove(node) - + return reduced - - + + def division(self, factor): raise Exception("Called unimplemented function") diff --git a/primo/tests/ProbabilityTable_test.py b/primo/tests/ProbabilityTable_test.py new file mode 100644 index 0000000..d485ee0 --- /dev/null +++ b/primo/tests/ProbabilityTable_test.py @@ -0,0 +1,65 @@ +import unittest +import numpy + +from primo.reasoning.density import ProbabilityTable +from primo.reasoning import DiscreteNode + + +class MultiplicationTest(unittest.TestCase): + def setUp(self): + self.pt = ProbabilityTable(); + + def tearDown(self): + self.pt = None + + def test_easy_shape(self): + n1 = DiscreteNode("Some Node", [True, False]) + n2 = DiscreteNode("Second Node" , [True, False]) + + s = n1.get_cpd().multiplication(n2.get_cpd()) + + self.assertEqual(s.table.shape, (2,2)); + + s = n1.get_cpd().multiplication(n1.get_cpd()) + self.assertEqual(s.table.shape,(2,)) + + def test_easy_values(self): + n1 = DiscreteNode("Some Node", [True, False]) + n2 = DiscreteNode("Second Node" , [True, False]) + + cpt1 = numpy.array([2,3]) + cpt2 = numpy.array([5,7]) + + n1.set_probability_table(cpt1,[n1]) + n2.set_probability_table(cpt2,[n2]) + + s = n1.get_cpd().multiplication(n2.get_cpd()) + + cptN = numpy.array([[10,14],[15,21]]) + + numpy.testing.assert_array_equal(s.table,cptN) + self.assertEqual(s.variables[0],n1) + + def test_complicated_multi(self): + n1 = DiscreteNode("Some Node", [True, False]) + n2 = DiscreteNode("Second Node" , [True, False,"noIdea"]) + + cpt1 = numpy.array([2,3]) + cpt2 = numpy.array([5,7,9]) + + n1.set_probability_table(cpt1,[n1]) + n2.set_probability_table(cpt2,[n2]) + + c3 = n1.get_cpd().multiplication(n2.get_cpd()) + c3 = n1.get_cpd().multiplication(c3) + + cptN = numpy.array([[20, 28, 36],[45, 63, 81]]) + numpy.testing.assert_array_equal(c3.table,cptN) + + + + + +#include this so you can run this test without nose +if __name__ == '__main__': + unittest.main() -- GitLab