Skip to content
Snippets Groups Projects
Commit 95aa4580 authored by Lukas Kettenbach's avatar Lukas Kettenbach
Browse files

changed repo. structure

parent 7722d78a
No related branches found
No related tags found
No related merge requests found
Showing
with 0 additions and 437 deletions
import sys
import networkx as nx
from core.Node import Node
class BayesNet(object):
graph = nx.DiGraph()
node_lookup = {}
def __init__(self):
pass
def add_node(self, node):
if isinstance(node, Node):
if node.name in self.node_lookup.keys():
raise Exception("Node name already exists in Bayesnet: "+node.name)
self.node_lookup[node.name]=node
self.graph.add_node(node)
else:
raise Exception("Can only add 'Node' and its subclasses as nodes into the BayesNet")
def add_edge(self, node_from, node_to):
if node_from in self.graph.nodes() and node_to in self.graph.nodes():
self.graph.add_edge(node_from, node_to)
node_to.announce_parent(node_from)
else:
raise Exception("Tried to add an Edge between two Nodes of which at least one was not contained in the Bayesnet")
def remove_node(self, node):
if node.name not in self.node_lookup.keys():
raise Exception("Node " + node.name + "does not exists")
else :
try:
self.graph.remove_node(node)
except nx.exception.NetworkXError:
raise Exception("Tried to remove a node which does not exist.")
del self.node_lookup[node.name]
def remove_edge(self, node_from, node_to):
try:
self.graph.remove_edge(node_from, node_to)
except nx.exception.NetworkXError:
raise Exception("Tried to remove an edge which does not exist in the BayesNet")
#raise Exception("Fixme: Adapt CPD of child-node")
def get_node(self, node_name):
try:
return self.node_lookup[node_name]
except KeyError:
raise Exception("There is no node with name "+node_name+" in the BayesNet")
def get_nodes(self, node_names):
nodes = []
if not node_names:
nodes = self.graph.nodes()
else:
for node_name in node_names:
nodes.append(self.get_node(node_name))
return nodes
def get_parents(self, node):
if node.name not in self.node_lookup.keys():
raise Exception("Node " + node.name + "does not exists")
else:
return self.graph.predecessors(node)
def get_children(self, node):
if node.name not in self.node_lookup.keys():
raise Exception("Node " + node.name + "does not exists")
else:
return self.graph.successors(node)
def get_markov_blanket(self, node):
raise Exception("Called unimplemented function")
def is_dag(self):
raise Exception("Called unimplemented function")
def draw(self):
import matplotlib.pyplot as plt
nx.draw(self.graph)
plt.show()
def is_valid(self):
'''Check if graph structure is valid.
Returns true if graph is directed and acyclic, false otherwiese'''
if self.graph.number_of_selfloops() > 0:
return False
for node in self.graph.nodes():
if self.has_loop(node):
return False
return True
def has_loop(self, node, origin=None):
'''Check if any path from node leads back to node.
Keyword arguments:
node -- the start node
origin -- same as node for internal recursive loop (default: None)
Returns true on succes, false otherwise.'''
if not origin:
origin = node
for successor in self.graph.successors(node):
if successor == origin:
return True
else:
return self.has_loop(successor, origin)
def clear(self):
'''Remove all nodes and edges from the graph.
This also removes the name, and all graph, node and edge attributes.'''
self.graph.clear()
self.node_lookup.clear()
def number_of_nodes(self):
'''Return the number of nodes in the graph.'''
return len(self)
def __len__(self):
'''Return the number of nodes in the graph.'''
return len(self.graph)
# -*- coding: utf-8 -*-
from BayesNet import BayesNet
class DynamicBayesNet(BayesNet):
def __init__(self):
super(DynamicBayesNet, self).__init__()
def add_edge(self, node_from, node_to, arc=False):
'''Add an directed edge to the graph.
Keyword arguments:
node_from -- from node
node_to -- to node
arc -- is this edge a temporal conditional dependency (default: False)
'''
super().add_edge(node_from, node_to)
# Adding an edge that already exists updates the edge data.
self.graph.add_edge(node_from, node_to, arc=arc)
import abc
class Node(object):
__metaclass__ = abc.ABCMeta
name = "UninitializedName"
def __init__(self, node_name):
self.name = node_name
@abc.abstractmethod
def announce_parent(self, node):
"""This method will be called by the graph-management to inform nodes
which just became children of other nodes, so they can adapt themselves
(e.g. their cpt)"""
return
def __str__(self):
print self.name
return self.name
from BayesNet import BayesNet
from Node import Node
import unittest
from core.BayesNet import BayesNet
from reasoning.DiscreteNode import DiscreteNode
class NodeAddAndRemoveTestCase(unittest.TestCase):
def setUp(self):
self.bn = BayesNet()
def tearDown(self):
self.bn = None
def test_clear_and_len(self):
self.assertFalse(0 == len(self.bn))
self.assertFalse(0 == self.bn.number_of_nodes())
self.bn.clear()
self.assertEqual(0, len(self.bn))
self.assertEqual(0, self.bn.number_of_nodes())
def test_add_node(self):
self.bn.clear()
n = DiscreteNode("Some Node", [True, False])
self.bn.add_node(n)
self.assertEqual(n, self.bn.get_node("Some Node"))
self.assertTrue(n in self.bn.get_nodes(["Some Node"]))
node_with_same_name = DiscreteNode("Some Node", [True, False])
self.assertRaises(Exception, self.bn.add_node, node_with_same_name)
def test_remove_node(self):
self.bn.clear()
n = DiscreteNode("Some Node to remove", [True, False])
self.bn.add_node(n)
self.bn.remove_node(n)
self.assertFalse(n in self.bn.get_nodes([]))
def test_add_edge(self):
self.bn.clear()
n1 = DiscreteNode("1", [True, False])
n2 = DiscreteNode("2", [True, False])
self.bn.add_node(n1)
self.bn.add_node(n2)
self.bn.add_edge(n1, n2)
self.assertTrue(n1 in self.bn.get_parents(n2))
self.assertTrue(n2 in self.bn.get_children(n1))
def test_remove_edge(self):
self.bn.clear()
n1 = DiscreteNode("1", [True, False])
n2 = DiscreteNode("2", [True, False])
self.bn.add_node(n1)
self.bn.add_node(n2)
self.bn.add_edge(n1, n2)
self.assertEqual([n1], self.bn.get_parents(n2))
self.bn.remove_edge(n1, n2)
self.assertEqual([], self.bn.get_parents(n2))
def test_is_valid(self):
self.bn.clear()
n1 = DiscreteNode("1", [True, False])
n2 = DiscreteNode("2", [True, False])
self.bn.add_node(n1)
self.bn.add_node(n2)
self.bn.add_edge(n1, n2)
self.assertTrue(self.bn.is_valid())
self.bn.add_edge(n1, n1)
self.assertFalse(self.bn.is_valid())
self.bn.remove_edge(n1, n1)
self.assertTrue(self.bn.is_valid())
n3 = DiscreteNode("3", [True, False])
n4 = DiscreteNode("4", [True, False])
self.bn.add_node(n3)
self.bn.add_node(n4)
self.assertTrue(self.bn.is_valid())
self.bn.add_edge(n2, n3)
self.assertTrue(self.bn.is_valid())
self.bn.add_edge(n3, n4)
self.assertTrue(self.bn.is_valid())
self.bn.add_edge(n4, n1)
self.assertFalse(self.bn.is_valid())
#include this so you can run this test without nose
if __name__ == '__main__':
unittest.main()
# -*- coding: utf-8 -*-
from core.Node import Node
class DecisionNode(Node):
'''TODO: write doc'''
def __init__(self):
super(DecisionNode, self).__init__()
\ No newline at end of file
# -*- coding: utf-8 -*-
from core.Node import Node
class UtilityNode(Node):
'''TODO: write doc'''
def __init__(self):
super(UtilityNode, self).__init__()
\ No newline at end of file
File moved
File moved
File deleted
# -*- coding: utf-8 -*-
from reasoning.RandomNode import RandomNode
from reasoning.density.ProbabilityTable import ProbabilityTable
class DiscreteNode(RandomNode):
'''#TODO: write doc'''
def __init__(self, name, value_range):
super(DiscreteNode, self).__init__(name)
self.value_range = value_range
self.cpd = ProbabilityTable()
self.cpd.add_variable(self)
def announce_parent(self, node):
self.cpd.add_variable(node)
def __str__(self):
return self.name + "\n" + str(self.cpd)
def set_probability(self, value, node_value_pairs):
self.cpd.set_probability(value, node_value_pairs)
def set_probability_table(self, table, nodes):
self.cpd.set_probability_table(table, nodes)
def is_valid(self):
return self.cpd.is_normalized_as_cpt(self)
# -*- coding: utf-8 -*-
from reasoning.RandomNode import RandomNode
class GaussNode(RandomNode):
'''TODO: write doc'''
def __init__(self):
super(GaussNode, self).__init__()
\ No newline at end of file
# -*- coding: utf-8 -*-
from core.Node import Node
from reasoning.density import Density
class RandomNode(Node):
'''TODO: write doc'''
cpd = Density()
def __init__(self, name):
super(RandomNode, self).__init__(name)
def is_valid(self):
raise Exception("Called an unimplemented function")
from DiscreteNode import DiscreteNode
from GaussNode import GaussNode
from RandomNode import RandomNode
# -*- coding: utf-8 -*-
class Density(object):
'''TODO: write doc'''
def __init__(self):
super(Density, self).__init__()
# -*- coding: utf-8 -*-
from reasoning.density import Density
class Gauss(Density):
'''TODO: write doc'''
def __init__(self):
super(Gauss, self).__init__()
# -*- coding: utf-8 -*-
import numpy
from reasoning.density import Density
class ProbabilityTable(Density):
'''TODO: write doc'''
def __init__(self):
super(ProbabilityTable, self).__init__()
#self.owner = owner
#self.variables = [owner]
#size_of_range = len(owner.value_range)
#self.table = numpy.ones(size_of_range) / size_of_range
self.variables = []
self.table = numpy.array(0)
def add_variable(self, variable):
self.variables.append(variable)
ax = self.table.ndim
self.table=numpy.expand_dims(self.table,ax)
self.table=numpy.repeat(self.table,len(variable.value_range),axis = ax)
def set_probability_table(self, table, nodes):
if not set(nodes) == set(self.variables):
raise Exception("The list which should define the ordering of the variables does not match"
" the variables that this cpt depends on (plus the node itself)")
if not self.table.ndim == table.ndim:
raise Exception("The provided probability table does not have the right number of dimensions")
for d,node in enumerate(nodes):
if len(node.value_range) != table.shape[d]:
raise Exception("The size of the provided probability table does not match the number of possible values of the node "+node.name+" in dimension "+str(d))
self.table = table
self.variables = nodes
def set_probability(self, value, node_value_pairs):
index = self.get_cpt_index(node_value_pairs)
self.table[tuple(index)]=value
def get_cpt_index(self, node_value_pairs):
nodes, values = zip(*node_value_pairs)
index = []
for node in self.variables:
index_in_values_list = nodes.index(node)
value = values[index_in_values_list]
index.append(node.value_range.index(value))
return index
def is_normalized_as_cpt(self,owner):
dim_of_owner = self.variables.index(owner)
sum_of_owner_probs = numpy.sum(self.table, dim_of_owner)
return set(sum_of_owner_probs.flatten()) == set([1])
def is_normalized_as_jpt(self):
return numpy.sum(table) == 1.0
def multiplication(self, factor):
raise Exception("Called unimplemented function")
def marginalization(self, variable):
raise Exception("Called unimplemented function")
def reduction(self):
raise Exception("Called unimplemented function")
def division(self, factor):
raise Exception("Called unimplemented function")
def __str__(self):
return str(self.table)
from Density import Density
from ProbabilityTable import ProbabilityTable
from Gauss import Gauss
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment