From aac641766718f37d2b10a8aa998c521784694767 Mon Sep 17 00:00:00 2001
From: Denis John PC <djohn@techfak.uni-bielefeld.de>
Date: Sun, 27 Jan 2013 19:05:09 +0100
Subject: [PATCH] multiplication of ProbabilityTables done

---
 primo/reasoning/density/ProbabilityTable.py | 103 +++++++++++---------
 primo/tests/ProbabilityTable_test.py        |  65 ++++++++++++
 2 files changed, 121 insertions(+), 47 deletions(-)
 create mode 100644 primo/tests/ProbabilityTable_test.py

diff --git a/primo/reasoning/density/ProbabilityTable.py b/primo/reasoning/density/ProbabilityTable.py
index 2c60cc8..5c366f4 100644
--- a/primo/reasoning/density/ProbabilityTable.py
+++ b/primo/reasoning/density/ProbabilityTable.py
@@ -6,11 +6,11 @@ from primo.reasoning.density import Density
 class ProbabilityTable(Density):
     '''TODO: write doc'''
 
-    
+
 
     def __init__(self):
         super(ProbabilityTable, self).__init__()
-        
+
         #self.owner = owner
         #self.variables = [owner]
 
@@ -25,7 +25,7 @@ class ProbabilityTable(Density):
 
         ax = self.table.ndim
         self.table=numpy.expand_dims(self.table,ax)
-        self.table=numpy.repeat(self.table,len(variable.value_range),axis = ax)   
+        self.table=numpy.repeat(self.table,len(variable.value_range),axis = ax)
 
     def set_probability_table(self, table, nodes):
         if not set(nodes) == set(self.variables):
@@ -41,7 +41,7 @@ class ProbabilityTable(Density):
         self.variables = nodes
 
     def set_probability(self, value, node_value_pairs):
-        index = self.get_cpt_index(node_value_pairs) 
+        index = self.get_cpt_index(node_value_pairs)
         self.table[tuple(index)]=value
 
     def get_cpt_index(self, node_value_pairs):
@@ -62,53 +62,62 @@ class ProbabilityTable(Density):
         return set(sum_of_owner_probs.flatten()) == set([1])
 
     def is_normalized_as_jpt(self):
-        return numpy.sum(table) == 1.0
+        return numpy.sum(self.table) == 1.0
 
     def multiplication(self, inputFactor):
-        raise Exception("Called unimplemented function") 
         #init a new probability tabel
         factor1 = ProbabilityTable()
-        
+
         #all variables from both factors are needed
         factor1.variables = copy.copy(self.variables)
 
-        for v in factor.variables:
-			if not v in factor1.variables:
-				factor1.variables.append(v)
-		
-		#the table from the first factor is copied
-		factor1.table = copy.copy(self.table)
-		
-		#and extended by the dimensions for the left variables
-		for curIdx in range(factor1.table.ndim, len(factor1.variables)):
-			ax = factor1.table.ndim
-			factor1.table=numpy.expand_dims(factor1.table,ax)
-			factor1.table=numpy.repeat(factor1.table,len(factor1.variables[curIdx].values),axis = ax)
-			
-		#copy factor 2 and it's variables ...
-		factor2 = ProbabilityTable()
-		factor2.variables = copy.copy(inputFactor.variables)
-		factor2.table = copy.copy(inputFactor.table)
-		
-		#extend the dimensions of factors 2 to the dimensions of factor 1
-		for v in factor1.variables:
-			if not v in factor2.variables:
-				factor2.variables.append(v)
-				
-		for curIdx in range(factor2.table.ndim, len(factor2.variables)):
-			ax = factor2.table.ndim
-			factor2.table=numpy.expand_dims(factor2.table,ax)
-			factor2.table=numpy.repeat(factor2.table,len(factor2.variables[curIdx].values),axis = ax)
-		
-		#sort the variables to the same order
-		
-		#pointwise multiplication
-			
-		
-    
+        for v in (inputFactor.variables):
+            if not v in factor1.variables:
+                factor1.variables.append(v)
+
+            #the table from the first factor is copied
+            factor1.table = copy.copy(self.table)
+
+        #and extended by the dimensions for the left variables
+        for curIdx in range(factor1.table.ndim, len(factor1.variables)):
+            ax = factor1.table.ndim
+            factor1.table=numpy.expand_dims(factor1.table,ax)
+            factor1.table=numpy.repeat(factor1.table,len(factor1.variables[curIdx].value_range),axis = ax)
+
+        #copy factor 2 and it's variables ...
+        factor2 = ProbabilityTable()
+        factor2.variables = copy.copy(inputFactor.variables)
+        factor2.table = copy.copy(inputFactor.table)
+
+        #extend the dimensions of factors 2 to the dimensions of factor 1
+        for v in factor1.variables:
+            if not v in factor2.variables:
+                factor2.variables.append(v)
+
+        for curIdx in range(factor2.table.ndim, len(factor2.variables)):
+            ax = factor2.table.ndim
+            factor2.table=numpy.expand_dims(factor2.table,ax)
+            factor2.table=numpy.repeat(factor2.table,len(factor2.variables[curIdx].value_range),axis = ax)
+
+        #sort the variables to the same order
+        for endDim,variable in enumerate(factor1.variables):
+            startDim = factor2.variables.index(variable);
+            if not startDim == endDim:
+                factor2.table = numpy.rollaxis(factor2.table, startDim, endDim)
+                factor2.variables.insert(endDim,factor2.variables.pop(startDim))
+
+        #pointwise multiplication
+        if factor1.table.shape != factor2.table.shape:
+            raise Exception("Multiplication: The probability tables have the wrong dimensions for unification")
+
+        factor1.table = factor1.table *factor2.table;
+
+        return factor1
+
+
     def marginalization(self, variable):
