# -*- coding: utf-8 -*-

import numpy
import operator

class UtilityTable(object):
    '''
    self.variables -- list of the parent nodes
    self.table -- utility table which contains the utility
    '''
    
    def __init__(self):
        super(UtilityTable, self).__init__()
        self.table = numpy.array(0)
        self.variables = []
    
    def add_variable(self, variable):
        self.variables.append(variable)

        ax = self.table.ndim
        self.table=numpy.expand_dims(self.table,ax)
        self.table=numpy.repeat(self.table,len(variable.value_range),axis = ax)

    def get_ut_index(self, node_value_pairs):
        nodes, values = zip(*node_value_pairs)
        index = []
        for node in self.variables:
            index_in_values_list = nodes.index(node)
            value = values[index_in_values_list]
            index.append(node.value_range.index(value))
        return tuple(index)    
        
    def set_utility_table(self, table, nodes):
        if not set(nodes) == set(self.variables):
            raise Exception("The list which should define the ordering of the variables does not match"
                " the variables that this cpt depends on (plus the node itself)")
        if not self.table.ndim == table.ndim:
            raise Exception("The provided probability table does not have the right number of dimensions")
        for d,node in enumerate(nodes):
            if len(node.value_range) != table.shape[d]:
                raise Exception("The size of the provided probability table does not match the number of possible values of the node "+node.name+" in dimension "+str(d))

        self.table = table
        self.variables = nodes
      
    def set_utility(self, value, node_value_pairs):
        index = self.get_ut_index(node_value_pairs)
        self.table[index]=value
    
    def get_utility_table(self):
        return self.table
        
    def get_variables(self):
        return self.variables
    
    def get_utility(self, node_value_pairs):
        index = self.get_ut_index(node_value_pairs)
        return self.table[index]

    def __str__(self):
        return str(self.table)