Skip to content
Snippets Groups Projects
BayesNet.py 4.05 KiB
Newer Older
Manuel Baum's avatar
Manuel Baum committed
import sys
Manuel Baum's avatar
Manuel Baum committed
import networkx as nx
from core.Node import Node
Manuel Baum's avatar
Manuel Baum committed
class BayesNet(object):
    graph = nx.DiGraph()
    node_lookup = {}
Manuel Baum's avatar
Manuel Baum committed

    def __init__(self):
Manuel Baum's avatar
Manuel Baum committed

    def add_node(self, node):
        if isinstance(node, Node):
            if node.name in self.node_lookup.keys():
                raise Exception("Node name already exists in Bayesnet: "+node.name)
            self.node_lookup[node.name]=node
Manuel Baum's avatar
Manuel Baum committed
            self.graph.add_node(node)
        else:
            raise Exception("Can only add 'Node' and its subclasses as nodes into the BayesNet")

    def add_edge(self, node_from, node_to):
        if node_from in self.graph.nodes() and node_to in self.graph.nodes():
            self.graph.add_edge(node_from, node_to)
            node_to.announce_parent(node_from)
        else:
            raise Exception("Tried to add an Edge between two Nodes of which at least one was not contained in the Bayesnet")

    def remove_node(self, node):
Lukas Kettenbach's avatar
Lukas Kettenbach committed
        if node.name not in self.node_lookup.keys():
            raise Exception("Node " + node.name + "does not exists")
        else :
            try:
                self.graph.remove_node(node)
            except nx.exception.NetworkXError:
                raise Exception("Tried to remove a node which does not exist.")
Lukas Kettenbach's avatar
Lukas Kettenbach committed
            del self.node_lookup[node.name]

    def remove_edge(self, node_from, node_to):
        try:
            self.graph.remove_edge(node_from, node_to)
        except nx.exception.NetworkXError:
            raise Exception("Tried to remove an edge which does not exist in the BayesNet")
        #raise Exception("Fixme: Adapt CPD of child-node")

    def get_node(self, node_name):
        try:
            return self.node_lookup[node_name]
        except KeyError:
            raise Exception("There is no node with name "+node_name+" in the BayesNet")

    def get_nodes(self, node_names):
mbaumBielefeld's avatar
mbaumBielefeld committed
        nodes = []
        if not node_names:
Lukas Kettenbach's avatar
Lukas Kettenbach committed
            nodes = self.graph.nodes()
mbaumBielefeld's avatar
mbaumBielefeld committed
        else:
            for node_name in node_names:
                nodes.append(self.get_node(node_name))
        return nodes

    def get_parents(self, node):
Lukas Kettenbach's avatar
Lukas Kettenbach committed
        if node.name not in self.node_lookup.keys():
            raise Exception("Node " + node.name + "does not exists")
        else:
            return self.graph.predecessors(node)


    def get_children(self, node):
Lukas Kettenbach's avatar
Lukas Kettenbach committed
        if node.name not in self.node_lookup.keys():
            raise Exception("Node " + node.name + "does not exists")
        else:
            return self.graph.successors(node)


    def get_markov_blanket(self, node):
        raise Exception("Called unimplemented function")

    def is_dag(self):
        raise Exception("Called unimplemented function")

    def draw(self):
        import matplotlib.pyplot as plt
        nx.draw(self.graph)
Lukas Kettenbach's avatar
Lukas Kettenbach committed
        plt.show()

    def is_valid(self):
        '''Check if graph structure is valid.
        Returns true if graph is directed and acyclic, false otherwiese'''

        if self.graph.number_of_selfloops() > 0:
            return False

        for node in self.graph.nodes():
            if self.has_loop(node):
                return False

        return True

    def has_loop(self, node, origin=None):
        '''Check if any path from node leads back to node.

        Keyword arguments:
        node -- the start node
        origin -- same as node for internal recursive loop (default: None)

        Returns true on succes, false otherwise.'''
        if not origin:
            origin = node

        for successor in self.graph.successors(node):
            if successor == origin:
                return True
            else:
                return self.has_loop(successor, origin)

    def clear(self):
        '''Remove all nodes and edges from the graph.
        This also removes the name, and all graph, node and edge attributes.'''
        self.graph.clear()
        self.node_lookup.clear()

    def number_of_nodes(self):
        '''Return the number of nodes in the graph.'''
        return len(self)

    def __len__(self):
        '''Return the number of nodes in the graph.'''
        return len(self.graph)