Commit 99f2ab5e authored by Jan Pöppel's avatar Jan Pöppel
Browse files

rebased with master

parents 7f2f9201 7d53472d
......@@ -83,6 +83,37 @@ class BayesianNetwork(object):
else:
raise Exception("There is no node with name {} in the network.".format(node))
def change_node_name(self, oldName, newName):
"""
Renames the given node to the new name.
Will have to modify all occurances of the old name.
"""
if oldName in self.node_lookup:
n = self.node_lookup[oldName]
children = list(self.graph.succ[n])
parents = list(self.graph.pred[n])
for child in children:
del child.parents[oldName]
child.parents[newName] = n
idx = child.parentOrder.index(oldName)
child.parentOrder[idx] = newName
#Fix nx graph
self.graph.remove_node(n)
n.name = newName
del self.node_lookup[oldName]
self.node_lookup[n] = n
self.graph.add_node(n)
for child in children:
self.graph.add_edge(n, self.node_lookup[child])
for parent in parents:
self.graph.add_edge(self.node_lookup[parent], n)
else:
raise Exception("There is no node with name {} in the network.".format(oldName))
def get_all_nodes(self):
return self.graph.nodes()
......
......@@ -107,6 +107,26 @@ class DiscreteNode(RandomNode):
self.parentOrder.append(parentNode.name)
self._update_dimensions()
def set_values(self, newValues):
"""
Allows to change/set the values of this variable. This will
invalidate the node, as it is expected to receive a new cpt
matching the new values.
Important: This will not update any children of this node,
therefore this method should only be used before specifying
children, or only by the network class which will also invalidate
all children of this node!
Parameters
----------
newValues: [String,]
List of the new value names.
"""
self.values = list(newValues)
self._update_dimensions()
def _update_dimensions(self):
"""
Private helper function to update the dimensions of the cpd.
......
......@@ -21,7 +21,7 @@
import unittest
from primo2.networks import BayesianNetwork
from primo2.nodes import RandomNode
from primo2.nodes import RandomNode, DiscreteNode
class BayesNetTest(unittest.TestCase):
......@@ -53,6 +53,43 @@ class BayesNetTest(unittest.TestCase):
self.bn.add_node(n2)
self.assertEqual(len(self.bn), 2)
def test_change_node_values_of_parent(self):
n1 = DiscreteNode("Node1")
n2 = DiscreteNode("Node2")
self.bn.add_node(n1)
self.bn.add_node(n2)
self.bn.add_edge(n2,n1)
self.assertEqual(n1.cpd.shape, (2,2))
self.bn.change_node_values(n2, ["Value1","Value2","Value3"])
self.assertEqual(n2.values, ["Value1","Value2","Value3"])
self.assertEqual(n1.valid, False)
self.assertEqual(n1.cpd.shape, (2,3))
def test_change_node_name(self):
n1 = DiscreteNode("Node1")
n2 = DiscreteNode("Node2")
n3 = DiscreteNode("Node3")
self.bn.add_node(n1)
self.bn.add_node(n2)
self.bn.add_node(n3)
self.bn.add_edge(n2,n1)
self.bn.add_edge(n3,n2)
self.bn.change_node_name("Node2", "NewName")
self.assertEqual(n2.name, "NewName")
with self.assertRaises(Exception) as cm:
self.bn.get_node("Node2")
self.assertEqual(str(cm.exception), "There is no node with name Node2 in the BayesianNetwork")
self.assertEqual(n1.parentOrder, ["NewName"])
self.assertTrue(n2 in self.bn.graph.nodes())
self.assertTrue(n2 in self.bn.node_lookup)
self.assertTrue(n1 in self.bn.graph.succ[n2])
self.assertTrue(n3 in self.bn.graph.pred[n2])
self.assertTrue(n2 in self.bn.graph.succ[n3])
# def test_addEdge(self):
# self.fail("TODO")
......
......@@ -106,6 +106,22 @@ class DiscreteNode(unittest.TestCase):
n.set_cpd(cpd)
self.assertTrue(n.valid)
def test_set_values(self):
n = nodes.DiscreteNode("Node1")
n.set_values(["Value1", "Value2", "Value3"])
self.assertEqual(n.cpd.shape, (3,))
self.assertEqual(n.values, ["Value1", "Value2", "Value3"])
self.assertEqual(n.valid, False)
def test_set_values_with_parent(self):
n = nodes.DiscreteNode("Node1")
n2 = nodes.DiscreteNode("Node2")
n.add_parent(n2)
n.set_values(["Value1", "Value2", "Value3"])
self.assertEqual(n.cpd.shape, (3,2))
self.assertEqual(n.values, ["Value1", "Value2", "Value3"])
self.assertEqual(n.valid, False)
def test_add_parent(self):
n = nodes.DiscreteNode("Node1", ["Value1", "Value2"])
n2 = nodes.DiscreteNode("Node2", ["Value3", "Value4", "Value5"])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment