-
Denis John PC authoredDenis John PC authored
ProbabilityTable.py 5.15 KiB
# -*- coding: utf-8 -*-
import numpy
import copy
from primo.reasoning.density import Density
class ProbabilityTable(Density):
'''TODO: write doc'''
def __init__(self):
super(ProbabilityTable, self).__init__()
#self.owner = owner
#self.variables = [owner]
#size_of_range = len(owner.value_range)
#self.table = numpy.ones(size_of_range) / size_of_range
self.variables = []
self.table = numpy.array(0)
def add_variable(self, variable):
self.variables.append(variable)
ax = self.table.ndim
self.table=numpy.expand_dims(self.table,ax)
self.table=numpy.repeat(self.table,len(variable.value_range),axis = ax)
def set_probability_table(self, table, nodes):
if not set(nodes) == set(self.variables):
raise Exception("The list which should define the ordering of the variables does not match"
" the variables that this cpt depends on (plus the node itself)")
if not self.table.ndim == table.ndim:
raise Exception("The provided probability table does not have the right number of dimensions")
for d,node in enumerate(nodes):
if len(node.value_range) != table.shape[d]:
raise Exception("The size of the provided probability table does not match the number of possible values of the node "+node.name+" in dimension "+str(d))
self.table = table
self.variables = nodes
def set_probability(self, value, node_value_pairs):
index = self.get_cpt_index(node_value_pairs)
self.table[tuple(index)]=value
def get_cpt_index(self, node_value_pairs):
nodes, values = zip(*node_value_pairs)
index = []
for node in self.variables:
index_in_values_list = nodes.index(node)
value = values[index_in_values_list]
index.append(node.value_range.index(value))
return index
def is_normalized_as_cpt(self,owner):
dim_of_owner = self.variables.index(owner)
sum_of_owner_probs = numpy.sum(self.table, dim_of_owner)
return set(sum_of_owner_probs.flatten()) == set([1])
def is_normalized_as_jpt(self):
return numpy.sum(self.table) == 1.0
def multiplication(self, inputFactor):
#init a new probability tabel
factor1 = ProbabilityTable()
#all variables from both factors are needed
factor1.variables = copy.copy(self.variables)
for v in (inputFactor.variables):
if not v in factor1.variables:
factor1.variables.append(v)
#the table from the first factor is copied
factor1.table = copy.copy(self.table)
#and extended by the dimensions for the left variables
for curIdx in range(factor1.table.ndim, len(factor1.variables)):
ax = factor1.table.ndim
factor1.table=numpy.expand_dims(factor1.table,ax)
factor1.table=numpy.repeat(factor1.table,len(factor1.variables[curIdx].value_range),axis = ax)
#copy factor 2 and it's variables ...
factor2 = ProbabilityTable()
factor2.variables = copy.copy(inputFactor.variables)
factor2.table = copy.copy(inputFactor.table)
#extend the dimensions of factors 2 to the dimensions of factor 1
for v in factor1.variables:
if not v in factor2.variables:
factor2.variables.append(v)
for curIdx in range(factor2.table.ndim, len(factor2.variables)):
ax = factor2.table.ndim
factor2.table=numpy.expand_dims(factor2.table,ax)
factor2.table=numpy.repeat(factor2.table,len(factor2.variables[curIdx].value_range),axis = ax)
#sort the variables to the same order
for endDim,variable in enumerate(factor1.variables):
startDim = factor2.variables.index(variable);
if not startDim == endDim:
factor2.table = numpy.rollaxis(factor2.table, startDim, endDim)
factor2.variables.insert(endDim,factor2.variables.pop(startDim))
#pointwise multiplication
if factor1.table.shape != factor2.table.shape:
raise Exception("Multiplication: The probability tables have the wrong dimensions for unification")
factor1.table = factor1.table *factor2.table;
return factor1
def marginalization(self, variable):
raise Exception("Called unimplemented function")
def reduction(self, evidence):
'''Returns a reduced version of this ProbabilityTable, evidence is a list of pairs.
Important: This node is not being changed!'''
reduced = ProbabilityTable()
reduced.variables = copy.copy(self.variables)
reduced.table = self.table
for node,value in evidence:
axis=reduced.variables.index(node)
position=node.value_range.index(value)
reduced.table = numpy.take(reduced.table,[position],axis=axis)
reduced.table=reduced.table.squeeze()
reduced.variables.remove(node)
return reduced
def division(self, factor):
raise Exception("Called unimplemented function")
def __str__(self):
return str(self.table)