Skip to content
Snippets Groups Projects
FactorTree.py 3.43 KiB
Newer Older
Denis John PC's avatar
Denis John PC committed

import networkx as nx
import primo.reasoning.density.ProbabilityTable as ProbabilityTable
Denis John PC's avatar
Denis John PC committed



class FactorTree(object):
    
    
    def __init__(self,graph,rootNode):
        self.graph = graph
        self.rootNode = rootNode
        self.graph['messagesValid'] = False
Denis John PC's avatar
Denis John PC committed
        
Denis John PC's avatar
Denis John PC committed
    def calculate_marginal(self,variables):
        if not self.graph['messagesValid']:
            self.calculateMessages()
            
        resPT = ProbabilityTable()
            
        for f in self.graph.nodes():
            if f.get_node() in variables:
                resPT = resPT.multiplication.calculate_marginal_forOne(f)
                
        return resPT
                
    def calculate_marginal_forOne(self,factor):
        curCPD = factor.get_calculation_CDP().copy()
        for p in self.graph.predecessors(factor):
            curCPD = curCPD.multiplication(self.graph[p][factor]['inMessage'])
            
        for p in self.graph.neighbors(factor):
            curCPD = curCPD.multiplication(self.graph[factor][p]['outMessage'])
            
        for v in curCPD.get_variables()[:]:
            if v != factor.get_node():
                curCPD = curCPD.marginalization(v)
                
        return curCPD
        
        
Denis John PC's avatar
Denis John PC committed
    def draw(self):
        import matplotlib.pyplot as plt
        nx.draw_circular(self.graph)
        plt.show()
    def calculateMessages(self):
        self.push_phase(self.rootNode,self.graph)
        self.pull_phase(self.rootNode,self.graph,ProbabilityTable())
        self.graph['messagesValid'] = True
        
        
    def setEvidences(self,evidences):
        self.graph['messagesValid'] = False
        
        evNodes = zip(*evidences)        
        
        for factor in self.graph.get_all_nodes():
            if factor.get_node() in evNodes:
                idx = evNodes.index(factor.get_node())
                factor.set_evidence(evidences[idx])
        
    
        
        
    def pull_phase(self,factor,graph):
        
        calCPD = factor.get_calculate_CPD()
        #calculate the messages of the children
        for child in graph.neighbors(factor):
            tmpInput = self.pull_phase(child,graph)
            #project each factor on the specific seperator
            seperator = graph[factor][child]['seperator']
            for var in tmpInput.variables[:]:
                if var not in seperator:
                    tmpInput = tmpInput.marginalization(var)
                
            #save message on edge
            graph[factor][child]['inMessage'] = tmpInput
            #calculate the new message
            calCPD = calCPD.multiplication(tmpInput)
              
        return calCPD
        
    def push_phase(self,factor,graph,inCPD):

        for child in graph.neighbors(factor):
            tmpCPD = inCPD.copy()
            for child2 in graph.neighbors(factor):
                if (child != child2):
                    tmpCPD = tmpCPD.multiplication(graph[factor][child2]['inMessage'])
            
            seperator = graph[factor][child]['seperator']
            #project on outgoing edge seperator
            for var in tmpCPD.variables[:]:
                if var not in seperator:
                    tmpCPD = tmpCPD.marginalization(var)
            
            #add setOut to outgoing vars from child
            graph[factor][child]['outMessage'] = tmpCPD
                
           
            self.push_phase(child,graph,tmpCPD)