Commit 304adee3 authored by Jan Pöppel's avatar Jan Pöppel
Browse files

fixed error in changing node names and used Kolja Berger's PEP8 improvements

parent 8334380d
......@@ -24,51 +24,53 @@ import networkx as nx
from . import exceptions
from . import nodes
class BayesianNetwork(object):
def __init__(self):
# super(BayesianNetwork, self).__init__()
self.graph = nx.DiGraph()
self.node_lookup = {}
self.name = "" #Only used to be compatible with XMLBIF
self.meta = [] #Used to be compatible with XMLBIF, stores properties
self.name = "" # Only used to be compatible with XMLBIF
self.meta = [] # Used to be compatible with XMLBIF, stores properties
def add_node(self, node):
if isinstance(node, nodes.RandomNode):
if node.name in self.node_lookup:
raise ValueError("The network already contains a node called '{}'.".format(node.name))
self.node_lookup[node.name]=node
raise ValueError("The network already contains a node " \
"called '{}'.".format(node.name))
self.node_lookup[node.name] = node
self.graph.add_node(node)
else:
raise TypeError("Only subclasses of RandomNode are valid nodes.")
def remove_node(self, node):
if node in self.graph:
#Go over all children of this node
# Go over all children of this node
for child in self.graph.succ[node]:
child.remove_parent(self.node_lookup[node])
self.graph.remove_node(node)
del self.node_lookup[node]
def remove_edge(self, fromName, toName):
if fromName in self.graph and toName in self.graph:
self.node_lookup[toName].remove_parent(self.node_lookup[fromName])
self.graph.remove_edge(fromName, toName)
def add_edge(self, fromName, toName):
if fromName in self.graph and toName in self.graph:
self.graph.add_edge(self.node_lookup[fromName], self.node_lookup[toName])
self.node_lookup[toName].add_parent(self.node_lookup[fromName])
def remove_edge(self, from_name, to_name):
if from_name in self.graph and to_name in self.graph:
self.node_lookup[to_name].remove_parent(self.node_lookup[from_name])
self.graph.remove_edge(from_name, to_name)
def add_edge(self, from_name, to_name):
if from_name in self.graph and to_name in self.graph:
self.graph.add_edge(self.node_lookup[from_name], self.node_lookup[to_name])
self.node_lookup[to_name].add_parent(self.node_lookup[from_name])
else:
raise Exception("Tried to add an Edge between two Nodes of which at least one was not contained in the BayesianNetwork")
raise Exception("Tried to add an Edge between two Nodes of " \
"which at least one was not contained in " \
"the BayesianNetwork")
def get_node(self, node_name):
try:
return self.node_lookup[node_name]
except KeyError:
raise Exception("There is no node with name {} in the BayesianNetwork".format(node_name))
raise Exception("There is no node with name {} in the " \
"BayesianNetwork".format(node_name))
def change_node_values(self, node, new_values):
"""
......@@ -88,30 +90,30 @@ class BayesianNetwork(object):
for child in self.graph.succ[node]:
child._update_dimensions()
else:
raise Exception("There is no node with name {} in the network.".format(node))
raise Exception("There is no node with name {} in " \
"the network.".format(node))
def change_node_name(self, oldName, newName):
def change_node_name(self, old_name, new_name):
"""
Renames the given node to the new name.
Will have to modify all occurances of the old name.
"""
if oldName in self.node_lookup:
n = self.node_lookup[oldName]
if old_name in self.node_lookup:
n = self.node_lookup[old_name]
children = list(self.graph.succ[n])
parents = list(self.graph.pred[n])
for child in children:
del child.parents[oldName]
child.parents[newName] = n
idx = child.parentOrder.index(oldName)
child.parentOrder[idx] = newName
del child.parents[old_name]
child.parents[new_name] = n
idx = child.parentOrder.index(old_name)
child.parentOrder[idx] = new_name
#Fix nx graph
# Fix nx graph
self.graph.remove_node(n)
n.name = newName
del self.node_lookup[oldName]
self.node_lookup[n] = n
n.name = new_name
del self.node_lookup[old_name]
self.node_lookup[new_name] = n
self.graph.add_node(n)
for child in children:
......@@ -120,7 +122,8 @@ class BayesianNetwork(object):
self.graph.add_edge(self.node_lookup[parent], n)
else:
raise Exception("There is no node with name {} in the network.".format(oldName))
raise Exception("There is no node with name {} in the " \
"network.".format(old_name))
def get_all_nodes(self):
return self.graph.nodes()
......@@ -137,21 +140,22 @@ class BayesianNetwork(object):
nodes.append(self.get_node(node_name))
return nodes
def get_children(self, nodeName):
def get_children(self, node_name):
"""
Returns a list of all the children of the given node.
Parameter
--------
nodeName : String or RandomNode
node_name : String or RandomNode
The name of the node whose children are to be returned.
Returns
-------
[RandomNode,]
A list containing all the nodes that have the given node as parent.
A list containing all the nodes that have the given
node as parent.
"""
return self.graph.succ[nodeName]
return self.graph.succ[node_name]
def get_sample(self, evidence):
sample = {}
......@@ -165,19 +169,18 @@ class BayesianNetwork(object):
return sample
def clear(self):
'''Remove all nodes and edges from the graph.
This also removes the name, and all graph, node and edge attributes.'''
"""Remove all nodes and edges from the graph.
This also removes the name, and all graph, node and edge attributes."""
self.graph.clear()
self.node_lookup.clear()
def number_of_nodes(self):
'''Return the number of nodes in the graph.'''
"""Return the number of nodes in the graph."""
return len(self)
def __len__(self):
'''Return the number of nodes in the graph.'''
"""Return the number of nodes in the graph."""
return len(self.graph)
......@@ -310,4 +313,4 @@ class DynamicBayesianNetwork(object):
A list of pairs, each of which represents one transition.
See add_transition for more information.
"""
return self._transitions
return self._transitions
\ No newline at end of file
......@@ -80,7 +80,11 @@ class BayesNetTest(unittest.TestCase):
self.assertEqual(n2.name, "NewName")
with self.assertRaises(Exception) as cm:
self.bn.get_node("Node2")
self.assertEqual(str(cm.exception), "There is no node with name Node2 in the BayesianNetwork")
self.assertEqual(str(cm.exception), "There is no node with name " \
"Node2 in the BayesianNetwork")
also_n2 = self.bn.get_node("NewName")
self.assertEqual(also_n2, n2)
self.assertEqual(n1.parentOrder, ["NewName"])
self.assertTrue(n2 in self.bn.graph.nodes())
......@@ -90,6 +94,31 @@ class BayesNetTest(unittest.TestCase):
self.assertTrue(n2 in self.bn.graph.succ[n3])
def test_change_node_name_twice(self):
n1 = DiscreteNode("Node1")
n2 = DiscreteNode("Node2")
n3 = DiscreteNode("Node3")
self.bn.add_node(n1)
self.bn.add_node(n2)
self.bn.add_node(n3)
self.bn.add_edge(n2,n1)
self.bn.add_edge(n3,n2)
self.bn.change_node_name("Node2", "NewName")
self.assertEqual(n2.name, "NewName")
with self.assertRaises(Exception) as cm:
self.bn.get_node("Node2")
self.assertEqual(str(cm.exception), "There is no node with name " \
"Node2 in the BayesianNetwork")
print("lookup old name: ", self.bn.node_lookup["NewName"])
self.bn.change_node_name("NewName", "2ndNewName")
self.assertEqual(n2.name, "2ndNewName")
with self.assertRaises(Exception) as cm:
self.bn.get_node("NewName")
self.assertEqual(str(cm.exception), "There is no node with name " \
"NewName in the BayesianNetwork")
# def test_addEdge(self):
# self.fail("TODO")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment