import networkx as nx
import primo.reasoning.density.ProbabilityTable as ProbabilityTable



class FactorTree(object):
    
    
    def __init__(self,graph,rootNode):
        self.graph = graph
        self.rootNode = rootNode
        self.graph['messagesValid'] = False
        
    def calculate_marginal(self,variables):
        if not self.graph['messagesValid']:
            self.calculateMessages()
            
        resPT = ProbabilityTable()
            
        for f in self.graph.nodes():
            if f.get_node() in variables:
                resPT = resPT.multiplication.calculate_marginal_forOne(f)
                
        return resPT
                
    def calculate_marginal_forOne(self,factor):
        curCPD = factor.get_calculation_CDP().copy()
        for p in self.graph.predecessors(factor):
            curCPD = curCPD.multiplication(self.graph[p][factor]['inMessage'])
            
        for p in self.graph.neighbors(factor):
            curCPD = curCPD.multiplication(self.graph[factor][p]['outMessage'])
            
        for v in curCPD.get_variables()[:]:
            if v != factor.get_node():
                curCPD = curCPD.marginalization(v)
                
        return curCPD
        
        
    def draw(self):
        import matplotlib.pyplot as plt
        nx.draw_circular(self.graph)
        plt.show()
        
    def calculateMessages(self):
        self.push_phase(self.rootNode,self.graph)
        self.pull_phase(self.rootNode,self.graph,ProbabilityTable())
        self.graph['messagesValid'] = True
        
        
    def setEvidences(self,evidences):
        self.graph['messagesValid'] = False
        
        evNodes = zip(*evidences)        
        
        for factor in self.graph.get_all_nodes():
            if factor.get_node() in evNodes:
                idx = evNodes.index(factor.get_node())
                factor.set_evidence(evidences[idx])
        
    
        
        
    def pull_phase(self,factor,graph):
        
        calCPD = factor.get_calculate_CPD()
        #calculate the messages of the children
        for child in graph.neighbors(factor):
            tmpInput = self.pull_phase(child,graph)
            #project each factor on the specific seperator
            seperator = graph[factor][child]['seperator']
            for var in tmpInput.variables[:]:
                if var not in seperator:
                    tmpInput = tmpInput.marginalization(var)
                
            #save message on edge
            graph[factor][child]['inMessage'] = tmpInput
            #calculate the new message
            calCPD = calCPD.multiplication(tmpInput)
              
        return calCPD
        
    def push_phase(self,factor,graph,inCPD):
       

        for child in graph.neighbors(factor):
            tmpCPD = inCPD.copy()
            for child2 in graph.neighbors(factor):
                if (child != child2):
                    tmpCPD = tmpCPD.multiplication(graph[factor][child2]['inMessage'])
            
            seperator = graph[factor][child]['seperator']
            #project on outgoing edge seperator
            for var in tmpCPD.variables[:]:
                if var not in seperator:
                    tmpCPD = tmpCPD.marginalization(var)
            
            #add setOut to outgoing vars from child
            graph[factor][child]['outMessage'] = tmpCPD
                
           
            self.push_phase(child,graph,tmpCPD)