Skip to content
Snippets Groups Projects
Commit b65cd2f6 authored by Hendrik Buschmeier's avatar Hendrik Buschmeier
Browse files

Refactoring: Moved decision networks.

parent 8679ad44
No related branches found
No related tags found
No related merge requests found
......@@ -5,10 +5,7 @@ import re
import scipy
from primo.decision.UtilityTable import UtilityTable
from primo.reasoning.density import Beta
from primo.reasoning.density import Exponential
from primo.reasoning.density import Gauss
from primo.reasoning.density import ProbabilityTable
import primo.reasoning.density
class Node(object):
......@@ -102,7 +99,7 @@ class DiscreteNode(RandomNode):
super(DiscreteNode, self).__init__(name)
self.value_range = value_range
self.cpd = ProbabilityTable()
self.cpd = primo.reasoning.density.ProbabilityTable()
self.cpd.add_variable(self)
def __str__(self):
......@@ -279,7 +276,11 @@ class ContinuousNodeFactory(object):
@param name: The name of the node.
'''
return self.createContinuousNode(name,(-float("Inf"),float("Inf")),Gauss)
return self.createContinuousNode(
name,
(-float("Inf"),
float("Inf")),
primo.reasoning.density.Gauss)
def createExponentialNode(self, name):
'''
......@@ -287,7 +288,10 @@ class ContinuousNodeFactory(object):
@param name: The name of the node.
'''
return self.createContinuousNode(name,(0,float("Inf")),Exponential)
return self.createContinuousNode(
name,
(0,float("Inf")),
primo.reasoning.density.Exponential)
def createBetaNode(self, name):
'''
......@@ -295,9 +299,12 @@ class ContinuousNodeFactory(object):
@param name: The name of the node.
'''
return self.createContinuousNode(name,(0,1),Beta)
return self.createContinuousNode(
name,
(0, 1),
primo.reasoning.density.Beta)
def createContinuousNode(self,name,value_range,DensityClass):
def createContinuousNode(self,name,value_range,density_class):
'''
Create a ContinuousNode. This method should only be invoked from
outside this class if no specialized method is available.
......@@ -308,7 +315,10 @@ class ContinuousNodeFactory(object):
@param DensityClass: A class from primo.reasoning.density that shall be
the node's pdf
'''
return ContinuousNode(name,value_range,DensityClass)
return ContinuousNode(
name,
value_range,
density_class)
class DecisionNode(Node):
......
# -*- coding: utf-8 -*-
import itertools
from primo.core import BayesianDecisionNetwork
from primo.decision import DecisionNode
from primo.decision import UtilityTable
from primo.reasoning import DiscreteNode
import operator
import numpy as np
import primo.nodes
class UtilityTable(object):
'''
self.variables -- list of the parent nodes
self.table -- utility table which contains the utility
'''
def __init__(self):
super(UtilityTable, self).__init__()
self.table = np.array(0)
self.variables = []
def add_variable(self, variable):
self.variables.append(variable)
ax = self.table.ndim
self.table=np.expand_dims(self.table,ax)
self.table=np.repeat(self.table,len(variable.value_range),axis = ax)
def get_ut_index(self, node_value_pairs):
nodes, values = zip(*node_value_pairs)
index = []
for node in self.variables:
index_in_values_list = nodes.index(node)
value = values[index_in_values_list]
index.append(node.value_range.index(value))
return tuple(index)
def set_utility_table(self, table, nodes):
if not set(nodes) == set(self.variables):
raise Exception("The list which should define the ordering of the variables does not match"
" the variables that this cpt depends on (plus the node itself)")
if not self.table.ndim == table.ndim:
raise Exception("The provided probability table does not have the right number of dimensions")
for d,node in enumerate(nodes):
if len(node.value_range) != table.shape[d]:
raise Exception("The size of the provided probability table does not match the number of possible values of the node "+node.name+" in dimension "+str(d))
self.table = table
self.variables = nodes
def set_utility(self, value, node_value_pairs):
index = self.get_ut_index(node_value_pairs)
self.table[index]=value
def get_utility_table(self):
return self.table
def get_variables(self):
return self.variables
def get_utility(self, node_value_pairs):
index = self.get_ut_index(node_value_pairs)
return self.table[index]
def __str__(self):
return str(self.table)
class MakeDecision(object):
"""
......@@ -65,7 +124,7 @@ class MakeDecision(object):
#Check if the Decision Nodes that are ordered before the provided Decision Node have a state
for node in partialOrder:
if isinstance(node, DecisionNode):
if isinstance(node, primo.nodes.DecisionNode):
if not decisionNode.name == node.name:
if node.get_state() is None:
raise Exception("Decision Nodes that are ordered before the provided Decision Node must have a state!")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment