Skip to content
Snippets Groups Projects
Commit 7c388d93 authored by Lukas Kettenbach's avatar Lukas Kettenbach
Browse files

Finished XMLBIF export. Add test for XMLBIF import and export.

parent 34c494b0
Branches
Tags
No related merge requests found
...@@ -12,7 +12,9 @@ class Node(object): ...@@ -12,7 +12,9 @@ class Node(object):
def __init__(self, node_name): def __init__(self, node_name):
# Remove all special characters and replace " " with "_" # Remove all special characters and replace " " with "_"
name = re.sub(r"[^a-zA-Z_0-9 ]*", "", node_name) name = re.sub(r"[^a-zA-Z_0-9 ]*", "", node_name)
self.name = name.replace(" ", "_") self.name = name.replace(" ", "_")
# for visual illustration
self.pos = (0, 0)
@abc.abstractmethod @abc.abstractmethod
def announce_parent(self, node): def announce_parent(self, node):
......
import unittest
from primo.utils import XMLBIF
from primo.core import BayesNet
from primo.reasoning import DiscreteNode
import numpy
import os
class ImportExportTest(unittest.TestCase):
def setUp(self):
# Create BayesNet
self.bn = BayesNet();
# Create Nodes
weather0 = DiscreteNode("Weather0", ["Sun", "Rain"])
weather = DiscreteNode("Weather", ["Sun", "Rain"])
ice_cream_eaten = DiscreteNode("Ice Cream Eaten", [True, False])
# Add nodes
self.bn.add_node(weather0)
self.bn.add_node(weather)
self.bn.add_node(ice_cream_eaten)
# Add edges
self.bn.add_edge(weather, ice_cream_eaten)
self.bn.add_edge(weather0, weather);
# Set probabilities
cpt_weather0 = numpy.array([.6, .4])
weather0.set_probability_table(cpt_weather0, [weather0])
cpt_weather = numpy.array([[.7, .5],
[.3, .5]])
weather.set_probability_table(cpt_weather, [weather0, weather])
ice_cream_eaten.set_probability(.9, [(ice_cream_eaten, True), (weather, "Sun")])
ice_cream_eaten.set_probability(.1, [(ice_cream_eaten, False), (weather, "Sun")])
ice_cream_eaten.set_probability(.2, [(ice_cream_eaten, True), (weather, "Rain")])
ice_cream_eaten.set_probability(.8, [(ice_cream_eaten, False), (weather, "Rain")])
def test_import_export(self):
# write BN
xmlbif = XMLBIF(self.bn, "Test Net")
xmlbif.write("test_out.xmlbif")
# read BN
bn2 = XMLBIF.read("test_out.xmlbif")
for node1 in self.bn.get_nodes():
name_found = False
cpd_equal = False
value_range_equal = False
str_equal = False
pos_equal = False
for node2 in bn2.get_nodes():
# Test node names
if node1.name == node2.name:
name_found = True
cpd_equal = node1.get_cpd == node2.get_cpd
value_range_equal = node1.get_value_range() == node2.get_value_range()
str_equal = str(node1) == str(node2)
pos_equal = node1.pos == node2.pos
self.assertTrue(name_found)
self.assertTrue(cpd_equal)
self.assertTrue(value_range_equal)
self.assertTrue(str_equal)
self.assertTrue(pos_equal)
# remove file
os.remove("test_out.xmlbif")
#include this so you can run this test without nose
if __name__ == '__main__':
unittest.main()
...@@ -4,7 +4,6 @@ from primo.core import BayesNet ...@@ -4,7 +4,6 @@ from primo.core import BayesNet
from primo.core import Node from primo.core import Node
from primo.reasoning import DiscreteNode from primo.reasoning import DiscreteNode
import re import re
import numpy
class XMLBIF(object): class XMLBIF(object):
...@@ -252,7 +251,6 @@ class XMLBIF(object): ...@@ -252,7 +251,6 @@ class XMLBIF(object):
definition_nodes = network_nodes[0].getElementsByTagName("DEFINITION") definition_nodes = network_nodes[0].getElementsByTagName("DEFINITION")
for definition_node in definition_nodes: for definition_node in definition_nodes:
node = None node = None
parent_count = 0
for for_node in definition_node.getElementsByTagName("FOR"): for for_node in definition_node.getElementsByTagName("FOR"):
name = XMLBIF.get_node_text(for_node.childNodes) name = XMLBIF.get_node_text(for_node.childNodes)
node = network.get_node(name) node = network.get_node(name)
...@@ -263,9 +261,10 @@ class XMLBIF(object): ...@@ -263,9 +261,10 @@ class XMLBIF(object):
parent_name = XMLBIF.get_node_text(given_node.childNodes) parent_name = XMLBIF.get_node_text(given_node.childNodes)
parent_node = network.get_node(parent_name) parent_node = network.get_node(parent_name)
node.announce_parent(parent_node) node.announce_parent(parent_node)
parent_count += 1
for table_node in definition_node.getElementsByTagName("TABLE"): for table_node in definition_node.getElementsByTagName("TABLE"):
table = XMLBIF.get_node_table_from_text(table_node.childNodes, parent_count) table = XMLBIF.get_node_table_from_text(table_node.childNodes)
node.get_cpd().get_table().T.flat = table
break
return network return network
...@@ -302,11 +301,10 @@ class XMLBIF(object): ...@@ -302,11 +301,10 @@ class XMLBIF(object):
return (number_list[0], number_list[1]) return (number_list[0], number_list[1])
@staticmethod @staticmethod
def get_node_table_from_text(nodelist, parent_count): def get_node_table_from_text(nodelist):
''' '''
Keyword arguments: Keyword arguments:
nodelist -- is a list of nodes (xml.dom.minidom.Node). nodelist -- is a list of nodes (xml.dom.minidom.Node).
parent_count -- is the number of parents.
Returns the probability table of the given nodelist as pair numpy.array. Returns the probability table of the given nodelist as pair numpy.array.
''' '''
...@@ -316,7 +314,4 @@ class XMLBIF(object): ...@@ -316,7 +314,4 @@ class XMLBIF(object):
rc.append(node.data) rc.append(node.data)
text = ''.join(rc) text = ''.join(rc)
number_list = re.findall(r"[0-9]*\.*[0-9]+", text) number_list = re.findall(r"[0-9]*\.*[0-9]+", text)
if len(number_list) return number_list
for number in number_list: \ No newline at end of file
return (number_list[0], number_list[1])
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment