From b50363345664f1a2fdbc1d4729844bc47ee9b3b0 Mon Sep 17 00:00:00 2001 From: "Olivier J.N. Bertrand" <olivier.bertrand@uni-bielefeld.de> Date: Mon, 20 Aug 2018 16:23:25 +0200 Subject: [PATCH] Add memory loading for db --- navipy/database/__init__.py | 49 ++++++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/navipy/database/__init__.py b/navipy/database/__init__.py index 91d1d77..5e847ad 100644 --- a/navipy/database/__init__.py +++ b/navipy/database/__init__.py @@ -34,6 +34,23 @@ def convert_array(text): return np.load(out) +def database2memory(source): + """ Load a database in the RAM of the machine + + It can be useful to speed up certain calculation + However quite some memory may be used. + + :param source: path to the databse to load into memory + :returns: a database + """ + db = DataBaseLoad(source) + dbmem = DataBaseSave(":memory:") + for row_id, posorient in db.posorients.iterrows(): + scene = db.scene(posorient) + dbmem.write_image(posorient, np.squeeze(scene[..., 0])) + return dbmem + + # Converts np.array to TEXT when inserting sqlite3.register_adapter(np.ndarray, adapt_array) # Converts TEXT to np.array when selecting @@ -65,7 +82,7 @@ class DataBase(): self._logger.exception(msg) raise TypeError(msg) _, ext = os.path.splitext(filename) - if ext != '.db': + if ext != '.db' and (filename != ':memory:'): msg = 'filename must have the .db extension' self._logger.exception(msg) raise NameError(msg) @@ -116,12 +133,18 @@ class DataBase(): for col in self.normalisation_columns: self.tablecolumns['normalisation'][col] = 'real' - if os.path.exists(filename): + if self.create is False: self._logger.info('Connect to database') - self.db = sqlite3.connect( - 'file:' + filename + '?cache=shared', uri=True, - detect_types=sqlite3.PARSE_DECLTYPES, - timeout=10) + + if os.path.exists(filename) or filename == ':memory:': + self.db = sqlite3.connect( + 'file:' + filename + '?cache=shared', uri=True, + detect_types=sqlite3.PARSE_DECLTYPES, + timeout=10) + else: + msg = 'Database {} does not exist'.format(filename) + self._logger.exception(msg) + raise NameError(msg) self.db_cursor = self.db.cursor() # Check table self._logger.debug('Check tables') @@ -131,10 +154,12 @@ class DataBase(): named {}'.format(filename, tablename) self._logger.exception(msg) raise Exception(msg) - elif self.create: + else: self._logger.info('Create to database') self.db = sqlite3.connect( - filename, detect_types=sqlite3.PARSE_DECLTYPES) + 'file:' + filename + '?cache=shared', uri=True, + detect_types=sqlite3.PARSE_DECLTYPES, + timeout=10) self.db_cursor = self.db.cursor() # Create table self._logger.info('Create tables') @@ -146,10 +171,6 @@ class DataBase(): self.db_cursor.execute( "create table {} {}".format(key, columns)) self.db.commit() - else: - msg = 'Database {} does not exist'.format(filename) - self._logger.exception(msg) - raise NameError(msg) azimuth = np.deg2rad(np.linspace(-180, 180, 360)) elevation = np.deg2rad(np.linspace(-90, 90, 180)) @@ -780,12 +801,12 @@ class DataBaseLoad(DataBase): return denormed_im -class DataBaseSave(DataBase): +class DataBaseSave(DataBaseLoad): def __init__(self, filename, channels=['R', 'G', 'B', 'D'], arr_dtype=np.uint8): """ """ - DataBase.__init__(self, filename, channels=channels) + DataBaseLoad.__init__(self, filename, channels=channels) self.arr_dtype = arr_dtype @property -- GitLab