Newer
Older
class BayesNet(object):
def __init__(self):
self.graph = nx.DiGraph()
self.node_lookup = {}
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def add_node(self, node):
if isinstance(node, Node):
if node.name in self.node_lookup.keys():
raise Exception("Node name already exists in Bayesnet: "+node.name)
self.node_lookup[node.name]=node
self.graph.add_node(node)
else:
raise Exception("Can only add 'Node' and its subclasses as nodes into the BayesNet")
def add_edge(self, node_from, node_to):
if node_from in self.graph.nodes() and node_to in self.graph.nodes():
self.graph.add_edge(node_from, node_to)
node_to.announce_parent(node_from)
else:
raise Exception("Tried to add an Edge between two Nodes of which at least one was not contained in the Bayesnet")
def remove_node(self, node):
if node.name not in self.node_lookup.keys():
raise Exception("Node " + node.name + "does not exists")
else :
try:
self.graph.remove_node(node)
except nx.exception.NetworkXError:
raise Exception("Tried to remove a node which does not exist.")
del self.node_lookup[node.name]
def remove_edge(self, node_from, node_to):
try:
self.graph.remove_edge(node_from, node_to)
except nx.exception.NetworkXError:
raise Exception("Tried to remove an edge which does not exist in the BayesNet")
#raise Exception("Fixme: Adapt CPD of child-node")
def get_node(self, node_name):
try:
return self.node_lookup[node_name]
except KeyError:
raise Exception("There is no node with name "+node_name+" in the BayesNet")
def get_all_nodes(self):
return self.graph.nodes()
def get_nodes(self, node_names=[]):
nodes = []
if not node_names:
nodes = self.graph.nodes()
else:
for node_name in node_names:
nodes.append(self.get_node(node_name))
return nodes
def get_nodes_in_topological_sort(self):
return nx.topological_sort(self.graph)
def get_parents(self, node):
if node.name not in self.node_lookup.keys():
raise Exception("Node " + node.name + "does not exists")
else:
return self.graph.predecessors(node)
def get_children(self, node):
if node.name not in self.node_lookup.keys():
raise Exception("Node " + node.name + "does not exists")
else:
return self.graph.successors(node)
def get_markov_blanket(self, node):
raise Exception("Called unimplemented function")
def is_dag(self):
raise Exception("Called unimplemented function")
def draw(self):
import matplotlib.pyplot as plt
def draw_graphviz(self):
import matplotlib.pyplot as plt
nx.draw_graphviz(self.graph)
plt.show()
def is_valid(self):
'''Check if graph structure is valid.
Returns true if graph is directed and acyclic, false otherwiese'''
if self.graph.number_of_selfloops() > 0:
return False
for node in self.graph.nodes():
if self.has_loop(node):
return False
return True
def has_loop(self, node, origin=None):
'''Check if any path from node leads back to node.
Keyword arguments:
node -- the start node
origin -- for internal recursive loop (default: None)
Returns true on succes, false otherwise.'''
if not origin:
origin = node
for successor in self.graph.successors(node):
if successor == origin:
return True
else:
return self.has_loop(successor, origin)
def clear(self):
'''Remove all nodes and edges from the graph.
This also removes the name, and all graph, node and edge attributes.'''
self.graph.clear()
self.node_lookup.clear()
def number_of_nodes(self):
'''Return the number of nodes in the graph.'''
return len(self)
def __len__(self):
'''Return the number of nodes in the graph.'''