Skip to content
Snippets Groups Projects
Commit 0fbbb4f6 authored by Lukas Kettenbach's avatar Lukas Kettenbach
Browse files

Some new graph methods

parent 05d1b792
No related branches found
No related tags found
No related merge requests found
...@@ -18,10 +18,16 @@ ns = bn.get_nodes(["Node2","Node1"]) ...@@ -18,10 +18,16 @@ ns = bn.get_nodes(["Node2","Node1"])
for n in ns: for n in ns:
print n.name print n.name
print bn.get_parents(n2)
print "Removing existing edge" print "Removing existing edge"
bn.remove_edge(n1, n2) bn.remove_edge(n1, n2)
print "nach edge remove"
print bn.get_parents(n2)
print "Removing not existing edge" print "Removing not existing edge"
bn.remove_edge(n1, n2) bn.remove_edge(n1, n2)
bn.draw() bn.draw()
...@@ -31,7 +31,14 @@ class BayesNet(object): ...@@ -31,7 +31,14 @@ class BayesNet(object):
raise Exception("Tried to add an Edge between two Nodes of which at least one was not contained in the Bayesnet") raise Exception("Tried to add an Edge between two Nodes of which at least one was not contained in the Bayesnet")
def remove_node(self, node): def remove_node(self, node):
raise Exception("Called unimplemented function") if node.name not in self.node_lookup.keys():
raise Exception("Node " + node.name + "does not exists")
else :
try:
self.graph.remove_node(node)
except nx.exception.NetworkXError:
raise Exception("Tried to remove a node which does not exists")
del self.node_lookup[node.name]
def remove_edge(self, node_from, node_to): def remove_edge(self, node_from, node_to):
try: try:
...@@ -44,22 +51,30 @@ class BayesNet(object): ...@@ -44,22 +51,30 @@ class BayesNet(object):
try: try:
return self.node_lookup[node_name] return self.node_lookup[node_name]
except KeyError: except KeyError:
raise Exception("There is no node with name "+node_name+" in the bayesnet") raise Exception("There is no node with name "+node_name+" in the bayesnet")
def get_nodes(self, node_names): def get_nodes(self, node_names):
nodes = [] nodes = []
if not node_names: if not node_names:
nodes = self.bn.nodes() nodes = self.graph.nodes()
else: else:
for node_name in node_names: for node_name in node_names:
nodes.append(self.get_node(node_name)) nodes.append(self.get_node(node_name))
return nodes return nodes
def get_parents(self, node): def get_parents(self, node):
raise Exception("Called unimplemented function") if node.name not in self.node_lookup.keys():
raise Exception("Node " + node.name + "does not exists")
else:
return self.graph.predecessors(node)
def get_children(self, node): def get_children(self, node):
raise Exception("Called unimplemented function") if node.name not in self.node_lookup.keys():
raise Exception("Node " + node.name + "does not exists")
else:
return self.graph.successors(node)
def get_markov_blanket(self, node): def get_markov_blanket(self, node):
raise Exception("Called unimplemented function") raise Exception("Called unimplemented function")
...@@ -70,4 +85,4 @@ class BayesNet(object): ...@@ -70,4 +85,4 @@ class BayesNet(object):
def draw(self): def draw(self):
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
nx.draw(self.graph) nx.draw(self.graph)
plt.show() plt.show()
from BayesNet import BayesNet
from Node import Node
...@@ -11,7 +11,7 @@ class NodeAddAndRemoveTestCase(unittest.TestCase): ...@@ -11,7 +11,7 @@ class NodeAddAndRemoveTestCase(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.bn = None self.bn = None
def test_add(self): def test_add_node(self):
n = Node("Some Node") n = Node("Some Node")
self.bn.add_node(n) self.bn.add_node(n)
self.assertEqual(n, self.bn.get_node("Some Node")) self.assertEqual(n, self.bn.get_node("Some Node"))
...@@ -20,12 +20,33 @@ class NodeAddAndRemoveTestCase(unittest.TestCase): ...@@ -20,12 +20,33 @@ class NodeAddAndRemoveTestCase(unittest.TestCase):
node_with_same_name = Node("Some Node") node_with_same_name = Node("Some Node")
self.assertRaises(Exception, self.bn.add_node, node_with_same_name) self.assertRaises(Exception, self.bn.add_node, node_with_same_name)
def test_remove(self): def test_remove_node(self):
n = Node("Some Node to remove") n = Node("Some Node to remove")
self.bn.add_node(n) self.bn.add_node(n)
self.bn.remove_node(n) self.bn.remove_node(n)
self.assertFalse(n in self.bn.get_nodes()) self.assertFalse(n in self.bn.get_nodes([]))
def test_add_edge(self):
n1 = Node("1")
n2 = Node("2")
self.bn.add_node(n1)
self.bn.add_node(n2)
self.bn.add_edge(n1, n2)
self.assertTrue(n1 in self.bn.get_parents(n2))
self.assertTrue(n2 in self.bn.get_children(n1))
self.bn.remove_node(n1)
self.bn.remove_node(n2)
def test_remove_edge(self):
n1 = Node("1")
n2 = Node("2")
self.bn.add_node(n1)
self.bn.add_node(n2)
self.bn.add_edge(n1, n2)
self.bn.remove_edge(n1, n2)
self.assertEqual([], self.bn.get_parents(n2))
self.bn.remove_node(n1)
self.bn.remove_node(n2)
#include this so you can run this test without nose #include this so you can run this test without nose
if __name__ == '__main__': if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment