diff --git a/ExampleBayesNet.py b/ExampleBayesNet.py index 903c9cee38098a0859be23199fcc47d9691af188..af083193aa7a528f8dd875c1ce80e45a9c9091ee 100644 --- a/ExampleBayesNet.py +++ b/ExampleBayesNet.py @@ -2,7 +2,26 @@ from core.BayesNet import * from core.Node import * bn = BayesNet() -n = Node() +n1 = Node("Node1") +n2 = Node("Node2") +n3 = Node("Node3") -bn.add_node(n) -bn.add_node("lol") +bn.add_node(n1) +bn.add_node(n2) + +bn.add_edge(n1,n2) + +n = bn.get_node("Node1") +print n.name + +ns = bn.get_nodes(["Node2","Node1"]) +for n in ns: + print n.name + +print "Removing existing edge" +bn.remove_edge(n1, n2) +print "Removing not existing edge" +bn.remove_edge(n1, n2) + + +bn.draw() diff --git a/core/BayesNet.py b/core/BayesNet.py index 24382910a3dff3e4192161aaea2c8203de079383..460a5066c247f2d025a03ef26a244a1fe7f0a67e 100644 --- a/core/BayesNet.py +++ b/core/BayesNet.py @@ -1,18 +1,69 @@ import sys -sys.path.append("../lib/networkx-1.7-py2.7.egg") + +sys.path.append("lib/networkx-1.7-py2.7.egg") import networkx as nx -import Node.Node +from core.Node import Node class BayesNet(object): graph = nx.DiGraph() + node_lookup = {} def __init__(self): - print "lol" + pass def add_node(self, node): - if isinstance(node, Node.Node): + if isinstance(node, Node): + if node.name in self.node_lookup.keys(): + raise Exception("Node name already exists in Bayesnet: "+node.name) + self.node_lookup[node.name]=node self.graph.add_node(node) else: raise Exception("Can only add 'Node' and its subclasses as nodes into the BayesNet") - + + def add_edge(self, node_from, node_to): + if node_from in self.graph.nodes() and node_to in self.graph.nodes(): + self.graph.add_edge(node_from, node_to) + + #raise Exception("Fixme: Adapt CPD of child-node") + + else: + raise Exception("Tried to add an Edge between two Nodes of which at least one was not contained in the Bayesnet") + + def remove_node(self, node): + raise Exception("Called unimplemented function") + + def remove_edge(self, node_from, node_to): + try: + self.graph.remove_edge(node_from, node_to) + except nx.exception.NetworkXError: + raise Exception("Tried to remove an Edge which does not exist in the BayesNet") + #raise Exception("Fixme: Adapt CPD of child-node") + + def get_node(self, node_name): + try: + return self.node_lookup[node_name] + except KeyError: + raise Exception("There is no node with name "+node_name+" in the bayesnet") + + def get_nodes(self, node_names): + for node_name in node_names: + yield self.get_node(node_name) + + def get_parents(self, node): + raise Exception("Called unimplemented function") + + def get_children(self, node): + raise Exception("Called unimplemented function") + + def get_markov_blanket(self, node): + raise Exception("Called unimplemented function") + + def is_dag(self): + raise Exception("Called unimplemented function") + + def draw(self): + import matplotlib.pyplot as plt + nx.draw(self.graph) + plt.show() + diff --git a/core/Node.py b/core/Node.py index c67a764d1a4f474667c3c31735ea1d41abf5aa91..d3c158aa6a696e2dae6430de1e4373400295f8b0 100644 --- a/core/Node.py +++ b/core/Node.py @@ -1,2 +1,5 @@ class Node(object): - name = "SomeNode" + name = "UninitializedName" + + def __init__(self, node_name): + self.name = node_name