From a73b004b46d0c130cbc9b70eb546270a7bd9db54 Mon Sep 17 00:00:00 2001 From: Denis John PC <djohn@techfak.uni-bielefeld.de> Date: Sat, 26 Jan 2013 13:39:55 +0100 Subject: [PATCH] multiplication started --- primo/reasoning/density/ProbabilityTable.py | 43 ++++++++++++++++++++- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/primo/reasoning/density/ProbabilityTable.py b/primo/reasoning/density/ProbabilityTable.py index 817548b..2c60cc8 100644 --- a/primo/reasoning/density/ProbabilityTable.py +++ b/primo/reasoning/density/ProbabilityTable.py @@ -64,8 +64,47 @@ class ProbabilityTable(Density): def is_normalized_as_jpt(self): return numpy.sum(table) == 1.0 - def multiplication(self, factor): - raise Exception("Called unimplemented function") + def multiplication(self, inputFactor): + raise Exception("Called unimplemented function") + #init a new probability tabel + factor1 = ProbabilityTable() + + #all variables from both factors are needed + factor1.variables = copy.copy(self.variables) + + for v in factor.variables: + if not v in factor1.variables: + factor1.variables.append(v) + + #the table from the first factor is copied + factor1.table = copy.copy(self.table) + + #and extended by the dimensions for the left variables + for curIdx in range(factor1.table.ndim, len(factor1.variables)): + ax = factor1.table.ndim + factor1.table=numpy.expand_dims(factor1.table,ax) + factor1.table=numpy.repeat(factor1.table,len(factor1.variables[curIdx].values),axis = ax) + + #copy factor 2 and it's variables ... + factor2 = ProbabilityTable() + factor2.variables = copy.copy(inputFactor.variables) + factor2.table = copy.copy(inputFactor.table) + + #extend the dimensions of factors 2 to the dimensions of factor 1 + for v in factor1.variables: + if not v in factor2.variables: + factor2.variables.append(v) + + for curIdx in range(factor2.table.ndim, len(factor2.variables)): + ax = factor2.table.ndim + factor2.table=numpy.expand_dims(factor2.table,ax) + factor2.table=numpy.repeat(factor2.table,len(factor2.variables[curIdx].values),axis = ax) + + #sort the variables to the same order + + #pointwise multiplication + + def marginalization(self, variable): raise Exception("Called unimplemented function") -- GitLab