From b50363345664f1a2fdbc1d4729844bc47ee9b3b0 Mon Sep 17 00:00:00 2001
From: "Olivier J.N. Bertrand" <olivier.bertrand@uni-bielefeld.de>
Date: Mon, 20 Aug 2018 16:23:25 +0200
Subject: [PATCH] Add memory loading for db

---
 navipy/database/__init__.py | 49 ++++++++++++++++++++++++++-----------
 1 file changed, 35 insertions(+), 14 deletions(-)

diff --git a/navipy/database/__init__.py b/navipy/database/__init__.py
index 91d1d77..5e847ad 100644
--- a/navipy/database/__init__.py
+++ b/navipy/database/__init__.py
@@ -34,6 +34,23 @@ def convert_array(text):
     return np.load(out)
 
 
+def database2memory(source):
+    """ Load a database in the RAM of the machine
+
+    It can be useful to speed up certain calculation
+    However quite some memory may be used.
+
+    :param source: path to the databse to load into memory
+    :returns: a database
+    """
+    db = DataBaseLoad(source)
+    dbmem = DataBaseSave(":memory:")
+    for row_id, posorient in db.posorients.iterrows():
+        scene = db.scene(posorient)
+        dbmem.write_image(posorient, np.squeeze(scene[..., 0]))
+    return dbmem
+
+
 # Converts np.array to TEXT when inserting
 sqlite3.register_adapter(np.ndarray, adapt_array)
 # Converts TEXT to np.array when selecting
@@ -65,7 +82,7 @@ class DataBase():
             self._logger.exception(msg)
             raise TypeError(msg)
         _, ext = os.path.splitext(filename)
-        if ext != '.db':
+        if ext != '.db' and (filename != ':memory:'):
             msg = 'filename must have the .db extension'
             self._logger.exception(msg)
             raise NameError(msg)
@@ -116,12 +133,18 @@ class DataBase():
         for col in self.normalisation_columns:
             self.tablecolumns['normalisation'][col] = 'real'
 
-        if os.path.exists(filename):
+        if self.create is False:
             self._logger.info('Connect to database')
-            self.db = sqlite3.connect(
-                'file:' + filename + '?cache=shared', uri=True,
-                detect_types=sqlite3.PARSE_DECLTYPES,
-                timeout=10)
+
+            if os.path.exists(filename) or filename == ':memory:':
+                self.db = sqlite3.connect(
+                    'file:' + filename + '?cache=shared', uri=True,
+                    detect_types=sqlite3.PARSE_DECLTYPES,
+                    timeout=10)
+            else:
+                msg = 'Database {} does not exist'.format(filename)
+                self._logger.exception(msg)
+                raise NameError(msg)
             self.db_cursor = self.db.cursor()
             # Check table
             self._logger.debug('Check tables')
@@ -131,10 +154,12 @@ class DataBase():
                                     named {}'.format(filename, tablename)
                     self._logger.exception(msg)
                     raise Exception(msg)
-        elif self.create:
+        else:
             self._logger.info('Create to database')
             self.db = sqlite3.connect(
-                filename, detect_types=sqlite3.PARSE_DECLTYPES)
+                'file:' + filename + '?cache=shared', uri=True,
+                detect_types=sqlite3.PARSE_DECLTYPES,
+                timeout=10)
             self.db_cursor = self.db.cursor()
             # Create table
             self._logger.info('Create tables')
@@ -146,10 +171,6 @@ class DataBase():
                 self.db_cursor.execute(
                     "create table {} {}".format(key, columns))
             self.db.commit()
-        else:
-            msg = 'Database {} does not exist'.format(filename)
-            self._logger.exception(msg)
-            raise NameError(msg)
 
         azimuth = np.deg2rad(np.linspace(-180, 180, 360))
         elevation = np.deg2rad(np.linspace(-90, 90, 180))
@@ -780,12 +801,12 @@ class DataBaseLoad(DataBase):
         return denormed_im
 
 
-class DataBaseSave(DataBase):
+class DataBaseSave(DataBaseLoad):
     def __init__(self, filename, channels=['R', 'G', 'B', 'D'],
                  arr_dtype=np.uint8):
         """
         """
-        DataBase.__init__(self, filename, channels=channels)
+        DataBaseLoad.__init__(self, filename, channels=channels)
         self.arr_dtype = arr_dtype
 
     @property
-- 
GitLab