From b4b87120d8d6c178c2b0c00a312d4fa4b47b45f3 Mon Sep 17 00:00:00 2001 From: "Olivier J.N. Bertrand" <olivier.bertrand@uni-bielefeld.de> Date: Mon, 17 Sep 2018 16:50:36 +0200 Subject: [PATCH] Simplify the Database API by adding a mode instead of Load/Save classes --- navipy/__init__.py | 4 +- navipy/comparing/test.py | 2 +- navipy/database/__init__.py | 352 +++++++++++++-------------- navipy/database/test.py | 45 +--- navipy/database/tools.py | 6 +- navipy/moving/__init__.py | 2 +- navipy/moving/agent.py | 10 +- navipy/moving/test_agent.py | 2 +- navipy/processing/test.py | 2 +- navipy/sensors/blendtest_renderer.py | 6 +- navipy/sensors/renderer.py | 9 +- 11 files changed, 191 insertions(+), 249 deletions(-) diff --git a/navipy/__init__.py b/navipy/__init__.py index 785eae1..4765cc3 100644 --- a/navipy/__init__.py +++ b/navipy/__init__.py @@ -37,7 +37,7 @@ An agent using an average skyline homing vector, could be build as follow :lines: 4-5,11-36 """ -from navipy.database import DataBaseLoad +from navipy.database import DataBase import logging @@ -68,7 +68,7 @@ class Brain(): @property def posorients(self): - if isinstance(self.renderer, DataBaseLoad): + if isinstance(self.renderer, DataBase): return self.renderer.posorients else: raise NotImplementedError("Subclasses should implement this, " + diff --git a/navipy/comparing/test.py b/navipy/comparing/test.py index 18b6a8e..f6ea1a4 100644 --- a/navipy/comparing/test.py +++ b/navipy/comparing/test.py @@ -11,7 +11,7 @@ class TestCase(unittest.TestCase): """ loads the database """ self.mydb_filename = pkg_resources.resource_filename( 'navipy', 'resources/database.db') - self.mydb = database.DataBaseLoad(self.mydb_filename) + self.mydb = database.DataBase(self.mydb_filename) def test_imagediff_curr(self): """ diff --git a/navipy/database/__init__.py b/navipy/database/__init__.py index 325234f..e0faf13 100644 --- a/navipy/database/__init__.py +++ b/navipy/database/__init__.py @@ -34,23 +34,6 @@ def convert_array(text): return np.load(out) -def database2memory(source): - """ Load a database in the RAM of the machine - - It can be useful to speed up certain calculation - However quite some memory may be used. - - :param source: path to the databse to load into memory - :returns: a database - """ - db = DataBaseLoad(source) - dbmem = DataBaseSave(":memory:") - for row_id, posorient in db.posorients.iterrows(): - scene = db.scene(posorient) - dbmem.write_image(posorient, np.squeeze(scene[..., 0])) - return dbmem - - # Converts np.array to TEXT when inserting sqlite3.register_adapter(np.ndarray, adapt_array) # Converts TEXT to np.array when selecting @@ -58,12 +41,13 @@ sqlite3.register_converter("array", convert_array) class DataBase(): - """DataBase is the parent class of DataBaseLoad and DataBaseSave. + """DataBase It creates three sql table on initialisation. """ __float_tolerance = 1e-14 - def __init__(self, filename, channels=['R', 'G', 'B', 'D']): + def __init__(self, filename, mode='a', + channels=['R', 'G', 'B', 'D'], arr_dtype=np.uint8): """Initialisation of the database the first database is the image database to store the images the second database is the position_orientation database to @@ -86,6 +70,30 @@ class DataBase(): msg = 'filename must have the .db extension' self._logger.exception(msg) raise NameError(msg) + # We try to determine if we need to load or create + # the database + if (not os.path.exists(filename)) and mode == 'r': + # The file does not exist, and we want to read + # it. This is not possible + msg = 'Cannot read database {}' + msg = msg.format(filename) + self._logger.exception(msg) + raise NameError(msg) + elif mode == 'w': + # The file exist, and we want to write it + # We need to create the database + self._logger.info('WriteRead-mode') + self.arr_dtype = arr_dtype + elif mode == 'a': + self._logger.info('AppendRead-mode') + self.arr_dtype = arr_dtype + else: + # The file exist, and we want to either + # write it or append data to it. + # We need to load the database + self._logger.info('ReadOnly-mode') + self.mode = mode + if not isinstance(channels, list): msg = 'nb_channel should be a list or np array' self._logger.exception(msg) @@ -103,6 +111,11 @@ class DataBase(): msg = 'channels must be single value' self._logger.exception(msg) raise ValueError(msg) + if c not in ['R', 'G', 'B', 'D']: + msg = 'channels must be either\ + R,G,B or D (Distance)' + self._logger.exception(msg) + raise ValueError(msg) self._logger.debug('database\nfilename: {}\nchannel: {}'.format( filename, channels)) self.filename = filename @@ -133,18 +146,13 @@ class DataBase(): for col in self.normalisation_columns: self.tablecolumns['normalisation'][col] = 'real' - if (self.create is False) or (os.path.exists(filename)): + if (self.mode in ['a', 'r']) and (os.path.exists(filename)): self._logger.info('Connect to database') - if os.path.exists(filename) or filename == ':memory:': - self.db = sqlite3.connect( - 'file:' + filename + '?cache=shared', uri=True, - detect_types=sqlite3.PARSE_DECLTYPES, - timeout=10) - else: - msg = 'Database {} does not exist'.format(filename) - self._logger.exception(msg) - raise NameError(msg) + self.db = sqlite3.connect( + 'file:' + filename + '?cache=shared', uri=True, + detect_types=sqlite3.PARSE_DECLTYPES, + timeout=10) self.db_cursor = self.db.cursor() # Check table self._logger.debug('Check tables') @@ -155,7 +163,7 @@ class DataBase(): self._logger.exception(msg) raise Exception(msg) else: - self._logger.info('Create to database') + self._logger.info('Create database') self.db = sqlite3.connect( 'file:' + filename + '?cache=shared', uri=True, detect_types=sqlite3.PARSE_DECLTYPES, @@ -431,39 +439,37 @@ class DataBase(): WHERE {}; """.format(where), params) return self.db_cursor.fetchone()[0] - elif self.create & (convention != 'quaternion'): - self.db_cursor.execute( - """ - INSERT - INTO position_orientation(x,y,z,q_0,q_1,q_2,q_3,rotconv_id) - VALUES (?,?,?,?,?,?,?,?) - """, ( - posorient['location']['x'], - posorient['location']['y'], - posorient['location']['z'], - posorient[convention]['alpha_0'], - posorient[convention]['alpha_1'], - posorient[convention]['alpha_2'], - np.nan, - convention)) - rowid = self.db_cursor.lastrowid - self.db.commit() - return rowid - elif self.create: - self.db_cursor.execute( - """ - INSERT - INTO position_orientation(x,y,z,q_0,q_1,q_2,rotconv_id) - VALUES (?,?,?,?,?,?,?,?) - """, ( - posorient['location']['x'], - posorient['location']['y'], - posorient['location']['z'], - posorient[convention]['q_0'], - posorient[convention]['q_1'], - posorient[convention]['q_2'], - posorient[convention]['q_3'], - convention)) + elif (self.mode in ['a', 'w']): + if convention != 'quaternion': + self.db_cursor.execute( + """ + INSERT + INTO position_orientation(x,y,z,q_0,q_1,q_2,q_3,rotconv_id) + VALUES (?,?,?,?,?,?,?,?) + """, ( + posorient['location']['x'], + posorient['location']['y'], + posorient['location']['z'], + posorient[convention]['alpha_0'], + posorient[convention]['alpha_1'], + posorient[convention]['alpha_2'], + np.nan, + convention)) + else: + self.db_cursor.execute( + """ + INSERT + INTO position_orientation(x,y,z,q_0,q_1,q_2,rotconv_id) + VALUES (?,?,?,?,?,?,?,?) + """, ( + posorient['location']['x'], + posorient['location']['y'], + posorient['location']['z'], + posorient[convention]['q_0'], + posorient[convention]['q_1'], + posorient[convention]['q_2'], + posorient[convention]['q_3'], + convention)) rowid = self.db_cursor.lastrowid self.db.commit() return rowid @@ -473,33 +479,6 @@ class DataBase(): self._logger.exception(msg) raise ValueError(msg) - @property - def create(self): - return False - - -class DataBaseLoad(DataBase): - """A database generated by the rendering module is based on sqlite3. - """ - - def __init__(self, filename, channels=['R', 'G', 'B', 'D']): - """Initialise the DataBaseLoader""" - DataBase.__init__(self, filename, channels=channels) - for c in channels: - if c not in ['R', 'G', 'B', 'D']: - msg = 'channels must be either\ - R,G,B or D (Distance)' - self._logger.exception(msg) - raise ValueError(msg) - self.__convention = None - - @property - def create(self): - """use to decide weather to alter the database or not - return False because we do not want - to write on database (Load class)""" - return False - def iter_posorients(self): """Iter through all position orientation in the database """ @@ -517,10 +496,13 @@ class DataBaseLoad(DataBase): toyield.name = toyield.id toyield.drop('id', inplace=True) yield toyield + # + # Access to single values + # @property def rotation_convention(self): - """ Return the convention in the database + """ Return the convention of the database The database can technically contains more than one convention. Although it is discourage to do so, it is not forbidden. @@ -532,6 +514,7 @@ class DataBaseLoad(DataBase): "select * from position_orientation;", self.db) posorient.set_index('id', inplace=True) if self.__convention is None: + # we need to assign it from the posorient if 'rotconv_id' in posorient.columns: rotconv = posorient.loc[:, 'rotconv_id'] if np.all(rotconv == rotconv.iloc[0]): @@ -569,11 +552,14 @@ class DataBaseLoad(DataBase): :returns: all position orientations :rtype: list of pd.Series """ - posorient = pd.read_sql_query( + normal = pd.read_sql_query( "select * from normalisation;", self.db) - posorient.set_index('id', inplace=True) - return posorient + normal.set_index('id', inplace=True) + return normal + # + # Read from database + # def read_posorient(self, posorient=None, rowid=None): if rowid is not None: if not isinstance(rowid, int): @@ -594,7 +580,7 @@ class DataBaseLoad(DataBase): raise Exception(msg) if posorient is not None: rowid = self.get_posid(posorient) - # Read images + # Read pososition porientation tablename = 'position_orientation' toreturn = pd.read_sql_query( """ @@ -722,6 +708,89 @@ class DataBaseLoad(DataBase): toreturn = toreturn[..., np.newaxis] check_scene(toreturn) return toreturn + # + # Write + # + + def insert_replace(self, tablename, params): + if not isinstance(tablename, str): + msg = 'table are named by string' + self._logger.exception('table are named by string') + raise TypeError(msg) + if not isinstance(params, dict): + msg = 'params should be dictionary columns:val' + self._logger.exception(msg) + raise TypeError(msg) + params_list = list() + columns_str = '' + for key, val in params.items(): + columns_str += key + ',' + params_list.append(val) + columns_str = columns_str[:-1] # remove last comma + if len(params_list) == 0: + self._logger.warning('nothing to be done in {}'.format(tablename)) + return + questionsmarks = '?' + for _ in range(1, len(params_list)): + questionsmarks += ',?' + self.db_cursor.execute( + """ + INSERT OR REPLACE + INTO {} ({}) + VALUES ({}) + """.format(tablename, + columns_str, + questionsmarks), + tuple(params_list) + ) + self.db.commit() + + def write_image(self, posorient, image): + """stores an image in the database. Automatically + calculates the cminmax range from the image and + channels. + :param posorient: is a 1x6 vector containing: + *in case of euler angeles the index should be + ['location']['x'] + ['location']['y'] + ['location']['z'] + [convention][alpha_0] + [convention][alpha_1] + [convention][alpha_2] + **where convention can be: + rxyz, rxzy, ryxz, ryzx, rzyx, rzxy + *in case of quaternions the index should be + ['location']['x'] + ['location']['y'] + ['location']['z'] + [convention]['q_0'] + [convention]['q_1'] + [convention]['q_2'] + [convention]['q_3'] + **where convention can be: + quaternion + :param image: image to be stored + :type image: np.ndarray + :type posorient: pd.Series + """ + normed_im, cmaxminrange = self.normalise_image(image, self.arr_dtype) + rowid = self.get_posid(posorient) + # Write image + tablename = 'image' + params = dict() + params['rowid'] = rowid + params['data'] = normed_im + self.insert_replace(tablename, params) + # + tablename = 'normalisation' + params = dict() + params['rowid'] = rowid + for chan_n in self.normalisation_columns: + params[chan_n] = cmaxminrange.loc[chan_n] + self.insert_replace(tablename, params) + # + # Image processing + # def denormalise_image(self, image, cmaxminrange): """denomalises an image @@ -800,99 +869,6 @@ class DataBaseLoad(DataBase): self._logger.warning(msg) return denormed_im - -class DataBaseSave(DataBaseLoad): - def __init__(self, filename, channels=['R', 'G', 'B', 'D'], - arr_dtype=np.uint8): - """ - """ - DataBaseLoad.__init__(self, filename, channels=channels) - self.arr_dtype = arr_dtype - - @property - def create(self): - """use to decide weather to alter the database or not - return True because we will need - to write on database (Save class)""" - return True - - def write_image(self, posorient, image): - """stores an image in the database. Automatically - calculates the cminmax range from the image and - channels. - :param posorient: is a 1x6 vector containing: - *in case of euler angeles the index should be - ['location']['x'] - ['location']['y'] - ['location']['z'] - [convention][alpha_0] - [convention][alpha_1] - [convention][alpha_2] - **where convention can be: - rxyz, rxzy, ryxz, ryzx, rzyx, rzxy - *in case of quaternions the index should be - ['location']['x'] - ['location']['y'] - ['location']['z'] - [convention]['q_0'] - [convention]['q_1'] - [convention]['q_2'] - [convention]['q_3'] - **where convention can be: - quaternion - :param image: image to be stored - :type image: np.ndarray - :type posorient: pd.Series - """ - normed_im, cmaxminrange = self.normalise_image(image, self.arr_dtype) - rowid = self.get_posid(posorient) - # Write image - tablename = 'image' - params = dict() - params['rowid'] = rowid - params['data'] = normed_im - self.insert_replace(tablename, params) - # - tablename = 'normalisation' - params = dict() - params['rowid'] = rowid - for chan_n in self.normalisation_columns: - params[chan_n] = cmaxminrange.loc[chan_n] - self.insert_replace(tablename, params) - - def insert_replace(self, tablename, params): - if not isinstance(tablename, str): - msg = 'table are named by string' - self._logger.exception('table are named by string') - raise TypeError(msg) - if not isinstance(params, dict): - msg = 'params should be dictionary columns:val' - self._logger.exception(msg) - raise TypeError(msg) - params_list = list() - columns_str = '' - for key, val in params.items(): - columns_str += key + ',' - params_list.append(val) - columns_str = columns_str[:-1] # remove last comma - if len(params_list) == 0: - self._logger.warning('nothing to be done in {}'.format(tablename)) - return - questionsmarks = '?' - for _ in range(1, len(params_list)): - questionsmarks += ',?' - self.db_cursor.execute( - """ - INSERT OR REPLACE - INTO {} ({}) - VALUES ({}) - """.format(tablename, - columns_str, - questionsmarks), - tuple(params_list) - ) - self.db.commit() - def normalise_image(self, image, dtype=np.uint8): """normalises an image to a range between 0 and 1. :param image: image to be normalised diff --git a/navipy/database/test.py b/navipy/database/test.py index bb2ddb0..973f4c7 100644 --- a/navipy/database/test.py +++ b/navipy/database/test.py @@ -6,7 +6,7 @@ import navipy.database as database # from navipy.processing.tools import is_numeric_array import pkg_resources import tempfile -from navipy.database import DataBaseLoad, DataBaseSave, DataBase +from navipy.database import DataBase from navipy import unittestlogger @@ -15,7 +15,7 @@ class TestCase(unittest.TestCase): unittestlogger() self.mydb_filename = pkg_resources.resource_filename( 'navipy', 'resources/database2.db') - self.mydb = DataBaseLoad(self.mydb_filename) + self.mydb = DataBase(self.mydb_filename, mode='r') def test_DataBase_init_(self): """ @@ -34,7 +34,7 @@ class TestCase(unittest.TestCase): with self.assertRaises(TypeError): DataBase(n) - # only works if testdb was created before e.g. with DataBaseSave + # only works if testdb was created before e.g. with DataBase # with self.assertRaises(NameError): # DataBase('test') @@ -195,23 +195,6 @@ class TestCase(unittest.TestCase): with self.assertRaises(Exception): self.mydb.get_posid(posorient2) - def test_DataBaseLoad_init_(self): - """ - this test checks the function DataBaseLoad works - correctly. - it checks if correct errors are raised for: - - filename does not end with .db - - filename is not a string or char - i.e. integer, float, none, nan - """ - # filename must end with .db - with self.assertRaises(NameError): - DataBaseLoad('test') - # filename must be string - for n in [2, 5.0, None, np.nan]: - with self.assertRaises(TypeError): - DataBaseLoad(n) - def test_read_posorient(self): """ this test checks the function read_posorient works @@ -628,24 +611,6 @@ class TestCase(unittest.TestCase): with self.assertRaises(ValueError): self.mydb.denormalise_image(imagecorrect, cminmaxrange) - def test_DataBaseSave(self): - """ - this test checks the function DataBaseSaver works - correctly. - it checks if correct errors are raised for: - - the filename is of type integer - - checks for correct result if a new DataBaseSaver - is created (no error is thrown) - """ - # should work and creat new database - with tempfile.TemporaryDirectory() as folder: - testdb_filename = folder + '/testdatabase.db' - database.DataBaseSave(testdb_filename) - # should not work - with self.assertRaises(Exception): - database.DataBaseSave(filename=3) - def test_normalise_image(self): """ this test checks the function normalise_image works @@ -659,7 +624,7 @@ class TestCase(unittest.TestCase): image = np.squeeze(image) with tempfile.TemporaryDirectory() as folder: testdb_filename = folder + '/testdatabase.db' - loadDB = DataBaseSave(testdb_filename) + loadDB = DataBase(testdb_filename, mode='w') loadDB.normalise_image(image) # not working @@ -698,7 +663,7 @@ class TestCase(unittest.TestCase): params['age'] = 20 with tempfile.TemporaryDirectory() as folder: testdb_filename = folder + '/testdatabase.db' - tmpmydb = database.DataBaseSave(testdb_filename) + tmpmydb = database.DataBase(testdb_filename, mode='w') for name in [3, 7.5, np.nan, None]: with self.assertRaises(TypeError): diff --git a/navipy/database/tools.py b/navipy/database/tools.py index 2233ea6..ce86abe 100644 --- a/navipy/database/tools.py +++ b/navipy/database/tools.py @@ -1,7 +1,7 @@ """ Some tools to work with databases """ -from navipy.database import DataBaseLoad, DataBaseSave +from navipy.database import DataBase def copy(filename_in, filename_out): @@ -10,8 +10,8 @@ def copy(filename_in, filename_out): :param filename_in: Path to the input database :param filename_out: Path to the output database """ - dbin = DataBaseLoad(filename_in) - dbout = DataBaseSave(filename_out) + dbin = DataBase(filename_in, mode='r') + dbout = DataBase(filename_out, mode='a') for i, posorient in dbin.get_posorients().iterrows(): print(posorient) try: diff --git a/navipy/moving/__init__.py b/navipy/moving/__init__.py index 7a199eb..0947922 100644 --- a/navipy/moving/__init__.py +++ b/navipy/moving/__init__.py @@ -28,7 +28,7 @@ They differ by the method use to update the sensory information: +================+====================================================+ |:CyberBeeAgent: |:Cyberbee: update within blender. | +----------------+----------------------------------------------------+ -|:GridAgent: |:DataBaseLoad: update from a pre-rendered database. | +|:GridAgent: |:DataBase: update from a pre-rendered database. | +----------------+----------------------------------------------------+ To deduce the agent motion from the current state of the agent \ diff --git a/navipy/moving/agent.py b/navipy/moving/agent.py index 4672b94..6632530 100644 --- a/navipy/moving/agent.py +++ b/navipy/moving/agent.py @@ -28,7 +28,7 @@ import multiprocessing from multiprocessing import Queue, JoinableQueue, Process import inspect import navipy.moving.maths as navimomath -from navipy.database import DataBaseLoad +from navipy.database import DataBase import time import os @@ -247,9 +247,9 @@ GridAgent is a close loop agent here its position is snap to a grid. def __init__(self, brain, posorients_queue=None, results_queue=None): - if not isinstance(brain.renderer, DataBaseLoad): + if not isinstance(brain.renderer, DataBase): msg = 'GridAgent only works with a brain having ' - msg += 'a renderer of type DataBaseLoad' + msg += 'a renderer of type DataBase' raise TypeError(msg) convention = brain.renderer.rotation_convention if (posorients_queue is not None) and (results_queue is not None): @@ -358,9 +358,9 @@ the agent motion, or self._brain = copy.copy(brain) # Init the graph self._graph = nx.DiGraph() - if not isinstance(self._brain.renderer, DataBaseLoad): + if not isinstance(self._brain.renderer, DataBase): msg = 'GraphAgent only works with a brain having ' - msg += 'a renderer of type DataBaseLoad' + msg += 'a renderer of type DataBase' raise TypeError(msg) for row_id, posor in self._brain.posorients.iterrows(): posor.name = row_id diff --git a/navipy/moving/test_agent.py b/navipy/moving/test_agent.py index 2f97910..69b19b3 100644 --- a/navipy/moving/test_agent.py +++ b/navipy/moving/test_agent.py @@ -47,7 +47,7 @@ class TestNavipyMovingAgent(unittest.TestCase): def setUp(self): self.mydb_filename = pkg_resources.resource_filename( 'navipy', 'resources/database.db') - self.mydb = navidb.DataBaseLoad(self.mydb_filename) + self.mydb = navidb.DataBase(self.mydb_filename, mode='r') self.convention = 'rzyx' self.brain = BrainTest(self.mydb) tuples = [('location', 'x'), ('location', 'y'), diff --git a/navipy/processing/test.py b/navipy/processing/test.py index 7aafffc..94a61ee 100644 --- a/navipy/processing/test.py +++ b/navipy/processing/test.py @@ -12,7 +12,7 @@ class TestCase(unittest.TestCase): unittestlogger() self.mydb_filename = pkg_resources.resource_filename( 'navipy', 'resources/database.db') - self.mydb = database.DataBaseLoad(self.mydb_filename) + self.mydb = database.DataBase(self.mydb_filename, mode='r') def test_scene_posorient(self): """ diff --git a/navipy/sensors/blendtest_renderer.py b/navipy/sensors/blendtest_renderer.py index 45167ec..7554787 100644 --- a/navipy/sensors/blendtest_renderer.py +++ b/navipy/sensors/blendtest_renderer.py @@ -6,7 +6,7 @@ import numpy as np import unittest import pkg_resources import tempfile -from navipy.database import DataBaseLoad +from navipy.database import DataBase class TestBlenderRender_renderer(unittest.TestCase): @@ -82,13 +82,13 @@ class TestBlenderRender_renderer(unittest.TestCase): rotconv = 'rzyx' db_reffilename = pkg_resources.resource_filename( 'navipy', 'resources/database.db') - db_ref = DataBaseLoad(db_reffilename) + db_ref = DataBase(db_reffilename, mode='r') tfile = tempfile.NamedTemporaryFile() outputfile = tfile.name+'.db' self.renderer.render_ongrid(outputfile, x, y, z, alpha_0, rotconv=rotconv) - db = DataBaseLoad(outputfile) + db = DataBase(outputfile, mode='r') posorients = db_ref.posorients for row_i, posorient in posorients.iterrows(): refscene = db_ref.scene(posorient) diff --git a/navipy/sensors/renderer.py b/navipy/sensors/renderer.py index cd7d153..8a44c6c 100644 --- a/navipy/sensors/renderer.py +++ b/navipy/sensors/renderer.py @@ -20,7 +20,7 @@ import navipy.maths.constants as constants from navipy.trajectories import Trajectory from PIL import Image from navipy.scene import check_scene -from navipy.database import DataBaseSave +from navipy.database import DataBase import logging @@ -80,9 +80,10 @@ class AbstractRender(): else: mode = 'database' self._logger.debug('render outputmode:{}'.format(mode)) - dataloger = DataBaseSave(outputfile, - channels=['R', 'G', 'B', 'D'], - arr_dtype=np.uint8) + dataloger = DataBase(outputfile, + mode='a', + channels=['R', 'G', 'B', 'D'], + arr_dtype=np.uint8) # We now can render self._logger.info('Start rendering') for frame_i, posorient in trajectory.iterrows(): -- GitLab