Skip to content
Snippets Groups Projects
DiscreteNode.py 1.71 KiB
Newer Older
Lukas Kettenbach's avatar
Lukas Kettenbach committed
# -*- coding: utf-8 -*-

from primo.reasoning import RandomNode
from primo.reasoning.density import ProbabilityTable
Lukas Kettenbach's avatar
Lukas Kettenbach committed


class DiscreteNode(RandomNode):
    '''#TODO: write doc'''

    def __init__(self, name, value_range):
        super(DiscreteNode, self).__init__(name)

        self.value_range = value_range
        self.cpd = ProbabilityTable()
        self.cpd.add_variable(self)
Manuel Baum's avatar
Manuel Baum committed
        
    def __str__(self):
        return self.name
        
    def __repr__(self):
        return "DiscreteNode("+self.name+")"
Lukas Kettenbach's avatar
Lukas Kettenbach committed

    def set_probability(self, value, node_value_pairs):
        self.cpd.set_probability(value, node_value_pairs)
        
    def get_probability(self, value, node_value_pairs):
        return self.cpd.get_probability([(self,value)] + node_value_pairs)
Lukas Kettenbach's avatar
Lukas Kettenbach committed

    def set_probability_table(self, table, nodes):
        self.cpd.set_probability_table(table, nodes)

    def is_valid(self):
        return self.cpd.is_normalized_as_cpt(self)
Manuel Baum's avatar
Manuel Baum committed
    def sample_global(self, state, evidence):
        if evidence==None or not self in evidence.keys():
            compatibles=self.value_range
        else:
            compatibles=[]
            for v in self.value_range:
Manuel Baum's avatar
Manuel Baum committed
        return self.cpd.sample_global(state,self,compatibles)
        
    def sample_local(self, x, evidence=None):
        if evidence==None or not self in evidence.keys():
            compatibles=self.value_range
        else:
            compatibles=[]
            for v in self.value_range:
Manuel Baum's avatar
Manuel Baum committed
        return random.choice(compatibles),1.0