Skip to content
Snippets Groups Projects
XMLBIF.py 11.2 KiB
Newer Older
  • Learn to ignore specific revisions
  • # -*- coding: utf-8 -*-
    import xml.dom.minidom as minidom
    from primo.core import BayesNet
    from primo.core import Node
    from primo.reasoning import DiscreteNode
    import re
    
    
    class XMLBIF(object):
        '''
        This class represents the Interchange Format for Bayesian Networks (XMLBIF).
        It helps you to convert a BayesNet to a XMLBIF and a XMLBIF to a BayesNet.
    
        See: http://www.cs.cmu.edu/~fgcozman/Research/InterchangeFormat/
        '''
        def __init__(self, network, network_name = "Unnamed network",
                     encoding = "UTF-8", ndent = "    ", newl = "\n",
                     addindent = "    "):
            '''
            Create a new XMLBIF instance.
    
            Keyword arguments:
            network -- is a valid BayesNet that must only contain DicreteNodes.
            network_name -- is some name that will be mentioned in the XMLBIF.
            encoding -- encoding of the XMLBIF. Encoding other than UTF-8 is
            likely incorrect, since UTF-8 is the default encoding of XML.
            ndent -- specifies the indentation string and defaults to a tabulator.
            newl -- specifies the string emitted at the end of each line.
            addindent -- is the incremental indentation to use for subnodes of the current one
            '''
            self.network = network
            self.network_name = network_name
            self.encoding = encoding
            self.ndent = ndent
            self.newl = newl
            self.addindent = addindent
            self.root = minidom.Document()
            if isinstance(network, BayesNet):
                self.network = network
            else:
                raise Exception("Given network is not a BayesNet.")
            # Create inital XMLBIF
            self.generate_XMLBIF()
    
        def __str__(self):
            '''
            Returns a pretty string representation of the XMLBIF.
            '''
            return self.root.toprettyxml(self.ndent, self.newl, self.encoding);
    
        def write(self, filename):
            '''
            Write this XMLBIF instance to disk.
    
            Keyword arguments:
            filename -- is a string containing the filename.
            '''
            f = open(filename, "w")
            self.root.writexml(f, self.ndent, self.addindent, self.newl, self.encoding)
    
    
    
        def generate_XMLBIF(self):
            '''
            Generate the XMLBIF document.
    
            This method is used internally. Do not call it outside this class.
            '''
            self.calculate_positions()
            root_node = minidom.Document()
            tag_bif = root_node.createElement("BIF")
            tag_net = root_node.createElement("NETWORK")
            tag_bif.setAttribute("VERSION","0.3")
            root_node.appendChild(tag_bif)
            tag_bif.appendChild(tag_net)
    
            tag_name = minidom.Element("NAME")
            text = minidom.Text()
            text.data = str(self.network_name)
            tag_name.appendChild(text)
            tag_net.appendChild(tag_name)
    
            for node_name in self.network.node_lookup:
                current_node = self.network.node_lookup[node_name]
                if not isinstance(current_node, DiscreteNode):
                    raise Exception("Node " + str(current_node) + " is not a DiscreteNode.")
                node_tag = self.create_node_tag(current_node)
                tag_net.appendChild(node_tag)
    
            #Generate CPTs
            for node_name in self.network.node_lookup:
                current_node = self.network.node_lookup[node_name]
                tag_def = minidom.Element("DEFINITION")
                tag_for = minidom.Element("FOR")
                txt_for = minidom.Text()
                txt_for.data = node_name
                tag_for.appendChild(txt_for)
                tag_def.appendChild(tag_for)
    
    
    Lukas Kettenbach's avatar
    Lukas Kettenbach committed
                # It's not guaranteed that the own node is at dimension zero in 
                # the probability table.But for the function the order of the 
                # variables is important
    
    Denis John PC's avatar
    Denis John PC committed
                for parent in reversed(current_node.get_cpd().get_variables()):
    
                    tag_par = minidom.Element("GIVEN")
                    txt_par = minidom.Text()
                    txt_par.data = str(parent.name)
                    tag_par.appendChild(txt_par)
                    tag_def.appendChild(tag_par)
    
                tag_cpt = minidom.Element("TABLE")
                txt_cpt = minidom.Text()
                txt = ""
                for elem in current_node.get_cpd().get_table().T.flat:
                    txt += str(elem) + " "
    
                txt_cpt.data = txt
                tag_cpt.appendChild(txt_cpt)
                tag_def.appendChild(tag_cpt)
    
                tag_net.appendChild(tag_def)
    
            self.root = root_node
            return self
    
    
    
        def create_node_tag(self, node):
            '''
            Create a node tag that will look like:
            <VARIABLE TYPE="nature">
                <NAME>node_name</NAME>
                <OUTCOME>...</OUTCOME>
                <OUTCOME>...</OUTCOME>
                <PROPERTY>position = (x, y)</PROPERTY>
            </VARIABLE>
    
            Keyword arguments:
            node -- a Node with valid name and position
    
            Returns a XMLBIF conform "variable" tag
            '''
            if not isinstance(node, Node):
                raise Exception("Node " + str(node) + " is not a Node.")
            tag_var = minidom.Element("VARIABLE")
            tag_own = minidom.Element("NAME")
            tag_pos = minidom.Element("PROPERTY")
            tag_var.setAttribute("TYPE", "nature")
    
            # set node name
            txt_name = minidom.Text()
            txt_name.data = node.name
            tag_var.appendChild(tag_own)
            tag_own.appendChild(txt_name)
    
            # set outcomes
            for value in node.value_range:
                tag_outcome = minidom.Element("OUTCOME")
                txt_outcome = minidom.Text()
                txt_outcome.data = value
                tag_outcome.appendChild(txt_outcome)
                tag_var.appendChild(tag_outcome)
    
            # set position
            txt_pos = minidom.Text()
            x, y = node.position
            txt_pos.data = "position = (" + str(x) + ", " + str(y) + ")"
            tag_pos.appendChild(txt_pos)
            tag_var.appendChild(tag_pos)
    
            return tag_var
    
    
        def calculate_positions(self):
            '''
            Calculate the visual position for each node.
    
            This method is used internally. Do not call it outside this class.
            '''
            q = []
            p = []
            already_seen = []
            x_step = 150
            y_step = 100
            x_pos = 0
            y_pos = 0
            for node_name in self.network.node_lookup:
                node = self.network.node_lookup[node_name]
                if len(self.network.graph.predecessors(node)) == 0:
                    q.append(node)
                    already_seen.append(node)
            while q:
                p = q
                q = []
                y_pos += y_step
                x_pos = x_step
                while p:
                    node = p.pop()
                    node.position = (x_pos, y_pos)
                    x_pos += x_step
    
                    for child in self.network.graph.successors(node):
                        if not child in already_seen:
                            q.append(child)
                            already_seen.append(child)
    
        @staticmethod
        def read(filename_or_file, is_string = False):
            '''
            Reads a XMLBIF and returns a BayesNet.
    
            Keyword arguments:
            filename_or_file -- may be either a file name, or a file-like object.
            is_string -- is True if filename_or_file is a XML in a string
    
            Returns a BayesNet.
            '''
            if is_string:
                root = minidom.parseString(filename_or_file)
            else:
                root = minidom.parse(filename_or_file)
    
            return XMLBIF.generate_BayesNet(root)
    
        @staticmethod
        def generate_BayesNet(root):
            '''
            Generate a BayesNet from a XMLBIF.
    
            This method is used internally. Do not call it outside this class.
            '''
            network = BayesNet()
            bif_nodes = root.getElementsByTagName("BIF")
            if len(bif_nodes) != 1:
                raise Exception("More than one or none <BIF>-tag in document.")
            network_nodes = bif_nodes[0].getElementsByTagName("NETWORK")
            if len(network_nodes) != 1:
                raise Exception("More than one or none <NETWORK>-tag in document.")
            variable_nodes = network_nodes[0].getElementsByTagName("VARIABLE")
            for variable_node in variable_nodes:
                name = "Unnamed node"
                value_range = []
                position = (0, 0)
                for name_node in variable_node.getElementsByTagName("NAME"):
                    name = XMLBIF.get_node_text(name_node.childNodes)
                    break
                for output_node in variable_node.getElementsByTagName("OUTCOME"):
                    value_range.append(XMLBIF.get_node_text(output_node.childNodes))
                for position_node in variable_node.getElementsByTagName("PROPERTY"):
                    position = XMLBIF.get_node_position_from_text(position_node.childNodes)
                    break
                new_node = DiscreteNode(name, value_range)
                new_node.position = position
                network.add_node(new_node)
            definition_nodes = network_nodes[0].getElementsByTagName("DEFINITION")
            for definition_node in definition_nodes:
                node = None
                for for_node in definition_node.getElementsByTagName("FOR"):
                    name = XMLBIF.get_node_text(for_node.childNodes)
                    node = network.get_node(name)
                    break
                if node == None:
                    continue
    
                for given_node in reversed(definition_node.getElementsByTagName("GIVEN")):
    
                    parent_name = XMLBIF.get_node_text(given_node.childNodes)
                    parent_node = network.get_node(parent_name)
    
                    network.add_edge(parent_node, node)
    
                for table_node in definition_node.getElementsByTagName("TABLE"):
    
                    table = XMLBIF.get_node_table_from_text(table_node.childNodes)
                    node.get_cpd().get_table().T.flat = table
                    break
    
    
            return network
    
        @staticmethod
        def get_node_text(nodelist):
            '''
            Keyword arguments:
            nodelist -- is a list of nodes (xml.dom.minidom.Node).
    
            Returns the text of the given nodelist or a empty string.
            '''
            rc = []
            for node in nodelist:
                if node.nodeType == node.TEXT_NODE:
                    rc.append(node.data)
            return ''.join(rc)
    
        @staticmethod
        def get_node_position_from_text(nodelist):
            '''
            Keyword arguments:
            nodelist -- is a list of nodes (xml.dom.minidom.Node).
    
            Returns the position of the given nodelist as pair (x, y).
            '''
            rc = []
            for node in nodelist:
                if node.nodeType == node.TEXT_NODE:
                    rc.append(node.data)
            text = ''.join(rc)
            number_list = re.findall(r"\d+", text)
            if len(number_list) != 2:
                raise Exception("Ambiguous node position in XMLBIF.")
            return (number_list[0], number_list[1])
    
        @staticmethod
    
        def get_node_table_from_text(nodelist):
    
            '''
            Keyword arguments:
            nodelist -- is a list of nodes (xml.dom.minidom.Node).
    
            Returns the probability table of the given nodelist as pair numpy.array.
            '''
            rc = []
            for node in nodelist:
                if node.nodeType == node.TEXT_NODE:
                    rc.append(node.data)
            text = ''.join(rc)
            number_list = re.findall(r"[0-9]*\.*[0-9]+", text)
    
            for (i, n) in enumerate(number_list):
                number_list[i] = float(n)