-        raise Exception("Called unimplemented function")       
-        
+        raise Exception("Called unimplemented function")
+
     def reduction(self, evidence):
         '''Returns a reduced version of this ProbabilityTable, evidence is a list of pairs.
             Important: This node is not being changed!'''
@@ -120,13 +129,13 @@ class ProbabilityTable(Density):
             axis=reduced.variables.index(node)
             position=node.value_range.index(value)
             reduced.table = numpy.take(reduced.table,[position],axis=axis)
-            
+
             reduced.table=reduced.table.squeeze()
             reduced.variables.remove(node)
-            
+
         return reduced
-            
-            
+
+
 
     def division(self, factor):
         raise Exception("Called unimplemented function")
diff --git a/primo/tests/ProbabilityTable_test.py b/primo/tests/ProbabilityTable_test.py
new file mode 100644
index 0000000..d485ee0
--- /dev/null
+++ b/primo/tests/ProbabilityTable_test.py
@@ -0,0 +1,65 @@
+import unittest
+import numpy
+
+from primo.reasoning.density import ProbabilityTable
+from primo.reasoning import DiscreteNode
+
+
+class MultiplicationTest(unittest.TestCase):
+    def setUp(self):
+        self.pt = ProbabilityTable();
+
+    def tearDown(self):
+        self.pt = None
+
+    def test_easy_shape(self):
+        n1 = DiscreteNode("Some Node", [True, False])
+        n2 = DiscreteNode("Second Node" , [True, False])
+
+        s = n1.get_cpd().multiplication(n2.get_cpd())
+
+        self.assertEqual(s.table.shape, (2,2));
+
+        s = n1.get_cpd().multiplication(n1.get_cpd())
+        self.assertEqual(s.table.shape,(2,))
+        
+    def test_easy_values(self):
+        n1 = DiscreteNode("Some Node", [True, False])
+        n2 = DiscreteNode("Second Node" , [True, False])
+        
+        cpt1 = numpy.array([2,3])
+        cpt2 = numpy.array([5,7])
+        
+        n1.set_probability_table(cpt1,[n1])
+        n2.set_probability_table(cpt2,[n2])
+        
+        s = n1.get_cpd().multiplication(n2.get_cpd())
+        
+        cptN = numpy.array([[10,14],[15,21]])
+        
+        numpy.testing.assert_array_equal(s.table,cptN)
+        self.assertEqual(s.variables[0],n1)
+        
+    def test_complicated_multi(self):
+        n1 = DiscreteNode("Some Node", [True, False])
+        n2 = DiscreteNode("Second Node" , [True, False,"noIdea"])
+        
+        cpt1 = numpy.array([2,3])
+        cpt2 = numpy.array([5,7,9])
+        
+        n1.set_probability_table(cpt1,[n1])
+        n2.set_probability_table(cpt2,[n2])
+        
+        c3 = n1.get_cpd().multiplication(n2.get_cpd())
+        c3 = n1.get_cpd().multiplication(c3)
+        
+        cptN = numpy.array([[20, 28, 36],[45, 63, 81]])        
+        numpy.testing.assert_array_equal(c3.table,cptN)
+        
+        
+
+
+
+#include this so you can run this test without nose
+if __name__ == '__main__':
+    unittest.main()
-- 
GitLab