Skip to content
Snippets Groups Projects
EasiestFactorElimination.py 3.14 KiB
Newer Older
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from  primo.core import BayesNet
from  primo.reasoning import DiscreteNode
import numpy

class EasiestFactorElimination(object):
Denis John PC's avatar
Denis John PC committed
    '''This is the easiest way for factor elimination.It's has the worst runtime because:
    1.* Needed evidences are set.
    2. All nodes are multiplied.
    3. The redundant variables are summed out'''
    def __init__(self,bayesNet):
        self.bn= bayesNet
Denis John PC's avatar
Denis John PC committed
    def calculate_PriorMarginal(self,variables):  
        '''Calculates the prior marignal for the given variables. The resulting
        CPD is returned.'''
        nodes = self.bn.get_all_nodes()
        
        finCpd = nodes.pop().get_cpd()
        
        for n in nodes:
            finCpd = finCpd.multiplication(n.get_cpd())
            
        for v in finCpd.get_variables()[:]:
            if v not in variables:
                finCpd = finCpd.marginalization(v)
        
        return finCpd
        
    def calculate_PosteriorMarginal(self,variables,evidence):
Denis John PC's avatar
Denis John PC committed
        '''Calculates the posterior marginal for given variables and evidence.
        It returns the resulting cpd.'''
Denis John's avatar
Denis John committed
        nodes = self.bn.get_all_nodes()
Denis John PC's avatar
Denis John PC committed
        #List of evidences
        ev_list = zip(*evidence)     
Denis John PC's avatar
Denis John PC committed
        # Special Case: First Node
Denis John's avatar
Denis John committed
        node1 = nodes.pop()
Denis John PC's avatar
Denis John PC committed
        if node1 in ev_list[0]:
            ind = ev_list[0].index(node1)
Denis John's avatar
Denis John committed
            finCpd = node1.get_cpd().set_evidence(evidence[ind])
            
        else:
            finCpd = node1.get_cpd()
Denis John PC's avatar
Denis John PC committed
            
            
        # For all other nodes
Denis John's avatar
Denis John committed
        for n in nodes:
Denis John PC's avatar
Denis John PC committed
            if n in ev_list[0]:
Denis John PC's avatar
Denis John PC committed
                #Set evidence and multiply
Denis John PC's avatar
Denis John PC committed
                ind = ev_list[0].index(n)
Denis John's avatar
Denis John committed
                nCPD = n.get_cpd().set_evidence(evidence[ind])
Denis John PC's avatar
Denis John PC committed
                finCpd = finCpd.multiplication(nCPD)            
Denis John's avatar
Denis John committed
            else:
Denis John PC's avatar
Denis John PC committed
                #only multiply
Denis John's avatar
Denis John committed
                finCpd = finCpd.multiplication(n.get_cpd())
Denis John's avatar
Denis John committed
        for v in finCpd.get_variables()[:]:
Denis John PC's avatar
Denis John PC committed
            if v not in variables:
                finCpd = finCpd.marginalization(v)
                
        finCpd = finCpd.normalize_as_jpt()
        
Denis John's avatar
Denis John committed
            
        return finCpd
Denis John's avatar
Denis John committed
    def calculate_PoE(self,evidence):
Denis John PC's avatar
Denis John PC committed
        ''' Calculates the probabilty of evidence for the given evidence and returns the result.'''
        
Denis John's avatar
Denis John committed
        nodes = self.bn.get_all_nodes()
        
        unzipped_list = zip(*evidence)
              
        node1 = nodes.pop()
        if node1 in unzipped_list[0]:
            ind = unzipped_list[0].index(node1)
            finCpd = node1.get_cpd().set_evidence(evidence[ind])
            
        else:
            finCpd = node1.get_cpd()
                  
        for n in nodes:
            if n in unzipped_list[0]:
                ind = unzipped_list[0].index(n)
                nCPD = n.get_cpd().set_evidence(evidence[ind])
                finCpd = finCpd.multiplication(nCPD)
            else:
                finCpd = finCpd.multiplication(n.get_cpd())
          
        for v in finCpd.get_variables()[:]:
            finCpd = finCpd.marginalization(v)
            
        return finCpd