Skip to content
Snippets Groups Projects
Commit aac64176 authored by Denis John PC's avatar Denis John PC
Browse files

multiplication of ProbabilityTables done

parent a73b004b
No related branches found
No related tags found
No related merge requests found
...@@ -6,11 +6,11 @@ from primo.reasoning.density import Density ...@@ -6,11 +6,11 @@ from primo.reasoning.density import Density
class ProbabilityTable(Density): class ProbabilityTable(Density):
'''TODO: write doc''' '''TODO: write doc'''
def __init__(self): def __init__(self):
super(ProbabilityTable, self).__init__() super(ProbabilityTable, self).__init__()
#self.owner = owner #self.owner = owner
#self.variables = [owner] #self.variables = [owner]
...@@ -25,7 +25,7 @@ class ProbabilityTable(Density): ...@@ -25,7 +25,7 @@ class ProbabilityTable(Density):
ax = self.table.ndim ax = self.table.ndim
self.table=numpy.expand_dims(self.table,ax) self.table=numpy.expand_dims(self.table,ax)
self.table=numpy.repeat(self.table,len(variable.value_range),axis = ax) self.table=numpy.repeat(self.table,len(variable.value_range),axis = ax)
def set_probability_table(self, table, nodes): def set_probability_table(self, table, nodes):
if not set(nodes) == set(self.variables): if not set(nodes) == set(self.variables):
...@@ -41,7 +41,7 @@ class ProbabilityTable(Density): ...@@ -41,7 +41,7 @@ class ProbabilityTable(Density):
self.variables = nodes self.variables = nodes
def set_probability(self, value, node_value_pairs): def set_probability(self, value, node_value_pairs):
index = self.get_cpt_index(node_value_pairs) index = self.get_cpt_index(node_value_pairs)
self.table[tuple(index)]=value self.table[tuple(index)]=value
def get_cpt_index(self, node_value_pairs): def get_cpt_index(self, node_value_pairs):
...@@ -62,53 +62,62 @@ class ProbabilityTable(Density): ...@@ -62,53 +62,62 @@ class ProbabilityTable(Density):
return set(sum_of_owner_probs.flatten()) == set([1]) return set(sum_of_owner_probs.flatten()) == set([1])
def is_normalized_as_jpt(self): def is_normalized_as_jpt(self):
return numpy.sum(table) == 1.0 return numpy.sum(self.table) == 1.0
def multiplication(self, inputFactor): def multiplication(self, inputFactor):
raise Exception("Called unimplemented function")
#init a new probability tabel #init a new probability tabel
factor1 = ProbabilityTable() factor1 = ProbabilityTable()
#all variables from both factors are needed #all variables from both factors are needed
factor1.variables = copy.copy(self.variables) factor1.variables = copy.copy(self.variables)
for v in factor.variables: for v in (inputFactor.variables):
if not v in factor1.variables: if not v in factor1.variables:
factor1.variables.append(v) factor1.variables.append(v)
#the table from the first factor is copied #the table from the first factor is copied
factor1.table = copy.copy(self.table) factor1.table = copy.copy(self.table)
#and extended by the dimensions for the left variables #and extended by the dimensions for the left variables
for curIdx in range(factor1.table.ndim, len(factor1.variables)): for curIdx in range(factor1.table.ndim, len(factor1.variables)):
ax = factor1.table.ndim ax = factor1.table.ndim
factor1.table=numpy.expand_dims(factor1.table,ax) factor1.table=numpy.expand_dims(factor1.table,ax)
factor1.table=numpy.repeat(factor1.table,len(factor1.variables[curIdx].values),axis = ax) factor1.table=numpy.repeat(factor1.table,len(factor1.variables[curIdx].value_range),axis = ax)
#copy factor 2 and it's variables ... #copy factor 2 and it's variables ...
factor2 = ProbabilityTable() factor2 = ProbabilityTable()
factor2.variables = copy.copy(inputFactor.variables) factor2.variables = copy.copy(inputFactor.variables)
factor2.table = copy.copy(inputFactor.table) factor2.table = copy.copy(inputFactor.table)
#extend the dimensions of factors 2 to the dimensions of factor 1 #extend the dimensions of factors 2 to the dimensions of factor 1
for v in factor1.variables: for v in factor1.variables:
if not v in factor2.variables: if not v in factor2.variables:
factor2.variables.append(v) factor2.variables.append(v)
for curIdx in range(factor2.table.ndim, len(factor2.variables)): for curIdx in range(factor2.table.ndim, len(factor2.variables)):
ax = factor2.table.ndim ax = factor2.table.ndim
factor2.table=numpy.expand_dims(factor2.table,ax) factor2.table=numpy.expand_dims(factor2.table,ax)
factor2.table=numpy.repeat(factor2.table,len(factor2.variables[curIdx].values),axis = ax) factor2.table=numpy.repeat(factor2.table,len(factor2.variables[curIdx].value_range),axis = ax)
#sort the variables to the same order #sort the variables to the same order
for endDim,variable in enumerate(factor1.variables):
#pointwise multiplication startDim = factor2.variables.index(variable);
if not startDim == endDim:
factor2.table = numpy.rollaxis(factor2.table, startDim, endDim)
factor2.variables.insert(endDim,factor2.variables.pop(startDim))
#pointwise multiplication
if factor1.table.shape != factor2.table.shape:
raise Exception("Multiplication: The probability tables have the wrong dimensions for unification")
factor1.table = factor1.table *factor2.table;
return factor1
def marginalization(self, variable): def marginalization(self, variable):
raise Exception("Called unimplemented function") raise Exception("Called unimplemented function")
def reduction(self, evidence): def reduction(self, evidence):
'''Returns a reduced version of this ProbabilityTable, evidence is a list of pairs. '''Returns a reduced version of this ProbabilityTable, evidence is a list of pairs.
Important: This node is not being changed!''' Important: This node is not being changed!'''
...@@ -120,13 +129,13 @@ class ProbabilityTable(Density): ...@@ -120,13 +129,13 @@ class ProbabilityTable(Density):
axis=reduced.variables.index(node) axis=reduced.variables.index(node)
position=node.value_range.index(value) position=node.value_range.index(value)
reduced.table = numpy.take(reduced.table,[position],axis=axis) reduced.table = numpy.take(reduced.table,[position],axis=axis)
reduced.table=reduced.table.squeeze() reduced.table=reduced.table.squeeze()
reduced.variables.remove(node) reduced.variables.remove(node)
return reduced return reduced
def division(self, factor): def division(self, factor):
raise Exception("Called unimplemented function") raise Exception("Called unimplemented function")
......
import unittest
import numpy
from primo.reasoning.density import ProbabilityTable
from primo.reasoning import DiscreteNode
class MultiplicationTest(unittest.TestCase):
def setUp(self):
self.pt = ProbabilityTable();
def tearDown(self):
self.pt = None
def test_easy_shape(self):
n1 = DiscreteNode("Some Node", [True, False])
n2 = DiscreteNode("Second Node" , [True, False])
s = n1.get_cpd().multiplication(n2.get_cpd())
self.assertEqual(s.table.shape, (2,2));
s = n1.get_cpd().multiplication(n1.get_cpd())
self.assertEqual(s.table.shape,(2,))
def test_easy_values(self):
n1 = DiscreteNode("Some Node", [True, False])
n2 = DiscreteNode("Second Node" , [True, False])
cpt1 = numpy.array([2,3])
cpt2 = numpy.array([5,7])
n1.set_probability_table(cpt1,[n1])
n2.set_probability_table(cpt2,[n2])
s = n1.get_cpd().multiplication(n2.get_cpd())
cptN = numpy.array([[10,14],[15,21]])
numpy.testing.assert_array_equal(s.table,cptN)
self.assertEqual(s.variables[0],n1)
def test_complicated_multi(self):
n1 = DiscreteNode("Some Node", [True, False])
n2 = DiscreteNode("Second Node" , [True, False,"noIdea"])
cpt1 = numpy.array([2,3])
cpt2 = numpy.array([5,7,9])
n1.set_probability_table(cpt1,[n1])
n2.set_probability_table(cpt2,[n2])
c3 = n1.get_cpd().multiplication(n2.get_cpd())
c3 = n1.get_cpd().multiplication(c3)
cptN = numpy.array([[20, 28, 36],[45, 63, 81]])
numpy.testing.assert_array_equal(c3.table,cptN)
#include this so you can run this test without nose
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment