Skip to content
Snippets Groups Projects
FactorTree.py 4.94 KiB
Newer Older
Denis John PC's avatar
Denis John PC committed

import networkx as nx
import primo.reasoning.density.ProbabilityTable as ProbabilityTable
Denis John PC's avatar
Denis John PC committed


class FactorTree(object):
Denis John PC's avatar
Denis John PC committed
    '''The factor tree contains for each node of the BayesNet a factor. It
Denis John PC's avatar
Denis John PC committed
    is a directed graph with one root node. To speed up the reasoning it uses
Denis John PC's avatar
Denis John PC committed
    a message based approach which stores calculated intermediate results at
    edges. Thus, the first query is expensive and all following are easy calculated.
    The speed of the first message calculation depends on how the tree was build.
    Literature: Modeling and Reasoning with Bayesian Networks - Adnan Darwiche
    Chapter 7    
    '''
Denis John PC's avatar
Denis John PC committed
    
    
    def __init__(self,graph,rootNode):
        self.graph = graph
        self.rootNode = rootNode
        
    def calculate_PoE(self):
Denis John PC's avatar
Denis John PC committed
        '''Calculates the probability of evidence with the set evidence'''
        if not self.graph.graph['messagesValid']:
            self.calculate_messages()
            
        cpd = self.calculate_marginal_forOne(self.rootNode)
        
        for v in cpd.get_variables()[:]:
            cpd = cpd.marginalization(v)
            
        return cpd
        
Denis John PC's avatar
Denis John PC committed
    def calculate_marginal(self,variables):
Denis John PC's avatar
Denis John PC committed
        ''' If evidence is set, then this methods calculates the posterior marginal.
        With an empty evidence this is automatically the prior marginal.'''
Denis John PC's avatar
Denis John PC committed
        if not self.graph.graph['messagesValid']:
            self.calculate_messages()
Denis John PC's avatar
Denis John PC committed
            
        resPT = ProbabilityTable.get_neutral_multiplication_PT()
        
Denis John PC's avatar
Denis John PC committed
            
        for f in self.graph.nodes():
            if f.get_node() in variables:
Denis John PC's avatar
Denis John PC committed
                resPT = resPT.multiplication(self.calculate_marginal_forOne(f))
        resPT = resPT.normalize_as_jpt()
                
Denis John PC's avatar
Denis John PC committed
        return resPT
                
    def calculate_marginal_forOne(self,factor):
        curCPD = factor.get_calculation_CDP().copy()
Denis John PC's avatar
Denis John PC committed
        for p in self.graph.predecessors(factor):
Denis John PC's avatar
Denis John PC committed
            tmpCPD = self.graph[p][factor]['msgRightWay']
            curCPD = curCPD.multiplication(tmpCPD)
Denis John's avatar
Denis John committed
                      
Denis John PC's avatar
Denis John PC committed
        for p in self.graph.neighbors(factor):
Denis John PC's avatar
Denis John PC committed
            tmpCPD = self.graph[factor][p]['msgAgainstWay']
            curCPD = curCPD.multiplication(tmpCPD)
Denis John PC's avatar
Denis John PC committed
            
        for v in curCPD.get_variables()[:]:
            if v != factor.get_node():
                curCPD = curCPD.marginalization(v)
                
        return curCPD
        
        
Denis John PC's avatar
Denis John PC committed
    def draw(self):
Denis John PC's avatar
Denis John PC committed
        '''Draws the FactorTree'''
Denis John PC's avatar
Denis John PC committed
        import matplotlib.pyplot as plt
        nx.draw_circular(self.graph)
        plt.show()
    def calculate_messages(self):
Denis John PC's avatar
Denis John PC committed
        ''' Calculates the messages and stores the intermediate results.'''
Denis John PC's avatar
Denis John PC committed
        self.pull_phase(self.rootNode,self.graph)
        self.push_phase(self.rootNode,self.graph,ProbabilityTable.get_neutral_multiplication_PT())
        self.graph.graph['messagesValid'] = True
    def set_evidences(self,evidences):
Denis John PC's avatar
Denis John PC committed
        self.graph.graph['messagesValid'] = False
        evNodes = zip(*evidences)
       
        for factor in self.graph.nodes():
            if factor.get_node() in evNodes[0]:
                idx = evNodes[0].index(factor.get_node())
                factor.set_evidence(evidences[idx])
        
    
        
        
    def pull_phase(self,factor,graph):
        
Denis John PC's avatar
Denis John PC committed
        calCPD = factor.get_calculation_CDP()
        #calculate the messages of the children
        for child in graph.neighbors(factor):
            tmpInput = self.pull_phase(child,graph)
Denis John PC's avatar
Denis John PC committed
            #project each factor on the specific separator
            separator = graph[factor][child]['separator']
            for var in tmpInput.variables[:]:
Denis John PC's avatar
Denis John PC committed
                if var not in separator:
                    tmpInput = tmpInput.marginalization(var)
                
Denis John PC's avatar
Denis John PC committed
            
            #save message on edge: it's the opposite of the direction of the edge
            graph[factor][child]['msgAgainstWay'] = tmpInput 
            #calculate the new message
            calCPD = calCPD.multiplication(tmpInput)
              
        return calCPD
        
    def push_phase(self,factor,graph,inCPD):
        for child in graph.neighbors(factor):
            tmpCPD = inCPD.multiplication(factor.get_calculation_CDP())
            for child2 in graph.neighbors(factor):
                if (child != child2):
Denis John PC's avatar
Denis John PC committed
                    tmpCPD = tmpCPD.multiplication(graph[factor][child2]['msgAgainstWay'])
Denis John PC's avatar
Denis John PC committed
            separator = graph[factor][child]['separator']
            #project on outgoing edge separator
            for var in tmpCPD.variables:
                if var not in separator:
                    tmpCPD = tmpCPD.marginalization(var)
            
            #add setOut to outgoing vars from child
Denis John PC's avatar
Denis John PC committed
            #Message with the direction of the edge
            graph[factor][child]['msgRightWay'] = tmpCPD
                
           
            self.push_phase(child,graph,tmpCPD)