From b65cd2f60968084347a946e3d965b956af1a8eb2 Mon Sep 17 00:00:00 2001
From: Hendrik Buschmeier <hbuschme@uni-bielefeld.de>
Date: Sat, 5 Apr 2014 00:18:07 +0200
Subject: [PATCH] Refactoring: Moved decision networks.

---
 primo/nodes.py                                | 30 +++++---
 .../MakeDecision.py => reasoning/decision.py} | 69 +++++++++++++++++--
 2 files changed, 84 insertions(+), 15 deletions(-)
 rename primo/{decision/make_decision/MakeDecision.py => reasoning/decision.py} (77%)

diff --git a/primo/nodes.py b/primo/nodes.py
index 67f5069..8fdd77c 100644
--- a/primo/nodes.py
+++ b/primo/nodes.py
@@ -5,10 +5,7 @@ import re
 import scipy
 
 from primo.decision.UtilityTable import UtilityTable
-from primo.reasoning.density import Beta
-from primo.reasoning.density import Exponential
-from primo.reasoning.density import Gauss
-from primo.reasoning.density import ProbabilityTable
+import primo.reasoning.density
 
 
 class Node(object):
@@ -102,7 +99,7 @@ class DiscreteNode(RandomNode):
         super(DiscreteNode, self).__init__(name)
 
         self.value_range = value_range
-        self.cpd = ProbabilityTable()
+        self.cpd = primo.reasoning.density.ProbabilityTable()
         self.cpd.add_variable(self)
         
     def __str__(self):
@@ -279,7 +276,11 @@ class ContinuousNodeFactory(object):
         
         @param name: The name of the node.
         '''        
-        return self.createContinuousNode(name,(-float("Inf"),float("Inf")),Gauss)
+        return self.createContinuousNode(
+            name,
+            (-float("Inf"),
+            float("Inf")),
+            primo.reasoning.density.Gauss)
         
     def createExponentialNode(self, name):
         '''
@@ -287,7 +288,10 @@ class ContinuousNodeFactory(object):
         
         @param name: The name of the node.
         '''  
-        return self.createContinuousNode(name,(0,float("Inf")),Exponential)
+        return self.createContinuousNode(
+            name,
+            (0,float("Inf")),
+            primo.reasoning.density.Exponential)
         
     def createBetaNode(self, name):
         '''
@@ -295,9 +299,12 @@ class ContinuousNodeFactory(object):
         
         @param name: The name of the node.
         '''  
-        return self.createContinuousNode(name,(0,1),Beta)
+        return self.createContinuousNode(
+            name,
+            (0, 1),
+            primo.reasoning.density.Beta)
     
-    def createContinuousNode(self,name,value_range,DensityClass):
+    def createContinuousNode(self,name,value_range,density_class):
         '''
         Create a ContinuousNode. This method should only be invoked from
         outside this class if no specialized method is available.
@@ -308,7 +315,10 @@ class ContinuousNodeFactory(object):
         @param DensityClass: A class from primo.reasoning.density that shall be
             the node's pdf
         '''  
-        return ContinuousNode(name,value_range,DensityClass)
+        return ContinuousNode(
+            name,
+            value_range,
+            density_class)
 
 
 class DecisionNode(Node):
diff --git a/primo/decision/make_decision/MakeDecision.py b/primo/reasoning/decision.py
similarity index 77%
rename from primo/decision/make_decision/MakeDecision.py
rename to primo/reasoning/decision.py
index 27702d7..085bcc9 100644
--- a/primo/decision/make_decision/MakeDecision.py
+++ b/primo/reasoning/decision.py
@@ -1,10 +1,69 @@
 # -*- coding: utf-8 -*-
 
 import itertools
-from primo.core import BayesianDecisionNetwork
-from primo.decision import DecisionNode
-from primo.decision import UtilityTable
-from primo.reasoning import DiscreteNode
+import operator
+
+import numpy as np
+
+import primo.nodes
+
+
+class UtilityTable(object):
+    '''
+    self.variables -- list of the parent nodes
+    self.table -- utility table which contains the utility
+    '''
+    
+    def __init__(self):
+        super(UtilityTable, self).__init__()
+        self.table = np.array(0)
+        self.variables = []
+    
+    def add_variable(self, variable):
+        self.variables.append(variable)
+
+        ax = self.table.ndim
+        self.table=np.expand_dims(self.table,ax)
+        self.table=np.repeat(self.table,len(variable.value_range),axis = ax)
+
+    def get_ut_index(self, node_value_pairs):
+        nodes, values = zip(*node_value_pairs)
+        index = []
+        for node in self.variables:
+            index_in_values_list = nodes.index(node)
+            value = values[index_in_values_list]
+            index.append(node.value_range.index(value))
+        return tuple(index)    
+        
+    def set_utility_table(self, table, nodes):
+        if not set(nodes) == set(self.variables):
+            raise Exception("The list which should define the ordering of the variables does not match"
+                " the variables that this cpt depends on (plus the node itself)")
+        if not self.table.ndim == table.ndim:
+            raise Exception("The provided probability table does not have the right number of dimensions")
+        for d,node in enumerate(nodes):
+            if len(node.value_range) != table.shape[d]:
+                raise Exception("The size of the provided probability table does not match the number of possible values of the node "+node.name+" in dimension "+str(d))
+
+        self.table = table
+        self.variables = nodes
+      
+    def set_utility(self, value, node_value_pairs):
+        index = self.get_ut_index(node_value_pairs)
+        self.table[index]=value
+    
+    def get_utility_table(self):
+        return self.table
+        
+    def get_variables(self):
+        return self.variables
+    
+    def get_utility(self, node_value_pairs):
+        index = self.get_ut_index(node_value_pairs)
+        return self.table[index]
+
+    def __str__(self):
+        return str(self.table)
 
 class MakeDecision(object):
     """
@@ -65,7 +124,7 @@ class MakeDecision(object):
         
         #Check if the Decision Nodes that are ordered before the provided Decision Node have a state
         for node in partialOrder:
-            if isinstance(node, DecisionNode):
+            if isinstance(node, primo.nodes.DecisionNode):
                 if not decisionNode.name == node.name:
                     if node.get_state() is None:
                         raise Exception("Decision Nodes that are ordered before the provided Decision Node must have a state!")
-- 
GitLab