Skip to content
Snippets Groups Projects
EasiestFactorElimination.py 2.99 KiB
Newer Older
  • Learn to ignore specific revisions
  • #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    from  primo.core import BayesNet
    from  primo.reasoning import DiscreteNode
    import numpy
    
    class EasiestFactorElimination(object):
        '''This is the easiest way for factor elimination. But not
        very efficient.'''
        
        
        
        def __init__(self):
            self.bn= BayesNet()
    
        def set_BayesNet(self,bayesnet):
            self.bn = bayesnet
            
        def calculate_PriorMarginal(self,variables):        
            nodes = self.bn.get_all_nodes()
            
            finCpd = nodes.pop().get_cpd()
            
            for n in nodes:
                finCpd = finCpd.multiplication(n.get_cpd())
                
            for v in finCpd.get_variables()[:]:
                if v not in variables:
                    finCpd = finCpd.marginalization(v)
            
            return finCpd
    
            
        def calculate_PosteriorMarginal(self,variables,evidence):
    
    Denis John's avatar
    Denis John committed
            #TODO
            # Verbundwahrscheinlichkeit / PoE
            # Erst wie Prior Marginal nur mit setzen der Evidence
            # Dann PoE berechnen und damit normalisieren
            nodes = self.bn.get_all_nodes()
    
            ev_list = zip(*evidence)     
    
    Denis John PC's avatar
    Denis John PC committed
            # Special Case: First Node
    
    Denis John's avatar
    Denis John committed
            node1 = nodes.pop()
    
    Denis John PC's avatar
    Denis John PC committed
            if node1 in ev_list[0]:
                ind = ev_list[0].index(node1)
    
    Denis John's avatar
    Denis John committed
                finCpd = node1.get_cpd().set_evidence(evidence[ind])
                
            else:
                finCpd = node1.get_cpd()
    
    Denis John PC's avatar
    Denis John PC committed
                
                
            # For all other nodes
    
    Denis John's avatar
    Denis John committed
            for n in nodes:
    
    Denis John PC's avatar
    Denis John PC committed
                if n in ev_list[0]:
                    ind = ev_list[0].index(n)
    
    Denis John's avatar
    Denis John committed
                    nCPD = n.get_cpd().set_evidence(evidence[ind])
    
    Denis John PC's avatar
    Denis John PC committed
                    finCpd = finCpd.multiplication(nCPD)            
    
    Denis John's avatar
    Denis John committed
                else:
                    finCpd = finCpd.multiplication(n.get_cpd())
    
    Denis John's avatar
    Denis John committed
            for v in finCpd.get_variables()[:]:
    
    Denis John PC's avatar
    Denis John PC committed
                if v not in variables:
                    finCpd = finCpd.marginalization(v)
                    
            finCpd = finCpd.normalize_as_jpt()
            
            #unityMarg = finCpd
            
            #for v in finCpd.get_variables()[:]:
            #    finCpd = finCpd.marginalization(v)
                
            #unityMarg /= finCpd
    
    Denis John's avatar
    Denis John committed
                
            return finCpd
    
    Denis John's avatar
    Denis John committed
        def calculate_PoE(self,evidence):
            nodes = self.bn.get_all_nodes()
            
            unzipped_list = zip(*evidence)
                  
            node1 = nodes.pop()
            if node1 in unzipped_list[0]:
                ind = unzipped_list[0].index(node1)
                finCpd = node1.get_cpd().set_evidence(evidence[ind])
                
            else:
                finCpd = node1.get_cpd()
                      
            for n in nodes:
                if n in unzipped_list[0]:
                    ind = unzipped_list[0].index(n)
                    nCPD = n.get_cpd().set_evidence(evidence[ind])
                    finCpd = finCpd.multiplication(nCPD)
                else:
                    finCpd = finCpd.multiplication(n.get_cpd())
              
            for v in finCpd.get_variables()[:]:
                finCpd = finCpd.marginalization(v)
                
            return finCpd