import networkx as nx
import primo.reasoning.density.ProbabilityTable as ProbabilityTable


class FactorTree(object):
    '''The factor tree contains for each node of the BayesNet a factor. It
    is a directed graph with one root node. To speed up the reasoning it uses
    a message based approach which stores calculated intermediate results at
    edges. Thus, the first query is expensive and all following are easy calculated.
    The speed of the first message calculation depends on how the tree was build.
    Literature: Modeling and Reasoning with Bayesian Networks - Adnan Darwiche
    Chapter 7    
    '''
    
    
    def __init__(self,graph,rootNode):
        self.graph = graph
        self.rootNode = rootNode
        
    def calculate_PoE(self):
        '''Calculates the probability of evidence with the set evidence'''
        if not self.graph.graph['messagesValid']:
            self.calculate_messages()
            
        cpd = self.calculate_marginal_forOne(self.rootNode)
        
        for v in cpd.get_variables()[:]:
            cpd = cpd.marginalization(v)
            
        return cpd
        
    def calculate_marginal(self,variables):
        ''' If evidence is set, then this methods calculates the posterior marginal.
        With an empty evidence this is automatically the prior marginal.'''
        if not self.graph.graph['messagesValid']:
            self.calculate_messages()
            
            
        resPT = ProbabilityTable.get_neutral_multiplication_PT()
        
            
        for f in self.graph.nodes():
            if f.get_node() in variables:
                resPT = resPT.multiplication(self.calculate_marginal_forOne(f))
                
        resPT = resPT.normalize_as_jpt()
                
        return resPT
                
    def calculate_marginal_forOne(self,factor):
        curCPD = factor.get_calculation_CDP().copy()
        
        for p in self.graph.predecessors(factor):
            tmpCPD = self.graph[p][factor]['msgRightWay']
            curCPD = curCPD.multiplication(tmpCPD)
                      
        for p in self.graph.neighbors(factor):
            tmpCPD = self.graph[factor][p]['msgAgainstWay']
            curCPD = curCPD.multiplication(tmpCPD)
            
        for v in curCPD.get_variables()[:]:
            if v != factor.get_node():
                curCPD = curCPD.marginalization(v)
                
        return curCPD
        
        
    def draw(self):
        '''Draws the FactorTree'''
        import matplotlib.pyplot as plt
        nx.draw_circular(self.graph)
        plt.show()
        
    def calculate_messages(self):
        ''' Calculates the messages and stores the intermediate results.'''
        self.pull_phase(self.rootNode,self.graph)
        self.push_phase(self.rootNode,self.graph,ProbabilityTable.get_neutral_multiplication_PT())
        self.graph.graph['messagesValid'] = True
        
        
    def set_evidences(self,evidences):
        self.graph.graph['messagesValid'] = False
        
        evNodes = zip(*evidences)
       
        for factor in self.graph.nodes():
            if factor.get_node() in evNodes[0]:
                idx = evNodes[0].index(factor.get_node())
                factor.set_evidence(evidences[idx])
        
    
        
        
    def pull_phase(self,factor,graph):
        
        calCPD = factor.get_calculation_CDP()
        #calculate the messages of the children
        for child in graph.neighbors(factor):
            tmpInput = self.pull_phase(child,graph)
            
            
            #project each factor on the specific separator
            separator = graph[factor][child]['separator']
            for var in tmpInput.variables[:]:
                if var not in separator:
                    tmpInput = tmpInput.marginalization(var)
                
            
            #save message on edge: it's the opposite of the direction of the edge
            graph[factor][child]['msgAgainstWay'] = tmpInput 
            #calculate the new message
            calCPD = calCPD.multiplication(tmpInput)
              
        return calCPD
        
    def push_phase(self,factor,graph,inCPD):
       
        for child in graph.neighbors(factor):
            tmpCPD = inCPD.multiplication(factor.get_calculation_CDP())
            for child2 in graph.neighbors(factor):
                if (child != child2):
                    tmpCPD = tmpCPD.multiplication(graph[factor][child2]['msgAgainstWay'])
            
            separator = graph[factor][child]['separator']
            #project on outgoing edge separator
            for var in tmpCPD.variables:
                if var not in separator:
                    tmpCPD = tmpCPD.marginalization(var)
            
            #add setOut to outgoing vars from child
            #Message with the direction of the edge
            graph[factor][child]['msgRightWay'] = tmpCPD
                
           
            self.push_phase(child,graph,tmpCPD)