Skip to content
Snippets Groups Projects
Commit b65cd2f6 authored by Hendrik Buschmeier's avatar Hendrik Buschmeier
Browse files

Refactoring: Moved decision networks.

parent 8679ad44
No related branches found
No related tags found
No related merge requests found
...@@ -5,10 +5,7 @@ import re ...@@ -5,10 +5,7 @@ import re
import scipy import scipy
from primo.decision.UtilityTable import UtilityTable from primo.decision.UtilityTable import UtilityTable
from primo.reasoning.density import Beta import primo.reasoning.density
from primo.reasoning.density import Exponential
from primo.reasoning.density import Gauss
from primo.reasoning.density import ProbabilityTable
class Node(object): class Node(object):
...@@ -102,7 +99,7 @@ class DiscreteNode(RandomNode): ...@@ -102,7 +99,7 @@ class DiscreteNode(RandomNode):
super(DiscreteNode, self).__init__(name) super(DiscreteNode, self).__init__(name)
self.value_range = value_range self.value_range = value_range
self.cpd = ProbabilityTable() self.cpd = primo.reasoning.density.ProbabilityTable()
self.cpd.add_variable(self) self.cpd.add_variable(self)
def __str__(self): def __str__(self):
...@@ -279,7 +276,11 @@ class ContinuousNodeFactory(object): ...@@ -279,7 +276,11 @@ class ContinuousNodeFactory(object):
@param name: The name of the node. @param name: The name of the node.
''' '''
return self.createContinuousNode(name,(-float("Inf"),float("Inf")),Gauss) return self.createContinuousNode(
name,
(-float("Inf"),
float("Inf")),
primo.reasoning.density.Gauss)
def createExponentialNode(self, name): def createExponentialNode(self, name):
''' '''
...@@ -287,7 +288,10 @@ class ContinuousNodeFactory(object): ...@@ -287,7 +288,10 @@ class ContinuousNodeFactory(object):
@param name: The name of the node. @param name: The name of the node.
''' '''
return self.createContinuousNode(name,(0,float("Inf")),Exponential) return self.createContinuousNode(
name,
(0,float("Inf")),
primo.reasoning.density.Exponential)
def createBetaNode(self, name): def createBetaNode(self, name):
''' '''
...@@ -295,9 +299,12 @@ class ContinuousNodeFactory(object): ...@@ -295,9 +299,12 @@ class ContinuousNodeFactory(object):
@param name: The name of the node. @param name: The name of the node.
''' '''
return self.createContinuousNode(name,(0,1),Beta) return self.createContinuousNode(
name,
(0, 1),
primo.reasoning.density.Beta)
def createContinuousNode(self,name,value_range,DensityClass): def createContinuousNode(self,name,value_range,density_class):
''' '''
Create a ContinuousNode. This method should only be invoked from Create a ContinuousNode. This method should only be invoked from
outside this class if no specialized method is available. outside this class if no specialized method is available.
...@@ -308,7 +315,10 @@ class ContinuousNodeFactory(object): ...@@ -308,7 +315,10 @@ class ContinuousNodeFactory(object):
@param DensityClass: A class from primo.reasoning.density that shall be @param DensityClass: A class from primo.reasoning.density that shall be
the node's pdf the node's pdf
''' '''
return ContinuousNode(name,value_range,DensityClass) return ContinuousNode(
name,
value_range,
density_class)
class DecisionNode(Node): class DecisionNode(Node):
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import itertools import itertools
from primo.core import BayesianDecisionNetwork import operator
from primo.decision import DecisionNode
from primo.decision import UtilityTable import numpy as np
from primo.reasoning import DiscreteNode
import primo.nodes
class UtilityTable(object):
'''
self.variables -- list of the parent nodes
self.table -- utility table which contains the utility
'''
def __init__(self):
super(UtilityTable, self).__init__()
self.table = np.array(0)
self.variables = []
def add_variable(self, variable):
self.variables.append(variable)
ax = self.table.ndim
self.table=np.expand_dims(self.table,ax)
self.table=np.repeat(self.table,len(variable.value_range),axis = ax)
def get_ut_index(self, node_value_pairs):
nodes, values = zip(*node_value_pairs)
index = []
for node in self.variables:
index_in_values_list = nodes.index(node)
value = values[index_in_values_list]
index.append(node.value_range.index(value))
return tuple(index)
def set_utility_table(self, table, nodes):
if not set(nodes) == set(self.variables):
raise Exception("The list which should define the ordering of the variables does not match"
" the variables that this cpt depends on (plus the node itself)")
if not self.table.ndim == table.ndim:
raise Exception("The provided probability table does not have the right number of dimensions")
for d,node in enumerate(nodes):
if len(node.value_range) != table.shape[d]:
raise Exception("The size of the provided probability table does not match the number of possible values of the node "+node.name+" in dimension "+str(d))
self.table = table
self.variables = nodes
def set_utility(self, value, node_value_pairs):
index = self.get_ut_index(node_value_pairs)
self.table[index]=value
def get_utility_table(self):
return self.table
def get_variables(self):
return self.variables
def get_utility(self, node_value_pairs):
index = self.get_ut_index(node_value_pairs)
return self.table[index]
def __str__(self):
return str(self.table)
class MakeDecision(object): class MakeDecision(object):
""" """
...@@ -65,7 +124,7 @@ class MakeDecision(object): ...@@ -65,7 +124,7 @@ class MakeDecision(object):
#Check if the Decision Nodes that are ordered before the provided Decision Node have a state #Check if the Decision Nodes that are ordered before the provided Decision Node have a state
for node in partialOrder: for node in partialOrder:
if isinstance(node, DecisionNode): if isinstance(node, primo.nodes.DecisionNode):
if not decisionNode.name == node.name: if not decisionNode.name == node.name:
if node.get_state() is None: if node.get_state() is None:
raise Exception("Decision Nodes that are ordered before the provided Decision Node must have a state!") raise Exception("Decision Nodes that are ordered before the provided Decision Node must have a state!")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment